In [None]:
import os
import json
import time

import wandb
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from PIL import Image
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
#from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration,  AutoProcessor
from datasets import Dataset as DF

from src import set_seed, load_dataset, get_device, load_system_prompt, change_labels, grouped_paths

In [None]:
device = get_device()

def create_model(model_id, cache_dir, use_flash_attention):
    if use_flash_attention:
        return Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id, 
            cache_dir = cache_dir,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map=device
        )
    
    return Qwen2VLForConditionalGeneration.from_pretrained(
        model_id, 
        cache_dir = cache_dir,
        torch_dtype="auto",
        device_map=device
    )


class CustomDataset(Dataset):
    def __init__(self, data):
        self.text = data["text"]
        self.labels = data["labels"]
        self.images = data["images"]
        self.path_id = data["path_id"]
        self.step_id = data["step_id"]
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        return self.text[index], self.images[index], self.labels[index], self.path_id[index], self.step_id[index]

# TODO: make the collatefunctor work with batches
class CollateFunctor:
    # No batch, therefore no max length
    def __init__(self, processor, width, height):
        self.processor = processor
        self.width = width
        self.height = height

    def __call__(self, batch):
        text, images, labels, path_id, step_id = batch[0]
        action_tokens = processor.tokenizer(labels, return_tensors="pt").input_ids
        label_start = processor.tokenizer("<|im_start|>assistant\nAction: ", return_tensors="pt").input_ids

        #images = [Image.open(img).resize((self.width, self.height), Image.Resampling.LANCZOS) for img in images]
        images = [Image.open(img) for img in images]

        processed = processor(text=text, images=[images], return_tensors="pt")

        prompt_input_ids = processed["input_ids"]
        prompt_input_ids = torch.cat([prompt_input_ids, label_start], dim=1)
        input_ids = torch.cat([prompt_input_ids, action_tokens], dim=1)
        
        labels = torch.cat(
            [
                # se på denne senere, altså om vi skal virkelig fokusere på alt det der andre chat greiene
                torch.tensor([-100]*len(prompt_input_ids[0])).unsqueeze(0),
                action_tokens
            ],
            dim=1
        )


        attention_mask = torch.ones(1, input_ids.shape[1])
        processed["labels"] = labels
        processed["input_ids"] = input_ids
        processed["attention_mask"] = attention_mask
        processed["gold_action"] = action_tokens
        processed["path_id"] = path_id
        processed["step_id"] = step_id
        
        return processed


def format_prompts_v3_5(dataset, processor, system_prompt, instruction_index, path, data_type="train"):
    root_path = os.path.join(path, data_type)
    formatted_data = []
    
    for sample in dataset:
        content = [
            {
                "type" : "text", 
                #"text" : f"Route instruction: {sample['instructions'][instruction_index]}\nPrevious images: "
                "text" : f"Route Instruction: {sample['instructions'][instruction_index]}\nCurrent Step: {sample['step_id']}\nCummulative Distance Traveled: {sample['distance_traveled']}\nImages from Previous Steps: " 
            },
        ]
        
        images = sample["past_images"]
        images = [os.path.join(root_path, i) for i in sample["past_images"]]

        # HUSK DENNE HVIS DU SKAL PRØVE NOE ANNET
        #if len(images) > 5:
        #    images = images[-6:]
        #    for img in images:
        #        content.append({"type" : "image", "image" : img})

        #else:
        for img in images:
            content.append({"type" : "image", "image" : img}) 

        if len(images) == 0:
            content[0]["text"] += f"[]"

        content.append(
            {
                "type" : "text", 
                "text" : f"\nActions performed at Previous Steps: {sample['previous_actions'].__str__()}\nCurrent image:"
            }
        )
        content.append(
            {
                "type" : "image", 
                "image" : os.path.join(root_path, sample["current_image"])
            }
        )
        content.append(
            {
                "type" : "text", 
                "text" : f"\nPossible actions: {sample['possible_actions'].__str__()}\nNow predict the next action based on the input you have recived. Answer on the format: Action: (an the action you choose)"
            }
        )

        messages = [
            {"role" : "system", "content" : [{"type" : "text", "text" : system_prompt}]},
            {"role" : "user", "content" : content},
        ]

        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        #labels = f"<|im_start|>assistant\n{sample['gold_label']}<|im_end|>\n<|endoftext|>"
        #labels = f"<|im_start|>assistant\n{sample['gold_label']}<|im_end|>\n"# might be some cause of the problem not adding the end of text
        images.extend([os.path.join(root_path, sample["current_image"])])
        
        formatted_sample = {}
        formatted_sample["text"] = text
        formatted_sample["labels"] = sample["gold_label"]
        formatted_sample["images"] = images
        #formatted_sample["gold_actions"] = sample["gold_label"]
        # maybe remove for later
        formatted_sample["path_id"] = sample["path_id"]
        formatted_sample["step_id"] = sample["step_id"]

        formatted_data.append(formatted_sample)

    formatted_data = DF.from_list(formatted_data)
    return formatted_data


