In [None]:
import numpy as np
import torch
import os
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
import faiss
import json
import tempfile
import sys
import matplotlib.pyplot as plt
import tqdm
from concurrent.futures import ThreadPoolExecutor

# Path to append for imports
sys.path.append(os.path.abspath(".."))
from inference.create_img import convert_mario_to_png

# Functions from level_similarity_search.py
def load_json_data(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)

def load_level_search_index(index_path):
    """Load the FAISS index and level indices"""
    index = faiss.read_index(index_path)
    with open(index_path + ".indices", "r") as f:
        level_indices = [int(line.strip()) for line in f]
    print(f"Index loaded from {index_path}")
    return index, level_indices

# Function to create level data structure expected by the system
def create_level_data(level_string):
    # Split by newlines
    rows = level_string.split("\n")
    # Also check for | separator and split if present
    processed_rows = []
    for row in rows:
        if "|" in row:
            # Split by | and keep non-empty parts
            parts = [p for p in row.split("|") if p]
            processed_rows.extend(parts)
        else:
            processed_rows.append(row)
    
    # Create the window representation
    window = processed_rows
    
    # Create level data structure
    level_data = {
        "window": window,
        "level_string": level_string
    }
    
    return level_data

def generate_level_embedding(level_data, model, processor, game_type):
    """Process a level into an embedding"""
    window = level_data["window"]
    
    # Create level image
    if game_type == "mario":
        img, _, _ = convert_mario_to_png("\n".join(window), tiles_dir="../../assets/mario")
    else:
        raise ValueError(f"Game type {game_type} not supported yet")

    with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
        temp_path = temp_file.name
        img.save(temp_path)

    try:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "url": f"{temp_path}"},
                ]
            },
        ]

        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(model.device)

        with torch.no_grad():
            outputs = model(**inputs)
            embedding = outputs.image_hidden_states.cpu().numpy()

        return embedding

    finally:
        if os.path.exists(temp_path):
            os.remove(temp_path)

def search_similar_levels(query_level_data, model, processor, index, level_indices, game_type, top_k=5):
    """Find similar levels to the query level"""
    query_features = generate_level_embedding(query_level_data, model, processor, game_type)
    query_features = query_features.astype(np.float32).reshape(1, -1)

    faiss.normalize_L2(query_features)

    distances, indices = index.search(query_features, top_k)
    similarities = (distances + 1) / 2

    similar_level_indices = [level_indices[int(idx)] for idx in indices[0]]

    return similar_level_indices, similarities[0]

# --- Main Processing Code ---

# Load the JSON file with levels
input_json_path = "level_generation_results_20250521_100328.json"
levels_data = load_json_data(input_json_path)

# Set up the game type and embedding directory
game_type = "mario"
embedding_dir = f"embeddings_{game_type}"
top_k = 5  # Number of similar levels to find

# Load the model and processor (do this once to avoid reloading for each level)
print("Loading models...")
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")

# Load the search index
index, level_indices = load_level_search_index(f"{embedding_dir}/level_index.faiss")

# Function to process a single level
def process_level(level_entry):
    try:
        level_string = level_entry["level"]
        level_data = create_level_data(level_string)
        
        # Run similarity search
        similar_indices, similarities = search_similar_levels(
            level_data, model, processor, index, level_indices, game_type, top_k=top_k
        )
        
        # Add similarity results to the entry
        level_entry["similarity_results"] = {
            "similar_levels": [int(idx) for idx in similar_indices],
            "similarity_scores": [float(score) for score in similarities]
        }
        
        return level_entry
    except Exception as e:
        print(f"Error processing level: {str(e)}")
        level_entry["similarity_results"] = {
            "error": str(e)
        }
        return level_entry

# Process levels in parallel to speed things up
print(f"Processing {len(levels_data)} levels...")
processed_levels = []

# Use a progress bar to track the processing
with tqdm.tqdm(total=len(levels_data)) as pbar:
    # Process levels in smaller batches to avoid memory issues
    batch_size = 10
    for i in range(0, len(levels_data), batch_size):
        batch = levels_data[i:i+batch_size]
        
        # Process batch
        with ThreadPoolExecutor(max_workers=4) as executor:
            for result in executor.map(process_level, batch):
                processed_levels.append(result)
                pbar.update(1)
        
        # Clear GPU memory between batches
        torch.cuda.empty_cache()

# Save the results to a new JSON file
output_json = "levels_with_similarity_metrics.json"
with open(output_json, "w") as f:
    json.dump(processed_levels, f, indent=2)

print(f"Results saved to {output_json}")
