In [None]:
import os
import json
import time
import base64
import requests

import torch
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as DT
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration


from utils import set_seed, load_dataset, get_device, load_system_prompt

device = get_device()

In [None]:
# bare endre på disse
dataset = "test"
inst = 2
adjust = False

SAVE_DIR = f""
BASE_URL = ""

API_KEY = ""
headers = {"api-key" : API_KEY}

class Args:
    def __init__(self):
        self.custom_model = "
        self.model = "Qwen/Qwen2.5-VL-3B-Instruct"
        self.version = 2
        self.cache_dir = ""
        self.test_data = f"./notebooks/R2R_{dataset}.json" # inneholder alle paths
        self.system_prompt = "./prompts/system_prompt_v3_5_no_adjust.json"
        self.instruction_index = inst
        self.use_flash_attention = True
        self.width = 320
        self.height = 240
        self.seed = 32

def get_path_instructions(data):
    path_dict = {}

    for p in data:
        if p["path_id"] not in path_dict.keys():
            path_dict[p["path_id"]] = p["instructions"]
        
    return path_dict

class CustomDataset(Dataset):
    def __init__(self, data):
        self.text = data["text"]
        self.images = data["images"]
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, index):
        return self.text[index], self.images[index]

# TODO: make the collatefunctor work with batches
class CollateFunctor:
    # No batch, therefore no max length
    def __init__(self, processor, width, height):
        self.processor = processor
        self.width = width
        self.height = height

    def __call__(self, batch):
        text, images = batch[0]
        label_start = processor.tokenizer("<|im_start|>assistant\nAction: ", return_tensors="pt").input_ids

        images = [Image.open(img) for img in images]

        processed = processor(text=text, images=[images], return_tensors="pt")

        prompt_input_ids = processed["input_ids"]
        input_ids = torch.cat([prompt_input_ids, label_start], dim=1)

        attention_mask = torch.ones(1, input_ids.shape[1])
        processed["input_ids"] = input_ids
        processed["attention_mask"] = attention_mask
        
        return processed


def format_prompt_v3_5(images_path, step_id, route_instruction, distance_traveled, previous_actions, move_possible, processor, system_prompt):
    images = os.listdir(images_path)
    images = [os.path.join(images_path, img) for img in images]
    images = sorted(images, key=lambda x: int(x.split("_")[-1].split(".")[0]))

    current_image = images.pop(-1)
    
    content = [
            {
                "type" : "text", 
                #"text" : f"Route instruction: {sample['instructions'][instruction_index]}\nPrevious images: "
                "text" : f"Route Instruction: {route_instruction}\nCurrent Step: {step_id}\nCummulative Distance Traveled: {distance_traveled}\nImages from Previous Steps: " 
            },
        ]

    for img in images:
        content.append({"type" : "image", "image" : img}) 

    if len(images) == 0:
        content[0]["text"] += f"[]"

    content.append(
            {
                "type" : "text", 
                "text" : f"\nActions performed at Previous Steps: {previous_actions.__str__()}\nCurrent image:"
            }
        )
    content.append(
            {
                "type" : "image", 
                "image" : current_image
            }
        )
    if move_possible:
        possible_actions = ["Left", "Right", "Move", "Stop"]

    else:
        possible_actions = ["Left", "Right", "Stop"]
        
    content.append(
            {
                "type" : "text", 
                "text" : f"\nPossible actions: {possible_actions.__str__()}\nNow predict the next action based on the input you have recived. Answer on the format: Action: (an the action you choose)"
            }
        )

    messages = [
            {"role" : "system", "content" : [{"type" : "text", "text" : system_prompt}]},
            {"role" : "user", "content" : content},
        ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    images.extend([current_image])
    
    formatted_sample = {}
    formatted_sample["text"] = text
    formatted_sample["images"] = images

    formatted_data = [formatted_sample] 
    formatted_data = DT.from_list(formatted_data)
    return formatted_data
    

# Networking functions
def save_image_from_base64(base64_str, filename):
    """Decodes the base64 string and saves the image to a file."""
    # Decode the base64 string to bytes
    image_bytes = base64.b64decode(base64_str)
    
    # Open the image from the decoded bytes and save it
    image = Image.open(BytesIO(image_bytes))
    filepath = os.path.join(SAVE_DIR, filename)
    image.save(filepath)

def start_episode(path_id):
    """Starts a new episode and returns the session ID."""
    url = f"{BASE_URL}/episode/start_episode/{path_id}"
    response = requests.post(url, headers=headers)
    
    if response.status_code == 200:
        print(f"Episode started for path_id: {path_id}")
        data = response.json()  # The response is now in JSON format
        save_image_from_base64(data["image"], f"{path_id}/step_{step_id}.png")
        
        return path_id, data["distance"], data["move_possible"]
    else:
        print(f"Error starting episode: {response.json()}")
        return None

def end_episode(path_id):
    """Ends the current episode and returns some data"""
    url = f"{BASE_URL}/episode/end_episode/{path_id}"
    response = requests.delete(url, headers=headers)
    
    if response.status_code == 200:
        print(f"Episode ended for path_id: {path_id}\n")
        data = response.json()  # The response is now in JSON format
        return data
    else:
        print(f"Error ending episode: {response.json()}")
        return None

def take_action(session_id, action, step_id):
    """Takes an action in the specified session and saves the returned image."""
    url = f"{BASE_URL}/take_action/{session_id}"
    params = {"action": action}
    response = requests.post(url, params=params, headers=headers)
    if response.status_code == 200:
        #print(f"Action '{action}' taken in session '{session_id}'")
        data = response.json()  # The response is now in JSON format
        save_image_from_base64(data["image"], f"{path_id}/step_{step_id}.png")
        return data["distance"], data["move_possible"]
    else:
        print(f"Error taking action: {response.json()}")
        return None, None

In [None]:
args = Args()
set_seed(args.seed)
# kan være at den max pixel greia var problemet og at den ikke klarte å see hele panorama bildet
processor = AutoProcessor.from_pretrained(args.model, cache_dir=args.cache_dir)

if args.version == 1:
    model = Qwen2VLForConditionalGeneration.from_pretrained(
                args.custom_model, 
                cache_dir=args.cache_dir,
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                device_map=device
            )
for batch in data_loader:
    batch.to("cuda")
            
    outputs = model(**batch)
    argmax = torch.argmax(outputs.logits, dim=2)[0]
    model_prediction = processor.decode(argmax[-1]) # is -1 because it does not predict one more
    print(f"Predicted action: {model_prediction}")

else:
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                args.custom_model, 
                torch_dtype=torch.bfloat16,
                cache_dir=args.cache_dir,
                attn_implementation="flash_attention_2",
                device_map=device
            )