def run_epoch_train(model, processor, dataloader, optimizer, lr_scheduler, epoch, batch_size, grouped_data):
    print(f"Training model, epoch: {epoch}")
    model.train()

    epoch_loss = 0
    batch_loss = 0
    
    model_loss_analyse = []

    episode_counter = 0
    batch_counter = 0
    
    l = 0
    # this list contains the number of samples which to be devided with by when accumulating batches
    batch_length_list = []
    for i, (k, v) in enumerate(grouped_data.items()):
        l += len(v)
        
        if (i % batch_size == 0 and i != 0) or i+1 == len(grouped_data):
            batch_length_list.append(l)
            l = 0

    batch_length = batch_length_list[batch_counter]
    print(batch_length)
    
    for i, batch in enumerate(dataloader):
        gold_action = batch["gold_action"]
        path_id = batch["path_id"]
        step_id = batch["step_id"]
        
        del batch["gold_action"]
        del batch["path_id"]
        del batch["step_id"]
        
        batch.to(device)

        outputs = model(**batch)
        
        loss = outputs.loss
        batch_loss += loss.item()
        epoch_loss += loss.item()

        (loss/batch_length).backward()

        wandb.log({"step_loss" : loss.item()})

        argmax = torch.argmax(outputs.logits, dim=2)[0]
        model_prediction = processor.decode(argmax[-2])
        gold = processor.decode(gold_action[0])

        sample = {}
        #sample["prompt"] = processor.decode(batch["input_ids"][0])
        sample["loss"] = loss.item()# dette var problemet
        sample["path_id"] = path_id
        sample["step_id"] = step_id
        sample["gold_label"] = gold
        sample["argmax_result"] = model_prediction

        model_loss_analyse.append(sample)

        if (episode_counter % batch_size == 0 and episode_counter != 0 and gold == "Stop") or (i + 1) == len(dataloader):
            wandb.log({"batch_loss" : batch_loss/batch_length})
            
            clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step() # never remove this again, this is by the book and should alawys be like this
            
            optimizer.zero_grad()  # Reset gradients for next accumulation
            
            batch_loss = 0
            batch_counter += 1
            if batch_counter < len(batch_length_list):
                batch_length = batch_length_list[batch_counter]

        if gold == "Stop":
            episode_counter += 1

        if (i+1) % 500 == 0 or (i+1) == len(dataloader):
            print(f"step: {i}")
            print(f"Path_id:  {path_id}, step_id: {step_id}")
            print(f"step Loss: {loss.item()}")
            print(f"Model Prediction: {repr(model_prediction)}")
            print(f"Gold Action: {repr(gold)}")
            print()
            softmax = F.softmax(outputs.logits, dim=2)
            print(f"probability for Stop on -2: {softmax[0, -2, 10674]}")
            print(f"probability for Move on -2: {softmax[0, -2, 9860]}")
            print(f"probability for Right on -2: {softmax[0, -2, 5979]}")
            print(f"probability for Left on -2: {softmax[0, -2, 5415]}")
            print(f"probability for <|im_end|> on -1: {softmax[0, -1, 151645]}")
            print(f"probability for backslash n on -1: {softmax[0, -1, 198]}")
            print(f"probability for <|endoftext|> on -1: {softmax[0, -1, 151643]}")
            print()
            print("-"*60)

    
    epoch_loss /= len(dataloader)
    wandb.log({"training_loss": epoch_loss, "learning_rate": lr_scheduler.get_last_lr()[0]})

    return epoch_loss, model_loss_analyse
    

