In [None]:
import os
import json
import time

import wandb
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from PIL import Image
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
#from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration,  AutoProcessor
from datasets import Dataset as DF


from utils import set_seed, load_dataset, get_device, load_system_prompt, grouped_paths

In [None]:
device = get_device()

def create_model(model_id, cache_dir, use_flash_attention):
    if use_flash_attention:
        # husk å endre tilbake
        return Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id, 
            cache_dir = cache_dir,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map=device,
        )
    
    return Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id, 
        cache_dir = cache_dir,
        torch_dtype="auto",
        device_map=device
    )


class CustomDataset(Dataset):
    def __init__(self, data):
        self.text = data["text"]
        self.labels = data["labels"]
        self.panoramas = data["panoramas"]
        self.candidates = data["candidates"]
        self.path_id = data["path_id"]
        self.step_id = data["step_id"]
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        return self.text[index], self.panoramas[index], self.candidates[index], self.labels[index], self.path_id[index], self.step_id[index]


# TODO: make the collatefunctor work with batches
class CollateFunctor:
    # No batch, therefore no max length
    def __init__(self, processor, width, height):
        self.processor = processor
        self.width = width
        self.height = height

    def __call__(self, batch):
        text, panoramas, candidates, gold_candidate, path_id, step_id = batch[0]
        candidate_tokens = self.processor.tokenizer(f"{gold_candidate}", return_tensors="pt").input_ids
        # Explicitly add the Candidate: 
        label_start = self.processor.tokenizer("<|im_start|>assistant\nCandidate: ", return_tensors="pt").input_ids

        # 1440 x 360
        images = [Image.open(img).resize((960, 240), Image.Resampling.LANCZOS) for img in panoramas]
        candidate_images = [Image.open(img).resize((self.width, self.height), Image.Resampling.LANCZOS) for img in candidates]
        
        images.extend(candidate_images)
        
        processed = self.processor(text=text, images=[images], return_tensors="pt")

        prompt_input_ids = processed["input_ids"]
        prompt_input_ids = torch.cat([prompt_input_ids, label_start], dim=1)
        input_ids = torch.cat([prompt_input_ids, candidate_tokens], dim=1)
        
        labels = torch.cat(
            [
                # se på denne senere, altså om vi skal virkelig fokusere på alt det der andre chat greiene
                torch.tensor([-100]*len(prompt_input_ids[0])).unsqueeze(0),
                candidate_tokens
            ],
            dim=1
        )

        attention_mask = torch.ones(1, input_ids.shape[1])
        processed["labels"] = labels
        processed["input_ids"] = input_ids
        processed["attention_mask"] = attention_mask
        processed["gold_candidate"] = candidate_tokens
        processed["path_id"] = path_id
        processed["step_id"] = step_id
        
        return processed


