In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tensorflow.keras.preprocessing.text import Tokenizer, tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
from torchvision import models, transforms
from PIL import Image
import os
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import json


logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


device = torch.device("cpu")
logging.info(f"Using device: {device}")


all_genres = ["Action", "Adventure", "Animation", "Comedy", "Crime", "Documentary", "Drama",
              "Family", "Fantasy", "History", "Horror", "Music", "Mystery", "Romance",
              "Science Fiction", "TV Movie", "Thriller", "War", "Western"]


logging.info("Loading tokenizer...")
with open("tokenizer.json", "r") as f:
    tokenizer_config = json.load(f)
tokenizer = tokenizer_from_json(tokenizer_config)
max_len = 100

logging.info("Loading datasets...")
train_df = pd.read_json("train_fixed_updated.json")
val_df = pd.read_json("val_fixed_updated.json")
test_df = pd.read_json("test_fixed_updated.json")
logging.info(f"Loaded Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")


class ImageMovieGenreClassifier(nn.Module):
    def __init__(self, num_classes=19):
        super().__init__()
        self.resnet = models.resnet18(weights=None)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        features = self.resnet.conv1(x)
        features = self.resnet.bn1(features)
        features = self.resnet.relu(features)
        features = self.resnet.maxpool(features)
        features = self.resnet.layer1(features)
        features = self.resnet.layer2(features)
        features = self.resnet.layer3(features)
        features = self.resnet.layer4(features)
        features = self.resnet.avgpool(features)
        features = torch.flatten(features, 1)
        outputs = self.resnet.fc(features)
        return self.sigmoid(outputs), features


class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)
    def forward(self, lstm_output):
        energy = self.tanh(self.attn(lstm_output))
        attention = self.softmax(torch.matmul(energy, self.v.unsqueeze(1)).squeeze(2))
        context = torch.bmm(attention.unsqueeze(1), lstm_output).squeeze(1)
        return context


class AttentionMovieGenreClassifier(nn.Module):
    def __init__(self, vocab_size=10000, embedding_dim=128, hidden_dim=128, output_dim=19):
        super(AttentionMovieGenreClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True, dropout=0.0, bidirectional=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.6)
        self.sigmoid = nn.Sigmoid()
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)
    def forward(self, text):
        embedded = self.embedding(text)
        lstm_output, (hidden, cell) = self.lstm(embedded)
        context = self.attention(lstm_output)
        dense = self.dropout(context)
        output = self.fc(dense)
        output = self.sigmoid(output / self.temperature)
        return output, context

class FusionModel(nn.Module):
    def __init__(self, image_in_features=512, text_in_features=256, num_classes=19):
        super().__init__()
        self.fc1 = nn.Linear(image_in_features + text_in_features, 256)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)
        self.sigmoid = nn.Sigmoid()
    def forward(self, image_features, text_features):
        x = torch.cat((image_features, text_features), dim=1)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return self.sigmoid(x)


class FusionDataset(Dataset):
    def __init__(self, df, image_dir, tokenizer, max_len=100):
        self.df = df
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.df.iloc[idx]["poster_path"].lstrip("/"))
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            logging.warning(f"Error loading image {img_path}: {e}")
            image = Image.new("RGB", (224, 224), (0, 0, 0))
        image = self.transform(image)
        text = self.df.iloc[idx]["processed_overview"] or ""
        sequence = self.tokenizer.texts_to_sequences([text])[0]
        padded = pad_sequences([sequence], maxlen=self.max_len, padding="post", truncating="post")
        text_tensor = torch.tensor(padded, dtype=torch.long).squeeze()
        labels = np.array(self.df.iloc[idx]["genre_labels"], dtype=np.float32)
        return image, text_tensor, torch.tensor(labels)


def evaluate_model(image_model, text_model, fusion_model, loader):
    image_model.eval()
    text_model.eval()
    fusion_model.eval()
    preds, true = [], []
    with torch.no_grad():
        for images, texts, labels in tqdm(loader, desc="Evaluation"):
            images = images.to(device)
            texts = texts.to(device)
            labels = labels.to(device)
            _, image_features = image_model(images)
            _, text_features = text_model(texts)
            outputs = fusion_model(image_features, text_features)
            preds.append(outputs.cpu().numpy())
            true.append(labels.cpu().numpy())
    preds = np.concatenate(preds)
    true = np.concatenate(true)
    preds_binary = (preds > 0.5).astype(int)
    f1_macro = f1_score(true, preds_binary, average="macro")
    f1_micro = f1_score(true, preds_binary, average="micro")
    accuracy = accuracy_score(true, preds_binary)
    precision = np.mean([f1_score(true[:, i], preds_binary[:, i], average="binary", zero_division=0) for i in range(len(all_genres))])
    recall = np.mean([f1_score(true[:, i], preds_binary[:, i], average="binary", zero_division=0) for i in range(len(all_genres))])
    return preds_binary, true, {"f1_macro": f1_macro, "f1_micro": f1_micro, "accuracy": accuracy, "precision": precision, "recall": recall}


