<a href="https://colab.research.google.com/github/arumdauo/dixit-AI-bot/blob/main/guesser_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Extract images embeddings with CLIP

In [None]:
import os
import torch
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import pickle
import json

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

def extract_image_embeddings(cards_folder):
    image_embeddings = {}
    for filename in os.listdir(cards_folder):
        if filename.endswith(('.png')):
            card_id = int(os.path.splitext(filename)[0])
            image_path = os.path.join(cards_folder, filename)
            image = Image.open(image_path).convert('RGB')
            inputs = clip_processor(images=image, return_tensors="pt").to(device)
            with torch.no_grad():
                embedding = clip_model.get_image_features(**inputs)
                embedding = F.normalize(embedding, p=2, dim=-1).cpu()
            image_embeddings[card_id - 1] = embedding  # Adjust to 0-based index

    return image_embeddings

config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser_train.json"
config = load_config(config_path)

cards_folder = config["cards_folder"]
embeddings_path = config["embeddings_path"]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model = CLIPModel.from_pretrained(config["model_name"]).to(device)
clip_processor = CLIPProcessor.from_pretrained(config["model_name"])

image_embeddings = extract_image_embeddings(cards_folder)
with open(embeddings_path, 'wb') as f:
    pickle.dump(image_embeddings, f)

print(f"Image embeddings saved to {embeddings_path}")


# Training

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import CLIPProcessor, CLIPModel
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import nltk
import re
from collections import Counter
from nltk.corpus import stopwords
from sklearn.metrics import accuracy_score

with open(embeddings_path, 'rb') as f:
    image_embeddings = pickle.load(f)

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

class DixitDataset(Dataset):
    def __init__(self, csv_path, image_embeddings, debug_size=None):
        self.data = pd.read_csv(csv_path)
        if debug_size:
            self.data = self.data.sample(n=debug_size, random_state=42)
        self.image_embeddings = image_embeddings
        self.truncated_count = 0

    def __len__(self):
        return len(self.data)

    def keyword_extract(self, text):
        words = re.findall(r'\w+', text.lower())
        words = [word for word in words if word not in stop_words]
        most_common_words = [word for word, _ in Counter(words).most_common(77)]
        return ' '.join(most_common_words)

    def __getitem__(self, idx):
        hint = self.data.iloc[idx]['DESCRIPTION']
        target = int(self.data.iloc[idx]['TARGET']) - 1
        distractor_idxs = [idx for idx in range(len(self.image_embeddings)) if idx != target]
        distractor = self.image_embeddings[np.random.choice(distractor_idxs)]

        hint = self.keyword_extract(hint)
        hint_inputs = clip_processor(
            text=hint,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=77
        ).to(device)

        with torch.no_grad():
            hint_embedding = clip_model.get_text_features(**hint_inputs)
            hint_embedding = F.normalize(hint_embedding, p=2, dim=-1).cpu()

        target_embedding = self.image_embeddings[target]
        return hint_embedding.squeeze(0), target_embedding.squeeze(0), distractor.squeeze(0)

class DixitModel(nn.Module):
    def __init__(self, embedding_dim, dropout_rate):
        super(DixitModel, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 512)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, embedding_dim)

    def forward(self, hint_embedding):
        x = F.relu(self.fc1(hint_embedding))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def save_checkpoint(state, is_best, checkpoint_dir, filename="checkpoint.pth"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    torch.save(state, os.path.join(checkpoint_dir, filename))
    if is_best:
        torch.save(state, os.path.join(checkpoint_dir, "best_model.pth"))

def load_checkpoint(model, optimizer, checkpoint_dir, filename="checkpoint.pth"):
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float('inf'))
        print(f"Loaded checkpoint '{checkpoint_path}' (epoch {start_epoch})")
        return start_epoch, best_val_loss
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, float('inf')

def cosine_similarity(a, b):
    return F.cosine_similarity(a, b, dim=-1)

def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    all_similarities = []

    for hint_embedding, target_embedding, distractor_embedding in tqdm(dataloader):
        hint_embedding = hint_embedding.to(device)
        target_embedding = target_embedding.to(device)
        distractor_embedding = distractor_embedding.to(device)

        optimizer.zero_grad()
        anchor = model(hint_embedding)
        loss = criterion(anchor, target_embedding, distractor_embedding)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM)
        optimizer.step()

        total_loss += loss.item()

        similarity_to_target = cosine_similarity(anchor, target_embedding)
        similarity_to_distractor = cosine_similarity(anchor, distractor_embedding)

        correct_predictions = (similarity_to_target > similarity_to_distractor).sum().item()
        total_correct += correct_predictions
        total_samples += hint_embedding.size(0)

        all_similarities.append(similarity_to_target.mean().item() - similarity_to_distractor.mean().item())

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    avg_similarity_diff = np.mean(all_similarities)

    return avg_loss, accuracy, avg_similarity_diff