@torch.no_grad
def run_epoch_eval_or_test(model, processor, dataloader, mode="eval"):
    model.eval()

    total_loss = 0
    accuracy = 0
    predictions = []

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            gold_action = batch["gold_action"]
            path_id = batch["path_id"]
            step_id = batch["step_id"]
            
            del batch["gold_action"]
            del batch["path_id"]
            del batch["step_id"]
        
            batch.to(device)
        
            outputs = model(**batch)

            total_loss += outputs.loss.item()

            # greedy search for action
            argmax = torch.argmax(outputs.logits, dim=2)[0]

            model_prediction = processor.decode(argmax[-2]).strip(" ")
            gold = processor.decode(gold_action[0])

            predictions.append({"gold" : gold, "model" : model_prediction, "path_id" : path_id, "step_id" : step_id})
            
            if model_prediction == gold:
                accuracy += 1


    sorted_predictions = sorted(predictions, key=lambda x: (x["path_id"], x["step_id"]), reverse=False)
    path_dict = {}

    for d in sorted_predictions:
        if path_dict.get(d["path_id"], None) != None:
            path_dict[d["path_id"]]["gold"].append(d["gold"])
            path_dict[d["path_id"]]["model"].append(d["model"])
    
        else:
            path_dict[d["path_id"]] = {}
            path_dict[d["path_id"]]["gold"] = [d["gold"]]
            path_dict[d["path_id"]]["model"] = [d["model"]]

    gold_action_count = {
        "Left" : len([i["gold"] for i in sorted_predictions if i["gold"] == "Left"]),
        "Right" : len([i["gold"] for i in sorted_predictions if i["gold"] == "Right"]),
        "Move" : len([i["gold"] for i in sorted_predictions if i["gold"] == "Move"]),
        "Stop" : len([i["gold"] for i in sorted_predictions if i["gold"] == "Stop"]),
    }
    
    model_label_recall = {
        "Left" : len([i["model"] for i in sorted_predictions if i["model"] == "Left" and i["gold"] == "Left"])/gold_action_count["Left"],
        "Right" : len([i["model"] for i in sorted_predictions if i["model"] == "Right" and i["gold"] == "Right"])/gold_action_count["Right"],
        "Move" : len([i["model"] for i in sorted_predictions if i["model"] == "Move" and i["gold"] == "Move"])/gold_action_count["Move"],
        "Stop" : len([i["model"] for i in sorted_predictions if i["model"] == "Stop" and i["gold"] == "Stop"])/gold_action_count["Stop"],
    }

    # Precision calculation
    model_action_count = {
        "Left": len([i["model"] for i in sorted_predictions if i["model"] == "Left"]),
        "Right": len([i["model"] for i in sorted_predictions if i["model"] == "Right"]),
        "Move": len([i["model"] for i in sorted_predictions if i["model"] == "Move"]),
        "Stop": len([i["model"] for i in sorted_predictions if i["model"] == "Stop"]),
    }
    
    model_label_precision = {
        "Left": len([i for i in sorted_predictions if i["model"] == "Left" and i["gold"] == "Left"]) / model_action_count["Left"] if model_action_count["Left"] > 0 else 0.0,
        "Right": len([i for i in sorted_predictions if i["model"] == "Right" and i["gold"] == "Right"]) / model_action_count["Right"] if model_action_count["Right"] > 0 else 0.0,
        "Move": len([i for i in sorted_predictions if i["model"] == "Move" and i["gold"] == "Move"]) / model_action_count["Move"] if model_action_count["Move"] > 0 else 0.0,
        "Stop": len([i for i in sorted_predictions if i["model"] == "Stop" and i["gold"] == "Stop"]) / model_action_count["Stop"] if model_action_count["Stop"] > 0 else 0.0
    }


    sr = 0
    for path, d in path_dict.items():
        if d["gold"] == d["model"]:
            sr+=1
            
    success_rate = sr/len(path_dict)
    total_loss /= len(dataloader)
    accuracy = accuracy / len(dataloader)

    if mode == "eval":
        wandb.log({
            "validation_loss" : total_loss, 
            "validation accuracy" : accuracy, 
            "validation sucess rate" : success_rate,
            "validation: Left recall" : model_label_recall["Left"], 
            "validation: Right recall" : model_label_recall["Right"], 
            "validation: Move recall" : model_label_recall["Move"],
            "validation: Stop recall" : model_label_recall["Stop"],
            "validation: Left precision" : model_label_precision["Left"], 
            "validation: Right precision" : model_label_precision["Right"], 
            "validation: Move precision" : model_label_precision["Move"],
            "validation: Stop precision" : model_label_precision["Stop"]
        })

    else:
        wandb.log({
            "test_loss" : total_loss, 
            "test accuracy" : accuracy, 
            "test sucess rate" : success_rate,
            "test: Left recall" : model_label_recall["Left"], 
            "test: Right recall" : model_label_recall["Right"], 
            "test: Move recall" : model_label_recall["Move"],
            "test: Stop recall" : model_label_recall["Stop"],
            "test: Left precision" : model_label_precision["Left"], 
            "test: Right precision" : model_label_precision["Right"], 
            "test: Move precision" : model_label_precision["Move"],
            "test: Stop precision" : model_label_precision["Stop"]
            
        })

    # infere these model actions from the test set as it is not shuffled
    return total_loss, path_dict