def format_prompts_v5(dataset, processor, system_prompt, instruction_index, path, data_type="train"):
    root_path = os.path.join(path, data_type)
    formatted_data = []

    current_path_id = None
    distance_traveled = 0

    for sample in dataset:
        path_id = sample["path_id"]
        step_id = sample["step_id"]
        route_instruction = sample["instructions"][instruction_index]

        # should be in the order: panorama_history, current_panorama, candidates views from left to right
        current_panorama = os.path.join(root_path, sample["current_image"])
        panoramas = [os.path.join(root_path, i) for i in sample["image_history"]]
        
        if current_path_id != path_id:
            distance_traveled = 0
            current_path_id = path_id

        # route instruction, current step, cumulative distance
        content = [
            {
                "type" : "text",
                "text" : f"Route instruction: {route_instruction}\nCurrent step: {step_id}\nCumulative Distance Traveled: {distance_traveled} meters\n\nPanorama Images from Previous Steps:"
            }
        ]

        # panorama from previous steps
        for i, img in enumerate(panoramas):
            content.append({
                "type" : "text",
                "text" : f"\n\tPanorama at step: {i}: "
            })
            content.append({
                "type" : "image",
                "image" : img
            })

        if len(panoramas) == 0:
            content[0]["text"] += f"[]"

        # current panorama
        content.append({
            "type" : "text",
            "text" : f"\n\nCurrent Panorama Image:\n\t"
        })

        content.append({
            "type" : "image",
            "image" : current_panorama
        })

        # candidate directions
        content.append({
            "type" : "text",
            "text" : "\n\nCandidate Directions:"
        })

        candidates = []
        for i, candidate in enumerate(sample["candidates"].values()):
            relative_angle = round(candidate["relative_angle"], 0)
            distance = round(candidate["distance"], 2)
            direction = "Left" if relative_angle < 0 else "Right"
            candidate_image = os.path.join(root_path, candidate["image_path"])
            
            content.append({
                "type" : "text",
                "text" : f"\n\tCandidate: {i}:\n\t\tRelative angle: {abs(relative_angle)} degrees to the {direction}\n\t\tDistance: {distance} meters\n\t\tview: "
            })

            content.append({
                "type" : "image",
                "image" : candidate_image
            })

            candidates.append(candidate_image)

        # adds candidate STOP and the select cnadidate view 
        # hvis du skal trene flere modeller så vær oppmerksom på dette med at du bruker stor STOP i prompten men fasit er egt liten Stop action, veldig dum feil
        content.append({
            "type" : "text", # remember the Stop thing when writing later, this hadnt any impact on initial accuracy
            "text" : "\n\tCandidate: Stop\n\nNow, analyze the route instruction, your current position, and the available candidate directions. Select the candidate that best matches the instruction and helps you continue along the correct path. Answer on the format: Candidate: (and then the number)"
        })
        # HUSK AT DU HAR ENDRET STOP til Stop
        messages = [
            {"role" : "system", "content" : [{"type" : "text", "text" : system_prompt}]},
            {"role" : "user", "content" : content},
        ]

        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

        panoramas.extend([current_panorama])

        formatted_sample = {}
        formatted_sample["text"] = text
        formatted_sample["labels"] = str(sample["gold_label"])
        formatted_sample["candidates"] = candidates
        formatted_sample["panoramas"] = panoramas
        formatted_sample["path_id"] = path_id
        formatted_sample["step_id"] = step_id

        formatted_data.append(formatted_sample)
        
        # update variables:
        if sample["gold_label"] != "Stop":
            distance_traveled = round(sample["candidates"][str(sample["gold_label"])]["distance"] + distance_traveled, 2)

    formatted_data = DF.from_list(formatted_data)
    return formatted_data


def run_epoch_train(model, processor, dataloader, optimizer, lr_scheduler, epoch, batch_size, grouped_data):
    print(f"Training model, epoch: {epoch}")
    model.train()

    epoch_loss = 0
    batch_loss = 0
    
    model_loss_analyse = []

    episode_counter = 0
    batch_counter = 0

    l = 0
    # this list contains the number of samples which to be devided with by when accumulating batches
    batch_length_list = []
    for i, (k, v) in enumerate(grouped_data.items()):
        l += len(v)
        
        if (i % batch_size == 0 and i != 0) or i+1 == len(grouped_data):
            batch_length_list.append(l)
            l = 0

    batch_length = batch_length_list[batch_counter]
    print(batch_length)
    
    for i, batch in enumerate(dataloader):
        gold_candidate = batch["gold_candidate"]
        path_id = batch["path_id"]
        step_id = batch["step_id"]
        
        del batch["gold_candidate"]
        del batch["path_id"]
        del batch["step_id"]

        batch.to(device)
    
        outputs = model(**batch)
        
        loss = outputs.loss
        batch_loss += loss.item()
        epoch_loss += loss.item()  
        
        (loss/batch_length).backward()

        wandb.log({"step_loss" : loss.item()})

        # get argmax and -2 because it will predict one more during traning
        argmax = torch.argmax(outputs.logits, dim=2)[0]
        model_prediction = processor.decode(argmax[-2])
        gold = processor.decode(gold_candidate[0])
        
        sample = {}
        sample["loss"] = loss.item() # dette var problemet
        sample["path_id"] = path_id
        sample["step_id"] = step_id
        sample["gold_label"] = gold
        sample["argmax_result"] = model_prediction

        model_loss_analyse.append(sample)

        if (episode_counter % batch_size == 0 and episode_counter != 0 and gold == "Stop") or (i + 1) == len(dataloader):
            wandb.log({"episode_loss" : batch_loss/batch_length})
            
            clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step() # never remove this again, this is by the book and should alawys be like this
            
            optimizer.zero_grad()  # Reset gradients for next accumulation
            
            batch_loss = 0
            batch_counter += 1
            if batch_counter < len(batch_length_list):
                batch_length = batch_length_list[batch_counter]

        if gold == "Stop":
            episode_counter += 1


    epoch_loss /= len(dataloader)
    wandb.log({"training_loss": epoch_loss, "learning_rate": lr_scheduler.get_last_lr()[0]})
    
    return epoch_loss, model_loss_analyse
    

