<a href="https://colab.research.google.com/github/arumdauo/dixit-AI-bot/blob/main/guesser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Extract CLIP embeddings

In [None]:
import os
import json
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

def extract_image_embeddings(cards_folder, device):
    image_embeddings = []

    for filename in sorted(os.listdir(cards_folder)):
        if filename.endswith(('.png')):
            image_path = os.path.join(cards_folder, filename)
            image = Image.open(image_path).convert('RGB')

            inputs = clip_processor(images=image, return_tensors="pt").to(device)
            with torch.no_grad():
                embedding = clip_model.get_image_features(**inputs)
                embedding = embedding.squeeze().cpu()

            if embedding_dim != embedding.shape[0]:
                reduction_layer = torch.nn.Linear(embedding.shape[0], embedding_dim).to(device)
                embedding = reduction_layer(embedding.to(device)).cpu()

            image_embeddings.append(embedding)

    all_image_embeddings = torch.stack(image_embeddings)
    return all_image_embeddings


config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser.json"
config = load_config(config_path)

cards_folder = config["cards_folder"]
embeddings_save_path = config["embeddings_save_path"]
clip_model_name = config["clip_model_name"]
embedding_dim = 512

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

all_image_embeddings = extract_image_embeddings(cards_folder, device)
torch.save(all_image_embeddings, embeddings_save_path)
print(f"Image embeddings saved at {embeddings_save_path}")


# Llama, Dixit model<br>
Performing the guesser phase with random cards

In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    CLIPProcessor,
    CLIPModel,
    LlamaTokenizer,
    LlamaForCausalLM
)
import pandas as pd
from huggingface_hub import login
import json
import re
import os

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