class Args:
    def __init__(self):
        self.model = "Qwen/Qwen2.5-VL-3B-Instruct"
        self.cache_dir = ""

        # they should all have the same name
        self.wandb_name = ""
        self.outputs_path = "./experiment-outputs/"
        self.model_checkpoint_path = ""
        
        self.train_data = "/dataset/train/train_data.json"
        self.val_data = "/datasetval/val_data.json"
        self.test_data = "/dataset/test/test_data.json"
        self.dataset_root = "/dataset/"
        self.system_prompt = "./prompts/system_prompt_low_level.json"
        
        self.lr = 0.00001
        self.weight_decay = 0.1
        self.warmup_ratio = 0.1
        self.shuffle = False
        self.epochs = 1
        self.batch_size = 1
        
        self.notes = ""
        
        self.use_flash_attention = True
        self.freeze_vision = True
        self.instruction_index_train = 0
        self.instruction_index_val = 2
        self.width = 320
        self.height = 240
        self.seed = 128

In [None]:
if __name__ == "__main__":
    args = Args()
    set_seed(args.seed)

    # for saving outputs
    if not os.path.exists(args.outputs_path):
        os.makedirs(args.outputs_path)
    
    model = create_model(args.model, args.cache_dir, args.use_flash_attention)
    processor = AutoProcessor.from_pretrained(args.model, cache_dir=args.cache_dir, max_pixels=args.width*args.height)

    if args.freeze_vision:
        for name, param in model.visual.named_parameters():
            param.requires_grad = False  # Freeze parameter

    train_data = load_dataset(args.train_data)
    train_data = change_labels(train_data)
    grouped_data = grouped_paths(train_data)

    val_data = load_dataset(args.val_data)
    val_data = change_labels(val_data)

    test_data = load_dataset(args.test_data)
    test_data = change_labels(test_data)

    system_prompt = load_system_prompt(args.system_prompt)

    val = format_prompts_v3_5(val_data, processor, system_prompt, args.instruction_index_val, args.dataset_root, data_type="val")
    val_dataset = CustomDataset(val)
    collate_functor = CollateFunctor(processor, args.width, args.height)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=1, # Enforce batch size of 1 as phi-3.5 vision does not support a larger batch size out of the box (qwen does but i'll look at that later)
        shuffle=False,
        collate_fn=collate_functor,
    )

    # Ok, dette var kanskje problemet med den andre modellen du lærte, men bare kjør freezed også prøver du på nytt uten freezed
    total_steps = (len(grouped_data)/args.batch_size)*args.epochs
    warmup_steps = int(args.warmup_ratio * total_steps)

    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    wandb.init(
        # set the wandb project where this run will be logged
        project="masters-thesis",
        name=args.wandb_name,
        notes=args.notes,
        # notes="This is the fine tuning of Qwen2-VL-2B-Instruct on the room-to-room dataset",
        job_type="fine-tune",
    
        # track hyperparameters and run metadata
        config={
            "model" :  args.model,
            "dataset": "room-to-room",
            "epochs": args.epochs,
            "batch_size" : args.batch_size,
            "learning_rate": args.lr,
            "weight_decay" : args.weight_decay,
            "warmup_ratio" : args.warmup_ratio,
            "optimizer" : "AdamW",
            "schedular" : "linear_schedule_with_warmup",
            "architecture": "transformer",
            "seed" : args.seed,
            "action_space" : ["Move", "Left", "Right", "Stop"],
            "model_checkpoint_path" : args.model_checkpoint_path,
            "train_data_path" : args.train_data,
            "prompt_path" : args.system_prompt,
            "outputs_path" : args.outputs_path,
            "samples" : len(train_data),
            "instruction_index_train" : "all three instructions", #args.instruction_index_train,
            "instruction_index_val" : args.instruction_index_val
        }
    )

    start = time.time()
    for i in range(args.epochs):
        instruction = i
        
        if i > 2:
            instruction = i - 3

        print(f"instruction number: {instruction}")
        
        formatted_data = format_prompts_v3_5(train_data, processor, system_prompt, instruction, args.dataset_root)
        
        dataset = CustomDataset(formatted_data)

        dataloader = DataLoader(
            dataset, 
            batch_size=1, 
            shuffle=False,
            collate_fn=collate_functor
        )

        # calculating loss on validation data
        val_loss, path_dict = run_epoch_eval_or_test(model, processor, val_dataloader)

        with open(os.path.join(args.outputs_path, f"val_path_dict_epoch_{i}.json"), "w", encoding="utf-8") as file:
            json.dump(path_dict, file, indent=4) 
        
        # training model
        train_loss, model_analyse = run_epoch_train(
            model=model,
            processor=processor,
            dataloader=dataloader,
            optimizer=optimizer,
            lr_scheduler=scheduler,
            epoch=i,
            batch_size=args.batch_size,
            grouped_data=grouped_data
        )

        with open(os.path.join(args.outputs_path, f"model_analyse_{i}.json"), "w", encoding="utf-8") as file:
            json.dump(model_analyse, file, indent=4)

        model_save_path = os.path.join(args.model_checkpoint_path, f"checkpoint_{i}")
        model.save_pretrained(model_save_path)
        print(f"model saved at: {model_save_path}, epoch {i} finished")


        with open(os.path.join(args.outputs_path, f"val_path_dict_epoch_{i}.json"), "w", encoding="utf-8") as file:
            json.dump(path_dict, file, indent=4) 


    #val_loss, val_path_dict = run_epoch_eval_or_test(model, processor, val_dataloader, mode="eval")
    end = time.time()
    print(f"training time: {(end-start)/60} minutes")

     # run test
    test = format_prompts_v3_5(test_data, processor, system_prompt, args.instruction_index_val, args.dataset_root, data_type="test")
    test_dataset = CustomDataset(test)

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=1, 
        shuffle=False,
        collate_fn=collate_functor,
    )
    test_loss, test_path_dict = run_epoch_eval_or_test(model, processor, test_dataloader, mode="test")

    #with open(os.path.join(args.outputs_path, "val_path_dict.json"), "w", encoding="utf-8") as file:
    #    json.dump(val_path_dict, file, indent=4)

    with open(os.path.join(args.outputs_path,"test_path_dict.json"), "w", encoding="utf-8") as file:
        json.dump(test_path_dict, file, indent=4)

    wandb.finish()