<a href="https://colab.research.google.com/github/arumdauo/dixit-AI-bot/blob/main/storyteller.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers torch pillow pandas

Load Llama model

In [None]:
import os
import json
import torch
import pandas as pd
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
from google.colab import drive
from google.colab import runtime
from huggingface_hub import login

def load_config(config_path='config.json'):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

def load_llama_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    return tokenizer, model

config_path = '/content/drive/MyDrive/Colab Notebooks/dixit/config_storyteller.json'
config = load_config(config_path)
login(config["huggingface_token"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer, model = load_llama_model(config["model_name"])

# Performes the storyteller phase
Generates a hint based on the card descriptions

In [None]:
import re

def load_image_by_id(image_id, image_folder):
    image_path_1 = os.path.join(image_folder, f"card_{image_id}.png")
    image_path_2 = os.path.join(image_folder, f"{image_id}.png")
    if os.path.exists(image_path_1):
        image = Image.open(image_path_1)
        return image
    elif os.path.exists(image_path_2):
        image = Image.open(image_path_2)
        return image
    else:
        print("Image not found.")
        return None

def get_descriptions_from_csv(image_id, csv_path):
    df = pd.read_csv(csv_path)
    row_1 = df[df['Image'] == f"card_{image_id}.png"]
    row_2 = df[df['Image'] == f"{image_id}.png"]
    if not row_1.empty:
        row = row_1
    elif not row_2.empty:
        row = row_2
    else:
        print("Descriptions not found for this image ID.")
        return None

    descriptions = {
        "BLIP": row['BLIP'].values[0],
        "ViT": row['ViT'].values[0],
        "BLIP-2": row['BLIP-2'].values[0]
    }
    return descriptions

def generate_hint_1(descriptions, tokenizer, model):
    max_new_tokens = 15
    temperature = 0.9
    top_p = 0.8
    top_k = 50
    repetition_penalty = 1.5
    num_beams = 2

    examples = (
        "Examples of poetic, standalone clues:\n"
        "1. Context: A woman stands alone by the sea\n   Clue: \"Solitude.\"\n"
        "2. Context: An owl watches in the night\n   Clue: \"Should I fly?\"\n"
        "3. Context: Flowers bloom under a stormy sky\n   Clue: \"Petals defyh.\"\n\n"
    )

    prompt = (
        f"{examples}"
        "Create a short, mysterious clue based on the essence of this scene:\n"
        f"{descriptions['BLIP']}\n"
        f"{descriptions['ViT']}\n"
        f"{descriptions['BLIP-2']}\n\n"
        "Generate a single phrase ending with a period.\n"
        "Clue: "
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    output = model.generate(
        inputs['input_ids'],
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        num_beams=num_beams,
        do_sample=True
    )
    hint = tokenizer.decode(output[0], skip_special_tokens=True).split("Clue: ")[-1].strip()
    hint = re.sub(r'[^\w\s,.?!]', '', hint)

    return hint

def generate_hint_2(descriptions, tokenizer, model):
    max_new_tokens = 15
    temperature = 0.9
    top_p = 0.9
    top_k = 90
    repetition_penalty = 1.3
    num_beams = 1

    examples = (
        "Examples of poetic clues:\n"
        "1. Context: A moonlit forest path\n   Clue: \"Footsteps\"\n"
        "2. Context: A candle in the dark\n   Clue: \"Shadowâ€™s embrace.\"\n"
        "3. Context: A river winding through quiet mountains\n   Clue: \"There is just one.\"\n\n"
    )

    prompt = (
        f"{examples}"
        "Now, create a short, mysterious clue based on this context:\n"
        f"1. {descriptions['BLIP-2']}\n"
        f"2. {descriptions['ViT']}\n"
        f"3. {descriptions['BLIP']}\n\n"
        "Generate a clue with a dot at the end.\n"
        "Clue: "
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    output = model.generate(
        inputs['input_ids'],
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        num_beams=num_beams,
        do_sample=True
    )
    hint = tokenizer.decode(output[0], skip_special_tokens=True).split("Clue:")[-1].strip()
    hint = re.sub(r'[^\w\s,.?!]', '', hint)

    return hint

def refine_hint(hint_1, hint_2, tokenizer, model):
    max_new_tokens = 8
    temperature = 0.9
    top_p = 0.9
    top_k = 70
    repetition_penalty = 1.5
    num_beams = 3

    prompt = (
        "Given the following concepts, generate a very short and enigmatic hint:\n"
        f"Concept 1: {hint_1}\n"
        f"Concept 2: {hint_2}\n\n"
        "Generate a very short and enigamtic hint.\n"
        "Hint: "
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    output = model.generate(
        inputs['input_ids'],
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        num_beams=num_beams,
        do_sample=True
    )
    refined_hint = tokenizer.decode(output[0], skip_special_tokens=True).split("Hint: ")[-1].strip()
    refined_hint = re.sub(r'[^\w\s,.?!]', '', refined_hint)

    return refined_hint

def select_final_hint(hint_1, hint_2, refined_hint):
    hint_1 = clean_hint(hint_1)
    hint_2 = clean_hint(hint_2)
    refined_hint = clean_hint(refined_hint)

    if len(refined_hint) > 2:
        final_hint = refined_hint
    elif len(hint_1) > 2:
        final_hint = hint_1
    elif len(hint_2) > 2:
        final_hint = hint_2
    else:
        final_hint = "No hint available."

    return final_hint

def clean_hint(hint):
    hint = re.sub(r'^[.\s]+', '', hint)
    hint = re.sub(r'[.\s]+$', '', hint)
    return hint

image_id = input("Enter the image ID (e.g., 3 for card_3.png): ")
image = load_image_by_id(image_id, config["image_folder"])
if image:
    display(image)
    descriptions = get_descriptions_from_csv(image_id, config["csv_path"])
    if descriptions:
        print("Descriptions:")
        for key, desc in descriptions.items():
            print(f"{key}: {desc}")
        hint_1 = generate_hint_1(descriptions, tokenizer, model)
        print("\nGenerated Hint 1:")
        print(hint_1)
        hint_2 = generate_hint_2(descriptions, tokenizer, model)
        print("\nGenerated Hint 2:")
        print(hint_2)
        refined_hint = refine_hint(hint_1, hint_2, tokenizer, model)
        print("\nRefined Hint:")
        print(refined_hint)
        final_hint = select_final_hint(hint_1, hint_2, refined_hint)
        print("\nFinal Hint:")
        print(final_hint)