class DixitModel(nn.Module):
    def __init__(self, embedding_dim, dropout_rate=0.5):
        super(DixitModel, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(256, embedding_dim)

    def forward(self, hint_embedding):
        x = F.relu(self.bn1(self.fc1(hint_embedding)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        return self.fc3(x)

class DixitAI:
    def __init__(self, descriptions_file_path, checkpoint_path, embeddings_save_path, llama_model_path, hf_token):
        self.card_descriptions = {}
        descriptions_df = pd.read_csv(descriptions_file_path)
        for _, row in descriptions_df.iterrows():
            card_id = re.search(r'\d+', row['Image']).group()
            self.card_descriptions[card_id] = {
                'blip_description': row['BLIP'],
                'vit_description': row['ViT'],
                'blip2_description': row['BLIP-2']
            }

        embedding_dim = 512
        self.dixit_model = DixitModel(embedding_dim=embedding_dim).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        self.dixit_model.load_state_dict(checkpoint['model_state_dict'])
        self.dixit_model.eval()

        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_auth_token=hf_token)
        self.llama_model = LlamaForCausalLM.from_pretrained(
            llama_model_path,
            use_auth_token=hf_token,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(device)
        self.llama_model.eval()

        self.all_image_embeddings = torch.load(embeddings_save_path, map_location=device)
        self.num_cards = self.all_image_embeddings.shape[0]

    def create_cot_prompt(self, hint, candidate_cards):
        prompt_template = f"""You are an expert at matching images to a hint and giving a relatedness score.
            Your task is to evaluate each card's relevance to the given hint by giving a score.

            Hint: '{hint}'

            """
        for i, card_id in enumerate(candidate_cards):
            if 0 <= card_id < self.num_cards:
                blip_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip_description', "No description available")
                vit_desc = self.card_descriptions.get(str(card_id + 1), {}).get('vit_description', "No description available")
                blip2_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip2_description', "No description available")

                prompt_template += f"""
                Card {card_id + 1}:
                Descriptions: BLIP - {blip_desc} | ViT - {vit_desc} | BLIP-2 - {blip2_desc}
                Provide a numeric relatedness score (0-10) between image and hint.
                """

        return prompt_template

    def score_card(self, hint, card_id):
        if not (0 <= card_id < self.num_cards):
            raise ValueError(f"Invalid card_id: {card_id}. Must be between 0 and {self.num_cards-1}")

        hint_inputs = self.clip_processor(text=hint, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            hint_embedding = self.clip_model.get_text_features(**hint_inputs)
            hint_embedding = F.normalize(hint_embedding, p=2, dim=-1).cpu()

        hint_embedding = self.dixit_model(hint_embedding.to(device)).squeeze(0)
        hint_embedding = F.normalize(hint_embedding, p=2, dim=0)

        card_embedding = F.normalize(self.all_image_embeddings[card_id], p=2, dim=0)
        similarity = F.cosine_similarity(hint_embedding, card_embedding, dim=0).item()
        return similarity

    def generate_reasoning(self, hint, candidate_cards):
        """Generate longer reasoning using LLaMA with clear instructions for output."""
        prompt = self.create_cot_prompt(hint, candidate_cards)

        input_ids = self.llama_tokenizer.encode(prompt, return_tensors='pt').to(device)
        attention_mask = torch.ones_like(input_ids)

        max_length = min(input_ids.shape[1] + 200, 2048)

        output = self.llama_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            num_return_sequences=1,
            do_sample=True,
            top_k=20,
            top_p=0.6,  # decrease top_p to increase output diversity
            temperature=0.3,
            no_repeat_ngram_size=3,
            pad_token_id=self.llama_tokenizer.eos_token_id
        )

        response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
        reasoning_output = response[len(prompt):].strip()

        print("\n" + "="*30 + " DEBUG: LLaMA Reasoning Output " + "="*30)
        print(reasoning_output)
        print("="*75 + "\n")

        score_match = re.search(r'score:\s*(\d+)/10\b', reasoning_output, re.IGNORECASE)

        if score_match:
            score_value = float(score_match.group(1))
        else:
            score_match = re.search(r'\b([2-9])\b', reasoning_output)
            if score_match:
                preceding_text = reasoning_output[:score_match.start()]
                score_value = float(score_match.group(1))
            else:
                return reasoning_output, 0.0

        print(score_match)

        normalized_score = (score_value - 5) / 5
        return reasoning_output, normalized_score

    def choose_card(self, hint, candidate_cards):
        scores = {}
        for card_id in candidate_cards:
            similarity = self.score_card(hint, card_id)
            print(f"Card {card_id + 1} Similarity Score: {similarity:.3f}")
            reasoning, reasoning_score = self.generate_reasoning(hint, [card_id])
            combined_score = 0.7 * reasoning_score + 0.3 * similarity
            scores[card_id] = combined_score

        best_card = max(scores.items(), key=lambda x: x[1])[0]
        return best_card, scores[best_card]

    def play_turn(self, hint):
        """Play a complete turn as both card selector and guesser"""
        # First subphase: bot selects a card from its own hand
        bot_hand = random.sample(range(self.num_cards), 6)
        chosen_card, score = self.choose_card(hint, bot_hand)
        print(f"Selected card {chosen_card + 1} from hand with score {score:.3f} from {bot_hand}" )

        # Second subphase: other players' cards for guessing
        available_cards = [i for i in range(self.num_cards) if i not in bot_hand]
        other_cards = random.sample(available_cards, 3)

        print(f"\nOther players' cards: {[card + 1 for card in other_cards]}")

        # Bot guesses the target card among other players' cards
        guess_card, guess_score = self.choose_card(hint, other_cards)
        print(f"Guessed card {guess_card + 1} with score {guess_score:.3f}")

        return chosen_card, guess_card

def main():
    dixit_ai = DixitAI(
        descriptions_file_path=config["descriptions_file_path"],
        checkpoint_path=config["checkpoint_path"],
        embeddings_save_path=config["embeddings_save_path"],
        llama_model_path=config["llama_model_path"],
        hf_token=config["hf_token"]
    )

    hint = "Looking forward to meet you"
    chosen_card, guessed_card = dixit_ai.play_turn(hint)

    print(f"\nFinal Results:")
    print(f"Selected card: {chosen_card + 1}")
    print(f"Guessed card: {guessed_card + 1}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser.json"
config = load_config(config_path)

cards_folder = config["cards_folder"]
embeddings_save_path = config["embeddings_save_path"]
clip_model_name = config["clip_model_name"]
descriptions_file_path = config["descriptions_file_path"]
checkpoint_path = config["checkpoint_path"]
llama_model_path = config["llama_model_path"]
hf_token = config["hf_token"]

if __name__ == "__main__":
    main()

# Llama, Dixit model <br>
Performing the guesser phase


In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    CLIPProcessor,
    CLIPModel,
    LlamaTokenizer,
    LlamaForCausalLM
)
import pandas as pd
from huggingface_hub import login
import json
import re
import os

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

class DixitModel(nn.Module):
    def __init__(self, embedding_dim, dropout_rate=0.5):
        super(DixitModel, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(256, embedding_dim)

    def forward(self, hint_embedding):
        x = F.relu(self.bn1(self.fc1(hint_embedding)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        return self.fc3(x)

class DixitAI:
    def __init__(self, descriptions_file_path, checkpoint_path, embeddings_save_path, llama_model_path, hf_token):
        self.card_descriptions = {}
        descriptions_df = pd.read_csv(descriptions_file_path)
        for _, row in descriptions_df.iterrows():
            card_id = re.search(r'\d+', row['Image']).group()
            self.card_descriptions[card_id] = {
                'blip_description': row['BLIP'],
                'vit_description': row['ViT'],
                'blip2_description': row['BLIP-2']
            }

        embedding_dim = 512
        self.dixit_model = DixitModel(embedding_dim=embedding_dim).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        self.dixit_model.load_state_dict(checkpoint['model_state_dict'])
        self.dixit_model.eval()

        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_auth_token=hf_token)
        self.llama_model = LlamaForCausalLM.from_pretrained(
            llama_model_path,
            use_auth_token=hf_token,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(device)
        self.llama_model.eval()

        self.all_image_embeddings = torch.load(embeddings_save_path, map_location=device)
        self.num_cards = self.all_image_embeddings.shape[0]

    def create_cot_prompt(self, hint, candidate_cards):
        prompt_template = f"""You are an expert at matching images to a hint and giving a relatedness score.
            Your task is to evaluate each card's relevance to the given hint by giving a score.

            Hint: '{hint}'

            """
        for i, card_id in enumerate(candidate_cards):
            if 0 <= card_id < self.num_cards:
                blip_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip_description', "No description available")
                vit_desc = self.card_descriptions.get(str(card_id + 1), {}).get('vit_description', "No description available")
                blip2_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip2_description', "No description available")

                prompt_template += f"""
                Card {card_id + 1}:
                Descriptions: BLIP - {blip_desc} | ViT - {vit_desc} | BLIP-2 - {blip2_desc}
                Provide a relatedness score (0-10) between image and hint.
                Score:
                """

        return prompt_template

    def score_card(self, hint, card_id):
        if not (0 <= card_id < self.num_cards):
            raise ValueError(f"Invalid card_id: {card_id}. Must be between 0 and {self.num_cards-1}")

        hint_inputs = self.clip_processor(text=hint, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            hint_embedding = self.clip_model.get_text_features(**hint_inputs)
            hint_embedding = F.normalize(hint_embedding, p=2, dim=-1).cpu()

        hint_embedding = self.dixit_model(hint_embedding.to(device)).squeeze(0)
        hint_embedding = F.normalize(hint_embedding, p=2, dim=0)

        card_embedding = F.normalize(self.all_image_embeddings[card_id], p=2, dim=0)
        similarity = F.cosine_similarity(hint_embedding, card_embedding, dim=0).item()
        return similarity

    def generate_reasoning(self, hint, candidate_cards):
        """Generate reasoning using LLaMA with to get responses and scores."""

        prompt = f"""You are an expert at evaluating the relevance of images to a hint by analyzing descriptions.
        Your task is to assess each image's relevance to the hint and give a relatedness score (0-10).

        Hint: '{hint}'

        For each candidate image, analyze its descriptions and follow these instructions:
        1. Write a brief reasoning explaining the relevance or irrelevance of the descriptions to the hint.
        2. On a new line, provide a "Score:" followed by a numeric score (0-10), with 10 being the most relevant and 0 the least.

        """

        for i, card_id in enumerate(candidate_cards):
            if 0 <= card_id < self.num_cards:
                blip_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip_description', "No description available")
                vit_desc = self.card_descriptions.get(str(card_id + 1), {}).get('vit_description', "No description available")
                blip2_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip2_description', "No description available")

                prompt += f"""
                Description for Card {card_id + 1}: {blip_desc}, {vit_desc}, {blip2_desc}
                Reasoning and Score:
                """

        input_ids = self.llama_tokenizer.encode(prompt, return_tensors='pt').to(device)
        attention_mask = torch.ones_like(input_ids)
        max_length = min(input_ids.shape[1] + 500, 2048)

        output = self.llama_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            num_return_sequences=1,
            do_sample=True,
            top_k=20,
            top_p=0.9,
            temperature=0.2,
            no_repeat_ngram_size=3,
            pad_token_id=self.llama_tokenizer.eos_token_id
        )

        response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
        reasoning_output = response[len(prompt):].strip()

        print("\n" + "="*30 + "LLaMA Reasoning Output" + "="*30)
        print("Reasoning Output:\n", reasoning_output)
        print("="*75 + "\n")

        score_match = re.search(r'score:\s*(\d+)/10\b', reasoning_output, re.IGNORECASE)

        if score_match:
            score_value = float(score_match.group(1))
        else:
            score_match = re.search(r'\b([2-9])\b', reasoning_output)
            if score_match:
                preceding_text = reasoning_output[:score_match.start()]
                score_value = float(score_match.group(1))
            else:
                return reasoning_output, 0.0

        normalized_score = (score_value - 5) / 5
        print("Normalized Score:", normalized_score)
        return reasoning_output, normalized_score


    def choose_card(self, hint, candidate_cards):
        """Choose the best card from the candidates based on the hint."""
        scores = {}
        for card_id in candidate_cards:
            similarity = self.score_card(hint, card_id)
            print(f"Card {card_id + 1} Similarity Score: {similarity:.3f}")
            reasoning, reasoning_score = self.generate_reasoning(hint, [card_id])
            combined_score = 0.7 * reasoning_score + 0.3 * similarity
            scores[card_id] = combined_score

        best_card = max(scores.items(), key=lambda x: x[1])[0]
        return best_card, scores[best_card]

    def play_turn(self, hint):
        """Play a complete turn as both card selector and guesser."""
        bot_hand = input("Enter six card IDs for the bot's hand, separated by spaces: ")
        bot_hand = [int(card_id) - 1 for card_id in bot_hand.split()]

        chosen_card, score = self.choose_card(hint, bot_hand)
        print(f"Selected card {chosen_card + 1} from hand with score {score:.3f} from {bot_hand}")

        other_cards = input("Enter three card IDs for other players' cards, separated by spaces: ")
        other_cards = [int(card_id) - 1 for card_id in other_cards.split()]

        guess_card, guess_score = self.choose_card(hint, other_cards)
        print(f"Guessed card {guess_card + 1} with score {guess_score:.3f}")

        return chosen_card, guess_card

def main():
    dixit_ai = DixitAI(
        descriptions_file_path=config["descriptions_file_path"],
        checkpoint_path=config["checkpoint_path"],
        embeddings_save_path=config["embeddings_save_path"],
        llama_model_path=config["llama_model_path"],
        hf_token=config["hf_token"]
    )

    hint = input("Enter a hint for this turn: ")
    chosen_card, guessed_card = dixit_ai.play_turn(hint)

    print(f"\nFinal Results:")
    print(f"Selected card: {chosen_card + 1}")
    print(f"Guessed card: {guessed_card + 1}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser.json"
config = load_config(config_path)

cards_folder = config["cards_folder"]
embeddings_save_path = config["embeddings_save_path"]
clip_model_name = config["clip_model_name"]
descriptions_file_path = config["descriptions_file_path"]
checkpoint_path = config["checkpoint_path"]
llama_model_path = config["llama_model_path"]
hf_token = config["hf_token"]

if __name__ == "__main__":
    main()


In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    CLIPProcessor,
    CLIPModel,
    LlamaTokenizer,
    LlamaForCausalLM
)
import pandas as pd
from huggingface_hub import login
import json
import re
import os

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

class DixitModel(nn.Module):
    def __init__(self, embedding_dim):
        super(DixitModel, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(0.5)

        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(256, embedding_dim)

    def forward(self, hint_embedding):
        x = F.relu(self.bn1(self.fc1(hint_embedding)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        x = self.fc3(x)
        return F.normalize(x, p=2, dim=-1)

class DixitAI:
    def __init__(self, descriptions_file_path, checkpoint_path, embeddings_save_path, llama_model_path, hf_token):
        self.card_descriptions = {}
        descriptions_df = pd.read_csv(descriptions_file_path)
        for _, row in descriptions_df.iterrows():
            card_id = re.search(r'\d+', row['Image']).group()
            self.card_descriptions[card_id] = {
                'blip_description': row['BLIP'],
                'vit_description': row['ViT'],
                'blip2_description': row['BLIP-2']
            }

        embedding_dim = 512
        self.dixit_model = DixitModel(embedding_dim=embedding_dim).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        self.dixit_model.load_state_dict(checkpoint['model_state_dict'])
        self.dixit_model.eval()

        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_auth_token=hf_token)
        self.llama_model = LlamaForCausalLM.from_pretrained(
            llama_model_path,
            use_auth_token=hf_token,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(device)
        self.llama_model.eval()

        self.all_image_embeddings = torch.load(embeddings_save_path, map_location=device)
        self.num_cards = self.all_image_embeddings.shape[0]

    def score_card(self, hint, card_id):
        """Compute similarity between transformed hint and card embeddings."""
        if not (0 <= card_id < self.num_cards):
            raise ValueError(f"Invalid card_id: {card_id}. Must be between 0 and {self.num_cards-1}")

        hint_inputs = self.clip_processor(text=hint, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            hint_embedding = self.clip_model.get_text_features(**hint_inputs)
            hint_embedding = F.normalize(hint_embedding, p=2, dim=-1).cpu()

        hint_embedding = self.dixit_model(hint_embedding.to(device)).squeeze(0)
        hint_embedding = F.normalize(hint_embedding, p=2, dim=0)

        card_embedding = F.normalize(self.all_image_embeddings[card_id], p=2, dim=0)
        similarity = F.cosine_similarity(hint_embedding, card_embedding, dim=0).item()
        print("\n" + "="*30 + "="*30)
        return similarity

    def generate_reasoning(self, hint, candidate_cards):
        """Generate reasoning using LLaMA to get responses and scores."""

        prompt = f"""You are an expert at evaluating the relevance of images to a hint by analyzing descriptions.
        Your task is to assess each image's relevance to the hint and give a relatedness score (0-10).

        Hint: '{hint}'

        For each candidate image, analyze its descriptions and follow these instructions:
        1. Write a brief reasoning explaining the relevance or irrelevance of the descriptions to the hint.
        2. On a new line, provide a "Score:" followed by a numeric score (0-10), with 10 being the most relevant and 0 the least.

        """

        for i, card_id in enumerate(candidate_cards):
            if 0 <= card_id < self.num_cards:
                blip_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip_description', "No description available")
                vit_desc = self.card_descriptions.get(str(card_id + 1), {}).get('vit_description', "No description available")
                blip2_desc = self.card_descriptions.get(str(card_id + 1), {}).get('blip2_description', "No description available")

                prompt += f"""
                Description for Card {card_id + 1}: {blip_desc}, {vit_desc}, {blip2_desc}
                Reasoning and Score:
                """

        input_ids = self.llama_tokenizer.encode(prompt, return_tensors='pt').to(device)
        attention_mask = torch.ones_like(input_ids)
        max_length = min(input_ids.shape[1] + 500, 2048)

        output = self.llama_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            num_return_sequences=1,
            do_sample=True,
            top_k=20,
            top_p=0.9,
            temperature=0.2,
            no_repeat_ngram_size=3,
            pad_token_id=self.llama_tokenizer.eos_token_id
        )

        response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True)
        reasoning_output = response[len(prompt):].strip()

        print("Reasoning Output:\n", reasoning_output)

        score_match = re.search(r'score:\s*(\d+)/10\b', reasoning_output, re.IGNORECASE)

        if score_match:
            score_value = float(score_match.group(1))
        else:
            score_match = re.search(r'\b([2-9]|10)\b(?!\.)', reasoning_output)
            if score_match:
                preceding_text = reasoning_output[:score_match.start()]
                score_value = float(score_match.group(1))
            else:
                return reasoning_output, 0.0

        normalized_score = (score_value - 5) / 5
        print("Normalized Score:", normalized_score)
        return reasoning_output, normalized_score


    def choose_card(self, hint, candidate_cards):
        """Choose the best card from the candidates based on the hint."""
        scores = {}
        for card_id in candidate_cards:
            similarity = self.score_card(hint, card_id)
            print(f"Card {card_id + 1} Similarity Score: {similarity:.3f}")
            reasoning, reasoning_score = self.generate_reasoning(hint, [card_id])
            combined_score = 0.6 * reasoning_score + 0.4 * similarity
            print(f"Combined Score: {combined_score:.3f}")
            scores[card_id] = combined_score

        best_card = max(scores.items(), key=lambda x: x[1])[0]
        return best_card, scores[best_card]

    def play_turn(self, hint):
        """Play a complete turn as both card selector and guesser."""
        bot_hand = input("Enter six card IDs for the bot's hand, separated by spaces: ")
        bot_hand = [int(card_id) - 1 for card_id in bot_hand.split()]

        chosen_card, score = self.choose_card(hint, bot_hand)
        print(f"Selected card {chosen_card + 1} from hand with score {score:.3f} from {bot_hand}")

        other_cards = input("Enter three card IDs for other players' cards, separated by spaces: ")
        other_cards = [int(card_id) - 1 for card_id in other_cards.split()]

        guess_card, guess_score = self.choose_card(hint, other_cards)
        print(f"Guessed card {guess_card + 1} with score {guess_score:.3f}")

        return chosen_card, guess_card

def main():
    dixit_ai = DixitAI(
        descriptions_file_path=config["descriptions_file_path"],
        checkpoint_path=config["checkpoint_path"],
        embeddings_save_path=config["embeddings_save_path"],
        llama_model_path=config["llama_model_path"],
        hf_token=config["hf_token"]
    )

    hint = input("Enter a hint for this turn: ")
    chosen_card, guessed_card = dixit_ai.play_turn(hint)

    print(f"\nFinal Results:")
    print(f"Selected card: {chosen_card + 1}")
    print(f"Guessed card: {guessed_card + 1}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser.json"
config = load_config(config_path)

cards_folder = config["cards_folder"]
embeddings_save_path = config["embeddings_save_path"]
clip_model_name = config["clip_model_name"]
descriptions_file_path = config["descriptions_file_path"]
checkpoint_path = config["checkpoint_path_contrastiveloss"]
llama_model_path = config["llama_model_path"]
hf_token = config["hf_token"]

if __name__ == "__main__":
    main()