@torch.no_grad
def run_epoch_eval_or_test(model, processor, dataloader, mode="val"):
    model.eval()

    total_loss = 0
    accuracy = 0
    stop_total = 0
    stop = 0
    stop_pred_total = 0
    predictions = []

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            batch.to(device)
            gold_candidate = batch["gold_candidate"]
            path_id = batch["path_id"]
            step_id = batch["step_id"]
            
            del batch["gold_candidate"]
            del batch["path_id"]
            del batch["step_id"]

            outputs = model(**batch)

            total_loss += outputs.loss.item()

            # greedy search for action
            argmax = torch.argmax(outputs.logits, dim=2)[0]

            model_prediction = processor.decode(argmax[-2])
            gold = processor.decode(gold_candidate[0])

            predictions.append({"gold" : gold, "model" : model_prediction, "path_id" : path_id, "step_id" : step_id})
            
            if model_prediction == gold:
                accuracy += 1

            if gold == "Stop":
                stop_total += 1
                if model_prediction == "Stop":
                    stop += 1

            if model_prediction == "Stop":
                stop_pred_total += 1

        sorted_predictions = sorted(predictions, key=lambda x: (x["path_id"], x["step_id"]), reverse=False)
        path_dict = {}
    
        for d in sorted_predictions:
            if path_dict.get(d["path_id"], None) != None:
                path_dict[d["path_id"]]["gold"].append(d["gold"])
                path_dict[d["path_id"]]["model"].append(d["model"])
        
            else:
                path_dict[d["path_id"]] = {}
                path_dict[d["path_id"]]["gold"] = [d["gold"]]
                path_dict[d["path_id"]]["model"] = [d["model"]]

        sr = 0
        for path, d in path_dict.items():
            if d["gold"] == d["model"]:
                sr+=1
                
        sucess_rate = sr/len(path_dict)
        total_loss /= len(dataloader)
        accuracy = accuracy / len(dataloader)
        stop_precision = stop / stop_pred_total if stop_pred_total > 0 else 0.0
        
        
        if mode == "val":
             wandb.log({
                "validation_loss" : total_loss, 
                "validation accuracy" : accuracy, 
                "validation sucess rate" : sucess_rate,
                "validation: Stop recall" : stop/stop_total if stop_total > 0 else 0.0,
                "validation Stop Precision" : stop_precision
            })

        else:
            wandb.log({
                "test_loss" : total_loss, 
                "test accuracy" : accuracy, 
                "test sucess rate" : sucess_rate,
                "test: Stop recall" : stop/stop_total if stop_total > 0 else 0.0,
                "test Stop Precision" : stop_precision
                
            })

        # infere these model actions from the test set as it is not shuffled
        return total_loss, path_dict
    
class Args:
    def __init__(self):
        self.model = "Qwen/Qwen2.5-VL-3B-Instruct"
        self.custom_model = ""
        self.cache_dir = ""

        # these should all have the same name
        self.wandb_name = "Qwen2_5-panoramic-r2r-full"
        self.outputs_path = "./experiment-outputs/Qwen2_5-panoramic-r2r-full"
        self.model_checkpoint_path ="/"

        # for dataset
        self.train_data = "/dataset/train/train_data.json"
        self.val_data = "/dataset/val/val_data.json"
        self.test_data = "/dataset/test/test_data.json"
        self.dataset_root = "/datase/"
        self.system_prompt = "./prompts/system_prompt_panoramic.json"

        # hyperparameters
        self.lr = 0.00001
        self.weight_decay = 0.1
        self.warmup_ratio = 0.1
        self.shuffle = False # MUST NEVER BE SHUFFLED IN THIS NOTEBOOK
        self.epochs = 3 # go through each route instruction once
        
        self.notes = "Finetune of Qwen2.5 panoramic with freezed vision on the full R2R dataset"

        self.batch_size = 1
        self.instruction_index_train = 0
        self.instruction_index_val = 2
        self.freeze_vision = True
        self.use_flash_attention = True
        self.width = 320
        self.height = 240
        self.seed = 128