def train_fusion_model(image_model, text_model, fusion_model, train_loader, val_loader, epochs=7):
    optimizer = torch.optim.AdamW(fusion_model.parameters(), lr=0.001)
    class_weights = torch.ones(19).to(device)
    class_weights[15] = 2.0
    class_weights[9] = 2.0  
    criterion = nn.BCELoss(weight=class_weights)
    best_f1 = 0
    patience = 3
    patience_counter = 0
    metrics_history = {"train_loss": [], "val_f1_macro": [], "val_accuracy": []}
    
    for epoch in range(epochs):
        fusion_model.train()
        train_loss = 0
        for images, texts, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            images = images.to(device)
            texts = texts.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                _, image_features = image_model(images)
                _, text_features = text_model(texts)
            outputs = fusion_model(image_features, text_features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        metrics_history["train_loss"].append(train_loss)
        
        val_preds_binary, val_true, metrics = evaluate_model(image_model, text_model, fusion_model, val_loader)
        metrics_history["val_f1_macro"].append(metrics["f1_macro"])
        metrics_history["val_accuracy"].append(metrics["accuracy"])
        
        logging.info(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, "
                     f"Val F1 Macro: {metrics['f1_macro']:.4f}, Val Accuracy: {metrics['accuracy']:.4f}")
        
        if metrics["f1_macro"] > best_f1:
            best_f1 = metrics["f1_macro"]
            torch.save(fusion_model.state_dict(), "best_fusion_model.pth")
            torch.save(fusion_model.state_dict(), f"fusion_model_epoch_{epoch+1}.pth")
            np.save("val_fusion_preds.npy", val_preds_binary)
            np.save("val_fusion_true.npy", val_true)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                logging.info("Early stopping triggered")
                break
    
   
    plt.figure(figsize=(10, 5))
    plt.plot(metrics_history["train_loss"], label="Train Loss")
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig("loss_curves_fusion.png")
    plt.close()
    
    plt.figure(figsize=(10, 5))
    plt.plot(metrics_history["val_f1_macro"], label="Val F1 Macro")
    plt.title("Validation F1 Macro")
    plt.xlabel("Epoch")
    plt.ylabel("F1 Macro")
    plt.legend()
    plt.savefig("f1_scores_fusion.png")
    plt.close()
    
    plt.figure(figsize=(10, 5))
    plt.plot(metrics_history["val_accuracy"], label="Val Accuracy")
    plt.title("Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig("accuracy_scores_fusion.png")
    plt.close()
    
    return best_f1


def final_evaluation(image_model, text_model, fusion_model, test_loader):
    preds_binary, true, metrics = evaluate_model(image_model, text_model, fusion_model, test_loader)
    logging.info("Fusion Model Evaluation:")
    logging.info(f"F1 Macro: {metrics['f1_macro']:.4f}, F1 Micro: {metrics['f1_micro']:.4f}, "
                 f"Accuracy: {metrics['accuracy']:.4f}, Precision: {metrics['precision']:.4f}, "
                 f"Recall: {metrics['recall']:.4f}")
    for i, genre in enumerate(all_genres):
        f1 = f1_score(true[:, i], preds_binary[:, i], average="binary", zero_division=0)
        logging.info(f"{genre} F1: {f1:.4f}")
    return preds_binary, true, metrics


if __name__ == "__main__":
   
    image_model = ImageMovieGenreClassifier().to(device)
    text_model = AttentionMovieGenreClassifier().to(device)
    fusion_model = FusionModel().to(device)
    
    
    try:
        image_model.load_state_dict(torch.load("best_image_model.pth", map_location=device))
        logging.info("Loaded image model weights")
    except FileNotFoundError:
        logging.warning("Image model weights not found. Initialize with random weights.")
    
    try:
        text_model.load_state_dict(torch.load("best_text_model.pth", map_location=device))
        logging.info("Loaded text model weights")
    except FileNotFoundError:
        logging.warning("Text model weights not found. Initialize with random weights.")
    
    
    train_dataset = FusionDataset(train_df, "images/", tokenizer, max_len)
    val_dataset = FusionDataset(val_df, "images/", tokenizer, max_len)
    test_dataset = FusionDataset(test_df, "images/", tokenizer, max_len)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=16, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=16, num_workers=0)
    
   
    logging.info("Training fusion model...")
    best_f1 = train_fusion_model(image_model, text_model, fusion_model, train_loader, val_loader, epochs=7)
    logging.info(f"Best validation F1 Macro: {best_f1:.4f}")
    
    
    logging.info("Evaluating on test set...")
    fusion_model.load_state_dict(torch.load("best_fusion_model.pth", map_location=device))
    test_preds_binary, test_true, test_metrics = final_evaluation(image_model, text_model, fusion_model, test_loader)
    
    
    np.save("test_fusion_preds.npy", test_preds_binary)
    np.save("test_fusion_true.npy", test_true)
    with open("test_fusion_metrics.json", "w") as f:
        json.dump(test_metrics, f)

2025-04-30 16:47:45,213 - INFO - Using device: cpu


2025-04-30 16:47:45,214 - INFO - Loading tokenizer...
2025-04-30 16:47:45,416 - INFO - Loading datasets...
2025-04-30 16:47:46,776 - INFO - Loaded Train: 63639, Val: 7981, Test: 7963
2025-04-30 16:47:46,873 - INFO - Loaded image model weights
2025-04-30 16:47:46,878 - INFO - Loaded text model weights
2025-04-30 16:47:47,102 - INFO - Training fusion model...
Epoch 1 Training: 100%|██████████| 3978/3978 [18:51<00:00,  3.52it/s]
Evaluation: 100%|██████████| 499/499 [02:16<00:00,  3.65it/s]
2025-04-30 17:08:55,285 - INFO - Epoch 1/7, Train Loss: 0.1617, Val F1 Macro: 0.4989, Val Accuracy: 0.2191
Epoch 2 Training: 100%|██████████| 3978/3978 [18:10<00:00,  3.65it/s]
Evaluation: 100%|██████████| 499/499 [02:15<00:00,  3.69it/s]
2025-04-30 17:29:20,942 - INFO - Epoch 2/7, Train Loss: 0.1503, Val F1 Macro: 0.4845, Val Accuracy: 0.2322
Epoch 3 Training: 100%|██████████| 3978/3978 [18:10<00:00,  3.65it/s]
Evaluation: 100%|██████████| 499/499 [02:15<00:00,  3.68it/s]
2025-04-30 17:49:47,198 - INFO