test_data = load_dataset(args.test_data)
system_prompt = load_system_prompt(args.system_prompt)

model.eval()
print("model loaded to memory")

data_loader = DataLoader(
    dataset,
    batch_size=1,
    collate_fn=collate_fn
)

In [None]:
collate_fn = CollateFunctor(processor, args.width, args.height)

model_actions = {}

os.makedirs(SAVE_DIR, exist_ok=True)

with torch.no_grad():
    # 3704 er den forrige
    for i, path in enumerate(test_data[1310:]):
        print(f"Number: {i}")
        step_id = 0
        path_id = path["path_id"]
        route_instruction = path["instructions"][args.instruction_index]
        previous_actions = []

        # to view the predictions
        model_predictions = []

        images_path = os.path.join(SAVE_DIR, f"{path_id}")
        os.makedirs(images_path, exist_ok=True)
        time.sleep(0.5)
        
        model_prediction = None
        session_id, distance, move_possible = start_episode(path_id)
        
        # assumes the episode is ended if candidates == None
        while step_id < 32 and model_prediction != "Stop" and distance != None and move_possible != None:
            #prompt = format_prompt_v4(path_id, step_id, instruction, previous_actions, move_possible, system_prompt, processor, data_type="val")
            prompt = format_prompt_v3_5(images_path, step_id, route_instruction, distance, previous_actions, move_possible, processor, system_prompt)
            dataset = CustomDataset(prompt)
            data_loader = DataLoader(
                dataset,
                batch_size=1,
                collate_fn=collate_fn
            )
    
            for batch in data_loader:
                batch.to(device)
                
                outputs = model(**batch)
                argmax = torch.argmax(outputs.logits, dim=2)[0]
                model_prediction = processor.decode(argmax[-1]) # is -1 because it does not predict one more
                model_predictions.append(model_prediction)

                #print(f"Model prediction: {model_prediction}")
                #print(f"Previous Actions: {previous_actions.__str__()}")
                #print(f"Move possible: {move_possible}")
                #print(processor.decode(batch["input_ids"][0]))
                
                if model_prediction == "Move" and adjust == True:
                    step_id += 1
                    previous_actions.append("Automatically Turn Towards Node")
                    distance, move_possible = take_action(path_id, "Adjust", step_id)
                
                step_id += 1
                previous_actions.append(model_prediction)
                distance, move_possible = take_action(path_id, model_prediction, step_id)
        
        model_actions[path_id] = model_predictions
        end_episode(path_id)

with open(f"{SAVE_DIR}-actions.json", "w", encoding="utf-8") as file:
    json.dump(model_actions, file, indent=4)

In [None]:
with open(f"{SAVE_DIR}-actions.json", "w", encoding="utf-8") as file:
    json.dump(model_actions, file, indent=4)

In [None]:
print(model_actions)