In [None]:
if __name__ == "__main__":
    args = Args()
    set_seed(args.seed)

    # for saving outputs
    if not os.path.exists(args.outputs_path):
        os.makedirs(args.outputs_path)
        
    model = create_model(args.model, args.cache_dir, args.use_flash_attention)
    processor = AutoProcessor.from_pretrained(args.model, cache_dir=args.cache_dir)

    if args.freeze_vision:
        for name, param in model.visual.named_parameters():
            param.requires_grad = False  # Freeze parameter

 
    train_data = load_dataset(args.train_data)
    grouped_data = grouped_paths(train_data)

    val_data = load_dataset(args.val_data)
    test_data = load_dataset(args.test_data)
    system_prompt = load_system_prompt(args.system_prompt)

    val = format_prompts_v5(val_data, processor, system_prompt, args.instruction_index_val, args.dataset_root, data_type="val")
    val_dataset = CustomDataset(val)
    collate_functor = CollateFunctor(processor, args.width, args.height)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=1, # Enforce batch size of 1 as phi-3.5 vision does not support a larger batch size out of the box (qwen does but i'll look at that later)
        shuffle=False,
        collate_fn=collate_functor,
    )

    # len(grouped_data) = how many times the optimizer.step is called
    total_steps = (len(grouped_data)/args.batch_size)*args.epochs
    warmup_steps = int(args.warmup_ratio * total_steps)

    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    wandb.init(
        # set the wandb project where this run will be logged
        project="masters-thesis",
        name=args.wandb_name,
        notes=args.notes,
        # notes="This is the fine tuning of Qwen2-VL-2B-Instruct on the room-to-room dataset",
        job_type="fine-tune",
    
        # track hyperparameters and run metadata
        config={
            "model" :  args.model,
            "dataset": "room-to-room",
            "epochs": args.epochs,
            "batch_size" : args.batch_size,
            "learning_rate": args.lr,
            "weight_decay" : args.weight_decay,
            "warmup_ratio" : args.warmup_ratio,
            "optimizer" : "AdamW",
            "schedular" : "linear_schedule_with_warmup",
            "architecture": "transformer",
            "seed" : args.seed,
            "action_space" : "panoramic",
            "model_checkpoint_path" : args.model_checkpoint_path,
            "train_data_path" : args.train_data,
            "prompt_path" : args.system_prompt,
            "outputs_path" : args.outputs_path,
            "samples" : len(train_data),
            "instruction_index_train" : "all three instructions", #args.instruction_index_train,
            "instruction_index_val" : args.instruction_index_val
        }
    )

    start = time.time()
    for i in range(args.epochs):
        instruction = i
        
        if i > 2:
            instruction = i - 3

        print(f"instruction number: {instruction}")
        formatted_train = format_prompts_v5(train_data, processor, system_prompt, instruction, args.dataset_root, data_type="train")
        
        dataset = CustomDataset(formatted_train)

        dataloader = DataLoader(
            dataset, 
            batch_size=1, 
            shuffle=False,
            collate_fn=collate_functor
        )

        # calculating loss on validation data
        val_loss, path_dict = run_epoch_eval_or_test(model, processor, val_dataloader, mode="val")

        with open(os.path.join(args.outputs_path, f"val_path_dict_epoch_{i}.json"), "w", encoding="utf-8") as file:
            json.dump(path_dict, file, indent=4) 

        # training model
        train_loss, model_analyse = run_epoch_train(
            model=model,
            processor=processor,
            dataloader=dataloader,
            optimizer=optimizer,
            lr_scheduler=scheduler,
            epoch=i,
            batch_size=args.batch_size,
            grouped_data=grouped_data
        )

        #val_loss, path_dict = run_epoch_eval_or_test(model, processor, val_dataloader, mode="val")
        
        with open(os.path.join(args.outputs_path, f"mode_analyse_{i}.json"), "w", encoding="utf-8") as file:
            json.dump(model_analyse, file, indent=4)

        model_save_path = os.path.join(args.model_checkpoint_path, f"checkpoint_{i}")
        model.save_pretrained(model_save_path)
        print(f"model saved at: {model_save_path}, epoch {i} finished")

    val_loss, val_path_dict = run_epoch_eval_or_test(model, processor, val_dataloader, mode="val")
    end = time.time()
    print(f"training time: {(end-start)/60} minutes")

    test = format_prompts_v5(test_data, processor, system_prompt, args.instruction_index_val, args.dataset_root, data_type="test")
    test_dataset = CustomDataset(test)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=1, # Enforce batch size of 1 as phi-3.5 vision does not support a larger batch size out of the box (qwen does but i'll look at that later)
        shuffle=False,
        collate_fn=collate_functor,
    )

    test_loss, test_path_dict = run_epoch_eval_or_test(model, processor, test_dataloader, mode="test")

    with open(os.path.join(args.outputs_path, "val_path_dict.json"), "w", encoding="utf-8") as file:
        json.dump(val_path_dict, file, indent=4)

    with open(os.path.join(args.outputs_path,"test_path_dict.json"), "w", encoding="utf-8") as file:
        json.dump(test_path_dict, file, indent=4)

    wandb.finish()

Over er trenings koden