def validate_epoch(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    all_similarities = []

    with torch.no_grad():
        for hint_embedding, target_embedding, distractor_embedding in dataloader:
            hint_embedding = hint_embedding.to(device)
            target_embedding = target_embedding.to(device)
            distractor_embedding = distractor_embedding.to(device)

            anchor = model(hint_embedding)
            loss = criterion(anchor, target_embedding, distractor_embedding)
            total_loss += loss.item()

            similarity_to_target = cosine_similarity(anchor, target_embedding)
            similarity_to_distractor = cosine_similarity(anchor, distractor_embedding)

            correct_predictions = (similarity_to_target > similarity_to_distractor).sum().item()
            total_correct += correct_predictions
            total_samples += hint_embedding.size(0)

            all_similarities.append(similarity_to_target.mean().item() - similarity_to_distractor.mean().item())

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    avg_similarity_diff = np.mean(all_similarities)

    return avg_loss, accuracy, avg_similarity_diff

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser_train.json"
config = load_config(config_path)

csv_path = config["csv_path"]
cards_folder = config["cards_folder"]
checkpoint_dir = config["checkpoint_dir"]
embeddings_path = config["embeddings_path"]

BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPOCHS = 50
DROPOUT_RATE = 0.6  # Increased from 0.5
VALIDATION_SPLIT = 0.2
EARLY_STOPPING_PATIENCE = 5
WEIGHT_DECAY = 1e-5
TRIPLET_MARGIN = 0.4  # Increased from 0.3
CLIP_GRAD_NORM = 1.0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

clip_model = CLIPModel.from_pretrained(config["model_name"]).to(device)
clip_processor = CLIPProcessor.from_pretrained(config["model_name"])

dataset = DixitDataset(csv_path, image_embeddings)
train_size = int((1 - VALIDATION_SPLIT) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

embedding_dim = next(iter(train_loader))[0].shape[1]
model = DixitModel(embedding_dim, DROPOUT_RATE).to(device)
criterion = nn.TripletMarginLoss(margin=TRIPLET_MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

start_epoch, best_val_loss = load_checkpoint(model, optimizer, checkpoint_dir)

early_stopping_counter = 0

train_losses, val_losses = [], []

for epoch in range(start_epoch, EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train_loss, train_accuracy, train_similarity_diff = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_accuracy, val_similarity_diff = validate_epoch(model, val_loader, criterion)

    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, "
          f"Training Similarity Difference: {train_similarity_diff:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, "
          f"Validation Similarity Difference: {val_similarity_diff:.4f}")

    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1

    save_checkpoint({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss
    }, is_best, checkpoint_dir)

    if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
        print("Early stopping triggered.")
        break

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Training Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import CLIPProcessor, CLIPModel
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import nltk
import re
from collections import Counter
from nltk.corpus import stopwords
from sklearn.metrics import accuracy_score

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

class DixitDataset(Dataset):
    def __init__(self, csv_path, image_embeddings, debug_size=None):
        self.data = pd.read_csv(csv_path)
        if debug_size:
            self.data = self.data.sample(n=debug_size, random_state=42)
        self.image_embeddings = image_embeddings
        self.truncated_count = 0

    def __len__(self):
        return len(self.data)

    def keyword_extract(self, text):
        words = re.findall(r'\w+', text.lower())
        words = [word for word in words if word not in stop_words]
        most_common_words = [word for word, _ in Counter(words).most_common(77)]
        return ' '.join(most_common_words)

    def __getitem__(self, idx):
        hint = self.data.iloc[idx]['DESCRIPTION']
        target = int(self.data.iloc[idx]['TARGET']) - 1
        distractor_idxs = [idx for idx in range(len(self.image_embeddings)) if idx != target]
        distractor = self.image_embeddings[np.random.choice(distractor_idxs)]

        hint = self.keyword_extract(hint)
        hint_inputs = clip_processor(
            text=hint,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=77
        ).to(device)

        with torch.no_grad():
            hint_embedding = clip_model.get_text_features(**hint_inputs)
            hint_embedding = F.normalize(hint_embedding, p=2, dim=-1).cpu()

        target_embedding = self.image_embeddings[target]
        return hint_embedding.squeeze(0), target_embedding.squeeze(0), distractor.squeeze(0)

class DixitModel(nn.Module):
    def __init__(self, embedding_dim, dropout_rate, negative_slope=0.01):
        super(DixitModel, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(256, embedding_dim)
        self.negative_slope = negative_slope

    def forward(self, hint_embedding):
        x = F.leaky_relu(self.bn1(self.fc1(hint_embedding)), negative_slope=self.negative_slope)
        x = self.dropout1(x)
        x = x + hint_embedding
        x = F.leaky_relu(self.bn2(self.fc2(x)), negative_slope=self.negative_slope)
        x = self.dropout2(x)
        x = self.fc3(x)
        x = F.normalize(x, p=2, dim=-1)
        return x

def save_checkpoint(state, is_best, checkpoint_dir, filename="checkpoint.pth"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    torch.save(state, os.path.join(checkpoint_dir, filename))
    if is_best:
        torch.save(state, os.path.join(checkpoint_dir, "best_model.pth"))

def load_checkpoint(model, optimizer, checkpoint_dir, filename="checkpoint.pth"):
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float('inf'))
        print(f"Loaded checkpoint '{checkpoint_path}' (epoch {start_epoch})")
        return start_epoch, best_val_loss
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, float('inf')

def cosine_similarity(a, b):
    return F.cosine_similarity(a, b, dim=-1)

def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    all_similarities = []

    for hint_embedding, target_embedding, distractor_embedding in tqdm(dataloader):
        hint_embedding = hint_embedding.to(device)
        target_embedding = target_embedding.to(device)
        distractor_embedding = distractor_embedding.to(device)

        optimizer.zero_grad()
        anchor = model(hint_embedding)
        loss = criterion(anchor, target_embedding, distractor_embedding)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM)
        optimizer.step()

        total_loss += loss.item()

        similarity_to_target = cosine_similarity(anchor, target_embedding)
        similarity_to_distractor = cosine_similarity(anchor, distractor_embedding)

        correct_predictions = (similarity_to_target > similarity_to_distractor).sum().item()
        total_correct += correct_predictions
        total_samples += hint_embedding.size(0)

        all_similarities.append(similarity_to_target.mean().item() - similarity_to_distractor.mean().item())

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    avg_similarity_diff = np.mean(all_similarities)

    return avg_loss, accuracy, avg_similarity_diff

def validate_epoch(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    all_similarities = []

    with torch.no_grad():
        for hint_embedding, target_embedding, distractor_embedding in dataloader:
            hint_embedding = hint_embedding.to(device)
            target_embedding = target_embedding.to(device)
            distractor_embedding = distractor_embedding.to(device)

            anchor = model(hint_embedding)
            loss = criterion(anchor, target_embedding, distractor_embedding)
            total_loss += loss.item()

            similarity_to_target = cosine_similarity(anchor, target_embedding)
            similarity_to_distractor = cosine_similarity(anchor, distractor_embedding)

            correct_predictions = (similarity_to_target > similarity_to_distractor).sum().item()
            total_correct += correct_predictions
            total_samples += hint_embedding.size(0)

            all_similarities.append(similarity_to_target.mean().item() - similarity_to_distractor.mean().item())

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    avg_similarity_diff = np.mean(all_similarities)

    return avg_loss, accuracy, avg_similarity_diff

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser_train.json"
config = load_config(config_path)

csv_path = config["csv_path"]
cards_folder = config["cards_folder"]
checkpoint_dir = config["checkpoint_dir"]
embeddings_path = config["embeddings_path"]

BATCH_SIZE = 32
LEARNING_RATE = 0.0001 # Tried 0.001
EPOCHS = 50
DROPOUT_RATE = 0.6 # Increased from 0.5
VALIDATION_SPLIT = 0.2
EARLY_STOPPING_PATIENCE = 5
WEIGHT_DECAY = 1e-5 # Tried 1e-5
TRIPLET_MARGIN = 0.4  # Increased from 0.3
CLIP_GRAD_NORM = 1.0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

clip_model = CLIPModel.from_pretrained(config["model_name"]).to(device)
clip_processor = CLIPProcessor.from_pretrained(config["model_name"])

with open(embeddings_path, 'rb') as f:
    image_embeddings = pickle.load(f)

dataset = DixitDataset(csv_path, image_embeddings)
train_size = int((1 - VALIDATION_SPLIT) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

embedding_dim = next(iter(train_loader))[0].shape[1]
model = DixitModel(embedding_dim, DROPOUT_RATE).to(device)
criterion = nn.TripletMarginLoss(margin=TRIPLET_MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

start_epoch, best_val_loss = load_checkpoint(model, optimizer, checkpoint_dir)
early_stopping_counter = 0
train_losses, val_losses = [], []

train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

for epoch in range(start_epoch, EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train_loss, train_accuracy, train_similarity_diff = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_accuracy, val_similarity_diff = validate_epoch(model, val_loader, criterion)

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, "
          f"Training Similarity Difference: {train_similarity_diff:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, "
          f"Validation Similarity Difference: {val_similarity_diff:.4f}")

    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        early_stopping_counter = 0
    else:
        print(f"Early stopping counter: {early_stopping_counter}/{EARLY_STOPPING_PATIENCE}")
        early_stopping_counter += 1

    save_checkpoint({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss
    }, is_best, checkpoint_dir)

    scheduler.step(val_loss)

    if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
        print("Early stopping triggered.")
        break

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Training Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import CLIPProcessor, CLIPModel
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import nltk
import re
from collections import Counter
from nltk.corpus import stopwords
from sklearn.metrics import accuracy_score

with open(embeddings_path, 'rb') as f:
    image_embeddings = pickle.load(f)

def load_config(config_path):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

class DixitDataset(Dataset):
    def __init__(self, csv_path, image_embeddings, device, debug_size=None):
        self.data = pd.read_csv(csv_path)
        if debug_size:
            self.data = self.data.sample(n=debug_size, random_state=42)
        self.image_embeddings = image_embeddings
        self.device = device
        self.truncated_count = 0

    def __len__(self):
        return len(self.data)

    def keyword_extract(self, text):
        words = re.findall(r'\w+', text.lower())
        words = [word for word in words if word not in stop_words]
        most_common_words = [word for word, _ in Counter(words).most_common(77)]
        return ' '.join(most_common_words)

    def __getitem__(self, idx):
        hint = self.data.iloc[idx]['DESCRIPTION']
        target = int(self.data.iloc[idx]['TARGET']) - 1
        distractor_idxs = [idx for idx in range(len(self.image_embeddings)) if idx != target]
        distractor = torch.tensor(self.image_embeddings[np.random.choice(distractor_idxs)], device=self.device)

        hint = self.keyword_extract(hint)
        hint_inputs = clip_processor(
            text=hint,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=77
        ).to(self.device)

        with torch.no_grad():
            hint_embedding = clip_model.get_text_features(**hint_inputs)
            hint_embedding = F.normalize(hint_embedding, p=2, dim=-1)

        target_embedding = torch.tensor(self.image_embeddings[target], device=self.device)

        return hint_embedding.squeeze(0), target_embedding, distractor

class DixitModel(nn.Module):
    def __init__(self, embedding_dim):
        super(DixitModel, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, embedding_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(DROPOUT_RATE)
        self.fc2 = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, hint_embedding):
        output = self.fc1(hint_embedding)
        output = self.relu(output)
        output = self.dropout(output)
        output = self.fc2(output)
        output = F.normalize(output, p=2, dim=-1)
        return output

class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, anchor, positive, negative):
        if positive.dim() == 3:
            positive = positive.squeeze(1)
        if negative.dim() == 3:
            negative = negative.squeeze(1)

        assert anchor.dim() == 2 and positive.dim() == 2 and negative.dim() == 2

        pos_sim = F.cosine_similarity(anchor, positive, dim=1) / self.temperature
        neg_sim = F.cosine_similarity(anchor, negative, dim=1) / self.temperature

        logits = torch.stack([pos_sim, neg_sim], dim=1)

        labels = torch.zeros(anchor.size(0), dtype=torch.long, device=anchor.device)

        return F.cross_entropy(logits, labels)

def save_checkpoint(state, is_best, checkpoint_dir, filename="checkpoint.pth"):
    os.makedirs(checkpoint_dir, exist_ok=True)

    torch.save(state, os.path.join(checkpoint_dir, filename))
    if is_best:
        torch.save(state, os.path.join(checkpoint_dir, "best_model.pth"))

def load_checkpoint(model, optimizer, checkpoint_dir, filename="checkpoint.pth"):
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float('inf'))
        print(f"Loaded checkpoint '{checkpoint_path}' (epoch {start_epoch})")
        return start_epoch, best_val_loss
    else:
        print("No checkpoint found, starting from scratch.")
        return 0, float('inf')

def cosine_similarity(a, b):
    return F.cosine_similarity(a, b, dim=-1)

def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    similarity_diffs = []

    for hint_embedding, target_embedding, distractor_embedding in tqdm(dataloader):
        hint_embedding = hint_embedding.to(device)
        target_embedding = target_embedding.squeeze(1).to(device)
        distractor_embedding = distractor_embedding.squeeze(1).to(device)

        optimizer.zero_grad()

        transformed_hint = model(hint_embedding)

        loss = criterion(transformed_hint, target_embedding, distractor_embedding)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        pos_sim = F.cosine_similarity(transformed_hint, target_embedding)
        neg_sim = F.cosine_similarity(transformed_hint, distractor_embedding)
        correct_predictions = (pos_sim > neg_sim).sum().item()
        total_correct += correct_predictions
        total_samples += hint_embedding.size(0)

        similarity_diffs.append((pos_sim - neg_sim).mean().item())

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    avg_similarity_diff = np.mean(similarity_diffs)
    return avg_loss, accuracy, avg_similarity_diff

def validate_epoch(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    similarity_diffs = []

    with torch.no_grad():
        for hint_embedding, target_embedding, distractor_embedding in dataloader:
            hint_embedding = hint_embedding.to(device)
            target_embedding = target_embedding.squeeze(1).to(device)
            distractor_embedding = distractor_embedding.squeeze(1).to(device)

            transformed_hint = model(hint_embedding)
            loss = criterion(transformed_hint, target_embedding, distractor_embedding)
            total_loss += loss.item()

            pos_sim = F.cosine_similarity(transformed_hint, target_embedding)
            neg_sim = F.cosine_similarity(transformed_hint, distractor_embedding)
            correct_predictions = (pos_sim > neg_sim).sum().item()
            total_correct += correct_predictions
            total_samples += hint_embedding.size(0)

            similarity_diffs.append((pos_sim - neg_sim).mean().item())

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    avg_similarity_diff = np.mean(similarity_diffs)
    return avg_loss, accuracy, avg_similarity_diff
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

config_path = "/content/drive/My Drive/Colab Notebooks/dixit/config_guesser_train.json"
config = load_config(config_path)

csv_path = config["csv_path"]
cards_folder = config["cards_folder"]
checkpoint_dir = config["checkpoint_contrastiveloss_dir"]
embeddings_path = config["embeddings_path"]

BATCH_SIZE = 32
LEARNING_RATE = 0.001 # Tried 0.001
EPOCHS = 50
DROPOUT_RATE = 0.5 # Increased from 0.5
VALIDATION_SPLIT = 0.2
EARLY_STOPPING_PATIENCE = 5
WEIGHT_DECAY = 1e-5 # Tried 1e-5
TRIPLET_MARGIN = 0.4  # Increased from 0.3
CLIP_GRAD_NORM = 1.0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

clip_model = CLIPModel.from_pretrained(config["model_name"]).to(device)
clip_processor = CLIPProcessor.from_pretrained(config["model_name"])

dataset = DixitDataset(csv_path, image_embeddings, device)
train_size = int((1 - VALIDATION_SPLIT) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

embedding_dim = next(iter(train_loader))[0].shape[1]
model = DixitModel(embedding_dim).to(device)
criterion = NTXentLoss(temperature=0.5)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

start_epoch, best_val_loss = load_checkpoint(model, optimizer, checkpoint_dir)

early_stopping_counter = 0

train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

for epoch in range(start_epoch, EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train_loss, train_accuracy, train_similarity_diff = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_accuracy, val_similarity_diff = validate_epoch(model, val_loader, criterion)

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, "
          f"Training Similarity Difference: {train_similarity_diff:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, "
          f"Validation Similarity Difference: {val_similarity_diff:.4f}")

    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        print(f"Early stopping counter: {early_stopping_counter}/{EARLY_STOPPING_PATIENCE}")

    save_checkpoint({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss
    }, is_best, checkpoint_dir)

    scheduler.step(val_loss)

    if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
        print("Early stopping triggered.")
        break

plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Training Accuracy')
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()

plt.tight_layout()
plt.show()
