In [None]:
import os
import json
import time
import base64
import requests

import torch
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as DT
#from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor


from utils import set_seed, load_dataset, get_device, load_system_prompt

device = get_device()

In [None]:
# bare endre på disse
dataset = "test"
inst = 2
model = "qwen2_5"

SAVE_DIR = f""
BASE_URL = ""

API_KEY = ""
headers = {"api-key" : API_KEY}

class Args:
    def __init__(self):
        self.custom_model = ""
        self.model = "Qwen/Qwen2.5-VL-3B-Instruct"
        self.version = 2 if model == "qwen2_5" else 1
        self.cache_dir = ""
        self.test_data = f"./notebooks/R2R_{dataset}.json" # inneholder alle paths
        self.system_prompt = "../prompts/system_prompt_panoramic.json"
        self.instruction_index = inst
        self.use_flash_attention = True
        self.width = 320
        self.height = 240
        self.seed = 32

def get_path_instructions(data):
    path_dict = {}

    for p in data:
        if p["path_id"] not in path_dict.keys():
            path_dict[p["path_id"]] = p["instructions"]
        
    return path_dict


def format_prompt_v5(images_path, path_id, route_instruction, step_id, distance_traveled, candidates, processor, system_prompt):
    # should be in the order: panorama_history, current_panorama, candidates views from left to right
    images = os.listdir(images_path)
    panoramas = [os.path.join(images_path, img) for img in images if img.startswith("pano")]
    panoramas = sorted(panoramas, key=lambda x: int(x.split("_")[-1].split(".")[-2]))

    # these are probably sorted by default, however you might need to check
    candidate_images = [os.path.join(images_path, img) for img in images if img.startswith("pano") == False]
    candidate_images = sorted(candidate_images, key=lambda x: int(x.split("_")[-1].split(".")[0]))
    
    current_panorama = panoramas.pop(-1)

    # route instruction, current step, cumulative distance
    content = [
        {
            "type" : "text",
            "text" : f"Route instruction: {route_instruction}\nCurrent step: {step_id}\nCumulative Distance Traveled: {distance_traveled} meters\n\nPanorama Images from Previous Steps:"
        }
    ]

    # panorama from previous steps
    for i, img in enumerate(panoramas):
        content.append({
            "type" : "text",
            "text" : f"\n\tPanorama at step: {i}: "
        })
        content.append({
            "type" : "image",
            "image" : img
        })

    if len(panoramas) == 0:
        content[0]["text"] += f"[]"

    # current panorama
    content.append({
        "type" : "text",
        "text" : f"\n\nCurrent Panorama Image:\n\t"
    })

    content.append({
        "type" : "image",
        "image" : current_panorama
    })

    # candidate directions
    content.append({
        "type" : "text",
        "text" : "\n\nCandidate Directions:"
    })

    for i, candidate in enumerate(candidates):
        relative_angle = round(candidate["relative_angle"], 0)
        distance = round(candidate["distance"], 2)
        direction = "Left" if relative_angle < 0 else "Right"
        
        content.append({
            "type" : "text",
            "text" : f"\n\tCandidate: {i}:\n\t\tRelative angle: {abs(relative_angle)} degrees to the {direction}\n\t\tDistance: {distance} meters\n\t\tview: "
        })

        content.append({
            "type" : "image",
            "image" : candidate_images[i]
        })


    # adds candidate STOP and the select cnadidate view 
    content.append({
        "type" : "text",
        "text" : "\n\tCandidate: Stop\n\nNow, analyze the route instruction, your current position, and the available candidate directions. Select the candidate that best matches the instruction and helps you continue along the correct path. Answer on the format: Candidate: (and then the number)"
    })

    messages = [
        {"role" : "system", "content" : [{"type" : "text", "text" : system_prompt}]},
        {"role" : "user", "content" : content},
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

    panoramas.extend([current_panorama])

    formatted_sample = {}
    formatted_sample["text"] = text
    formatted_sample["candidates"] = candidate_images
    formatted_sample["panoramas"] = panoramas

    formatted_data = [formatted_sample] 
    formatted_data = DT.from_list(formatted_data)
    return formatted_data

class CustomDataset(Dataset):
    def __init__(self, data):
        self.text = data["text"]
        self.panoramas = data["panoramas"]
        self.candidates = data["candidates"]
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, index):
        return self.text[index], self.panoramas[index], self.candidates[index]

# TODO: make the collatefunctor work with batches
class CollateFunctor:
    # No batch, therefore no max length
    def __init__(self, processor, width, height):
        self.processor = processor
        self.width = width
        self.height = height

    def __call__(self, batch):
        text, panoramas, candidates = batch[0]
        label_start = processor.tokenizer("<|im_start|>assistant\nCandidate: ", return_tensors="pt").input_ids

        images = [Image.open(img) for img in panoramas]
        candidate_images = [Image.open(img) for img in candidates]
        #candidate_images = [Image.open(img).resize((self.width, self.height), Image.Resampling.LANCZOS) for img in candidates]
        images.extend(candidate_images)
        
        processed = processor(text=text, images=[images], return_tensors="pt")

        prompt_input_ids = processed["input_ids"]
        input_ids = torch.cat([prompt_input_ids, label_start], dim=1)
        
        attention_mask = torch.ones(1, input_ids.shape[1])
        processed["input_ids"] = input_ids
        processed["attention_mask"] = attention_mask
        
        return processed

# to save space we delete candidate images after use
def delete_candidate_images(folder):
    images = os.listdir(folder)
    deletable = [img for img in images if img.startswith("pano") == False]

    for img in deletable:
        file_path = os.path.join(folder, img)
        if os.path.isfile(file_path):
            os.remove(file_path)
    

# Networking functions
def save_image_from_base64(base64_str, filename):
    """Decodes the base64 string and saves the image to a file."""
    # Decode the base64 string to bytes
    image_bytes = base64.b64decode(base64_str)
    
    # Open the image from the decoded bytes and save it
    image = Image.open(BytesIO(image_bytes))
    filepath = os.path.join(SAVE_DIR, filename)
    image.save(filepath)

def handle_data(path_id, step_id, data):
    # save panorama
    save_image_from_base64(data['panorama'], f"{path_id}/pano_step_{step_id}.png")

    # save candidates:
    candidates = []
    for i, candidate in enumerate(data["candidates"].values()):
        save_image_from_base64(candidate["image"], f"{path_id}/step_{step_id}_candidate_{i}.png")
        candidates.append({
            "relative_angle" : candidate["relative_angle"],
            "distance" : candidate["distance"]
        })

    return candidates
    
def start_episode(path_id):
    """Starts a new episode and returns the session ID."""
    url = f"{BASE_URL}/episode/start_episode/{path_id}"
    response = requests.post(url, headers=headers)
    
    if response.status_code == 200:
        print(f"Episode started for path_id: {path_id}")
        data = response.json()  # The response is now in JSON format
        candidates = handle_data(path_id, 0, data)
        return path_id, candidates
    else:
        print(f"Error starting episode: {response.json()}")
        return None

def end_episode(path_id):
    """Ends the current episode and returns some data"""
    url = f"{BASE_URL}/episode/end_episode/{path_id}"
    response = requests.delete(url, headers=headers)
    
    if response.status_code == 200:
        print(f"Episode ended for path_id: {path_id}\n")
        data = response.json()  # The response is now in JSON format
        return data
    else:
        print(f"Error ending episode: {response.json()}")
        return None

def take_action(session_id, action, step_id):
    """Takes an action in the specified session and saves the returned image."""
    url = f"{BASE_URL}/take_action/{session_id}"
    params = {"action": action}
    response = requests.post(url, params=params, headers=headers)
    if response.status_code == 200:
        #print(f"Action '{action}' taken in session '{session_id}'")
        data = response.json()  # The response is now in JSON format
        candidates = handle_data(session_id, step_id, data)
        return candidates
    else:
        print(f"Error taking action: {response.json()}")
        return None

In [None]:
args = Args()
set_seed(args.seed)
# kan være at den max pixel greia var problemet og at den ikke klarte å see hele panorama bildet
processor = AutoProcessor.from_pretrained(args.model, cache_dir=args.cache_dir)

if args.version == 1:
    model = Qwen2VLForConditionalGeneration.from_pretrained(
                args.custom_model, 
                cache_dir=args.cache_dir,
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                device_map=device
            )

else:
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                args.custom_model, 
                torch_dtype=torch.bfloat16,
                cache_dir=args.cache_dir,
                attn_implementation="flash_attention_2",
                device_map=device
            )

test_data = load_dataset(args.test_data)
system_prompt = load_system_prompt(args.system_prompt)

model.eval()
print("model loaded to memory")

In [None]:
collate_fn = CollateFunctor(processor, args.width, args.height)

model_actions = {}

os.makedirs(SAVE_DIR, exist_ok=True)

with torch.no_grad():
    # 479
    for i, path in enumerate(test_data[1320:]):
        print(f"Number: {i}")
        step_id = 0
        path_id = path["path_id"]
        route_instruction = path["instructions"][args.instruction_index]
        cumulative_distance = 0

        # to view the predictions
        model_predictions = []

        images_path = os.path.join(SAVE_DIR, f"{path_id}")
        os.makedirs(images_path, exist_ok=True)
        time.sleep(0.5)
        
        model_prediction = None
        session_id, candidates = start_episode(path_id)

        # assumes the episode is ended if candidates == None
        while step_id < 32 and model_prediction != "Stop" and candidates != None:
            #prompt = format_prompt_v4(path_id, step_id, instruction, previous_actions, move_possible, system_prompt, processor, data_type="val")
            prompt = format_prompt_v5(images_path, path_id, route_instruction, step_id, cumulative_distance, candidates, processor, system_prompt)
            
            dataset = CustomDataset(prompt)
            data_loader = DataLoader(
                dataset,
                batch_size=1,
                collate_fn=collate_fn
            )
    
            for batch in data_loader:
                batch.to(device)
                
                outputs = model(**batch)
                
                argmax = torch.argmax(outputs.logits, dim=2)[0]
                model_prediction = processor.decode(argmax[-1]) # is -1 because it does not predict one more
                model_predictions.append(model_prediction)

                #print(f"Model prediction: {model_prediction}")
                if model_prediction.isnumeric() and int(model_prediction) < len(candidates):
                    selected_candidate = candidates[int(model_prediction)]
                    cumulative_distance = round(cumulative_distance + selected_candidate["distance"], 2)

                delete_candidate_images(images_path)
                
                step_id += 1
                candidates = take_action(path_id, model_prediction, step_id)
                
        delete_candidate_images(images_path)
        model_actions[path_id] = model_predictions
        end_episode(path_id)

with open(f"{SAVE_DIR}-actions.json", "w", encoding="utf-8") as file:
    json.dump(model_actions, file, indent=4)