In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.models import ResNet18_Weights
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.utils import resample
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import json
import os
import logging
from tqdm import tqdm


logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


torch.manual_seed(42)
np.random.seed(42)

all_genres = [
    "Action", "Adventure", "Animation", "Comedy", "Crime", "Documentary", "Drama",
    "Family", "Fantasy", "History", "Horror", "Music", "Mystery", "Romance",
    "Science Fiction", "TV Movie", "Thriller", "War", "Western"
]


logger.info("Loading datasets...")
train_df = pd.read_json("train_fixed_updated.json")
val_df = pd.read_json("val_fixed_updated.json")
test_df = pd.read_json("test_fixed_updated.json")
print(f"Loaded Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")



logger.info("Oversampling rare genres...")
rare_genres = ["TV Movie", "History", "Mystery", "Adventure", "Fantasy", "Crime", "Music", "War", "Western", "Documentary", "Science Fiction", "Animation", "Family"]
oversampled_dfs = [train_df]
for genre in rare_genres:
    genre_df = train_df[train_df["genre_labels"].apply(lambda x: x[all_genres.index(genre)] == 1)]
    n_samples = len(train_df)//15 
    oversampled = resample(genre_df, replace=True, n_samples=n_samples, random_state=42)
    oversampled_dfs.append(oversampled)
train_df_oversampled = pd.concat(oversampled_dfs)
print(f"Oversampled training set: {len(train_df_oversampled)} samples")


class MovieImageDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.df.iloc[idx]["poster_path"].lstrip("/"))
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            logger.warning(f"Error loading image {img_path}: {e}")
            image = Image.new("RGB", (224, 224), (0, 0, 0))
        
        labels = np.array(self.df.iloc[idx]["genre_labels"], dtype=np.float32)
        
        if self.transform:
            image = self.transform(image)
        
        return image, torch.tensor(labels, dtype=torch.float32)


train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


image_dir = "images/"
logger.info("Creating datasets...")
train_dataset = MovieImageDataset(train_df_oversampled, image_dir, train_transform)
val_dataset = MovieImageDataset(val_df, image_dir, val_test_transform)
test_dataset = MovieImageDataset(test_df, image_dir, val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=16, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=16, num_workers=0)


class ImageMovieGenreClassifier(nn.Module):
    def __init__(self, num_classes=19):
        super(ImageMovieGenreClassifier, self).__init__()
        self.resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.resnet(x)
        return self.sigmoid(x)

def compute_class_weights(labels):
    n_samples = labels.shape[0]
    n_classes = labels.shape[1]
    class_counts = np.sum(labels, axis=0)
    weights = n_samples / (n_classes * class_counts)
    for i, genre in enumerate(all_genres):
        if genre in ["TV Movie", "History", "Mystery", "Crime", "Music", "War", "Western", "Documentary", "Science Fiction", "Animation", "Family"]:
            weights[i] *= 2.5  
    return torch.tensor(weights, dtype=torch.float32)

logger.info("Computing class weights...")
train_labels = np.array(train_df_oversampled["genre_labels"].tolist())
device = torch.device("cpu")
logger.info(f"Using device: {device}")
class_weights = compute_class_weights(train_labels).to(device)
np.save("image_class_weights.npy", class_weights.cpu().numpy())


class LabelSmoothingBCELoss(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingBCELoss, self).__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        target_smooth = target * (1 - self.smoothing) + self.smoothing / 2
        return nn.BCELoss(weight=class_weights)(pred, target_smooth)


def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10, patience=3):
    best_val_f1 = 0
    trigger = 0
    train_losses, val_losses, val_accuracies = [], [], []
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        logger.info(f"Starting epoch {epoch+1}/{epochs}")
        for i, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} Training")):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_losses.append(train_loss/len(train_loader))
        
        model.eval()
        val_loss = 0
        val_preds, val_true = [], []
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_preds.append(outputs.cpu().numpy())
                val_true.append(labels.cpu().numpy())
        
        val_losses.append(val_loss/len(val_loader))
        val_preds = np.concatenate(val_preds)
        val_true = np.concatenate(val_true)
        
       
        thresholds = np.arange(0.01, 0.3, 0.05) 
        best_val_preds_binary = np.zeros_like(val_preds)
        for i, genre in enumerate(all_genres):
            best_f1, best_thresh = 0, 0.5
            thresh_range = thresholds if genre in ["TV Movie", "History", "Mystery", "Crime", "Music", "War", "Western", "Documentary", "Science Fiction", "Animation", "Family"] else np.arange(0.1, 0.6, 0.1)
            for thresh in thresh_range:
                preds_binary = (val_preds[:, i] > thresh).astype(int)
                f1 = f1_score(val_true[:, i], preds_binary)
                if f1 > best_f1:
                    best_f1, best_thresh = f1, thresh
            best_val_preds_binary[:, i] = (val_preds[:, i] > best_thresh).astype(int)
        
        val_f1_macro = f1_score(val_true, best_val_preds_binary, average="macro")
        val_accuracy = (val_true == best_val_preds_binary).mean()
        val_accuracies.append(val_accuracy)
        
       
        positive_preds = np.sum(best_val_preds_binary, axis=0)
        logger.info(f"Positive predictions per genre: {positive_preds}")
        
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}, Val F1 Macro: {val_f1_macro:.4f}, Val Accuracy: {val_accuracy:.4f}")
        
        scheduler.step(val_loss)
        if val_f1_macro > best_val_f1:
            best_val_f1 = val_f1_macro
            torch.save(model.state_dict(), "best_image_model.pth")
            trigger = 0
        else:
            trigger += 1
            if trigger >= patience:
                print("Early stopping triggered")
                break
    
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Image Model: Training and Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig("loss_curves_image.png", dpi=300)
    plt.close()
    
    plt.figure(figsize=(10, 6))
    plt.plot(val_accuracies, label="Val Accuracy")
    plt.ylabel("Accuracy")
    plt.title("Image Model: Validation Accuracy")
    plt.legend()
    plt.grid(True)
    plt.savefig("accuracy_curves_image.png", dpi=300)
    plt.close()
    
    return best_val_f1


def evaluate_model(model, loader):
    model.eval()
    preds, true = [], []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds.append(outputs.cpu().numpy())
            true.append(labels.cpu().numpy())
    
    preds = np.concatenate(preds)
    true = np.concatenate(true)
    
    thresholds = np.arange(0.01, 0.8, 0.05)
    best_thresholds = []
    for i, genre in enumerate(all_genres):
        thresh_range = thresholds if genre in ["TV Movie", "History", "Mystery", "Crime", "Music", "War", "Western", "Documentary", "Science Fiction", "Animation", "Family"] else np.arange(0.1, 0.9, 0.1)
        best_f1, best_thresh = 0, 0.5
        for thresh in thresh_range:
            preds_binary = (preds[:, i] > thresh).astype(int)
            f1 = f1_score(true[:, i], preds_binary)
            if f1 > best_f1:
                best_f1, best_thresh = f1, thresh
        best_thresholds.append(best_thresh)
    
    preds_binary = np.zeros_like(preds)
    for i, thresh in enumerate(best_thresholds):
        preds_binary[:, i] = (preds[:, i] > thresh).astype(int)
    
    f1_macro = f1_score(true, preds_binary, average="macro")
    f1_micro = f1_score(true, preds_binary, average="micro")
    precision = precision_score(true, preds_binary, average="macro")
    recall = recall_score(true, preds_binary, average="macro")
    accuracy = (true == preds_binary).mean()
    
    print(f"Image Model Evaluation:")
    print(f"F1 Macro: {f1_macro:.4f}, F1 Micro: {f1_micro:.4f}, Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}")
    
    f1_scores, accuracy_scores = [], []
    for i, genre in enumerate(all_genres):
        f1 = f1_score(true[:, i], preds_binary[:, i])
        acc = accuracy_score(true[:, i], preds_binary[:, i])
        f1_scores.append(f1)
        accuracy_scores.append(acc)
        print(f"{genre} - F1: {f1:.4f}, Accuracy: {acc:.4f}")
    
   
    plt.figure(figsize=(12, 6))
    sns.barplot(x=f1_scores, y=all_genres, palette="viridis")
    plt.xlabel("F1 Score")
    plt.ylabel("Genre")
    plt.title("Image Model: F1 Scores by Genre")
    plt.xlim(0, 1)
    for i, score in enumerate(f1_scores):
        plt.text(score + 0.01, i, f"{score:.4f}", va="center")
    plt.tight_layout()
    plt.savefig("f1_scores_image.png", dpi=300)
    plt.close()
    
    
    plt.figure(figsize=(12, 6))
    sns.barplot(x=accuracy_scores, y=all_genres, palette="magma")
    plt.xlabel("Accuracy")
    plt.ylabel("Genre")
    plt.title("Image Model: Accuracy by Genre")
    plt.xlim(0, 1)
    for i, score in enumerate(accuracy_scores):
        plt.text(score + 0.01, i, f"{score:.4f}", va="center")
    plt.tight_layout()
    plt.savefig("accuracy_scores_image.png", dpi=300)
    plt.close()
    
   
    cm = confusion_matrix(true[:, all_genres.index("Action")], preds_binary[:, all_genres.index("Action")])
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
                xticklabels=["Not Action", "Action"], yticklabels=["Not Action", "Action"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Image Model: Confusion Matrix for Action")
    plt.tight_layout()
    plt.savefig("cm_image_action.png", dpi=300)
    plt.close()
    
   
    errors = val_df.copy()
    errors["pred_genres"] = [np.where(pred==1)[0] for pred in preds_binary]
    errors["true_genres"] = [np.where(true==1)[0] for true in true]
    errors["correct"] = errors.apply(lambda x: set(x["pred_genres"]) == set(x["true_genres"]), axis=1)
    incorrect = errors[~errors["correct"]]
    print("\nSample misclassifications:")
    for _, row in incorrect.head(5).iterrows():
        print(f"Movie: {row['title']}")
        print(f"True Genres: {[all_genres[i] for i in row['true_genres']]}")
        print(f"Pred Genres: {[all_genres[i] for i in row['pred_genres']]}")
        print(f"Poster Path: {row['poster_path']}")
        print()
    
    return preds_binary, true, best_thresholds


logger.info("Initializing model...")
model = ImageMovieGenreClassifier().to(device)
criterion = LabelSmoothingBCELoss(smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)

print("Training Image-based CNN")
best_f1_image = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10, patience=3)


logger.info("Evaluating model...")
model.load_state_dict(torch.load("best_image_model.pth"))
val_preds_binary, val_true, best_thresholds = evaluate_model(model, val_loader)

np.save("val_image_preds.npy", val_preds_binary)
np.save("val_image_true.npy", val_true)
np.save("image_best_thresholds.npy", best_thresholds)

hyperparams = {
    "model": "ResNet18",
    "num_classes": 19,
    "learning_rate": 0.001,
    "batch_size": 16,
    "epochs": 10,
    "dropout": 0.0,
    "label_smoothing": 0.1
}
with open("image_model_hyperparams.json", "w") as f:
    json.dump(hyperparams, f)

print("Best image model saved to best_image_model.pth")

2025-04-30 00:34:24,321 - INFO - Loading datasets...
2025-04-30 00:34:25,634 - INFO - Oversampling rare genres...


Loaded Train: 63639, Val: 7981, Test: 7963


2025-04-30 00:34:25,880 - INFO - Creating datasets...
2025-04-30 00:34:25,889 - INFO - Computing class weights...
2025-04-30 00:34:25,989 - INFO - Using device: cpu
2025-04-30 00:34:25,993 - INFO - Initializing model...


Oversampled training set: 118785 samples


2025-04-30 00:34:26,293 - INFO - Starting epoch 1/10


Training Image-based CNN


Epoch 1 Training: 100%|██████████| 7425/7425 [1:01:49<00:00,  2.00it/s]
Epoch 1 Validation: 100%|██████████| 499/499 [02:03<00:00,  4.02it/s]
2025-04-30 01:38:19,892 - INFO - Positive predictions per genre: [1507. 1360. 1800. 3261. 1654. 1464. 5241. 1630.  539.  386. 3195. 1727.
 1446. 1574.  799. 1384. 1817.  710.  468.]
2025-04-30 01:38:19,924 - INFO - Starting epoch 2/10


Epoch 1/10, Train Loss: 0.3927, Val Loss: 0.3540, Val F1 Macro: 0.3068, Val Accuracy: 0.8047


Epoch 2 Training: 100%|██████████| 7425/7425 [1:02:00<00:00,  2.00it/s]
Epoch 2 Validation: 100%|██████████| 499/499 [02:05<00:00,  3.96it/s]
2025-04-30 02:42:26,256 - INFO - Positive predictions per genre: [1727. 1077.  579. 3592. 1430. 1769. 6088.  713.  545.  486. 1038. 1343.
  879. 2556.  564.  928. 2259.  579.  401.]
2025-04-30 02:42:26,290 - INFO - Starting epoch 3/10


Epoch 2/10, Train Loss: 0.3789, Val Loss: 0.3422, Val F1 Macro: 0.3477, Val Accuracy: 0.8321


Epoch 3 Training: 100%|██████████| 7425/7425 [1:02:06<00:00,  1.99it/s]
Epoch 3 Validation: 100%|██████████| 499/499 [02:03<00:00,  4.05it/s]
2025-04-30 03:46:36,397 - INFO - Positive predictions per genre: [1397.  663.  597. 3068. 2166. 1588. 6135. 1017.  806.  533. 1103. 1619.
 1639. 1770.  586.  657. 1763.  944.  300.]
2025-04-30 03:46:36,428 - INFO - Starting epoch 4/10


Epoch 3/10, Train Loss: 0.3693, Val Loss: 0.3475, Val F1 Macro: 0.3510, Val Accuracy: 0.8311


Epoch 4 Training: 100%|██████████| 7425/7425 [1:02:11<00:00,  1.99it/s]
Epoch 4 Validation: 100%|██████████| 499/499 [02:03<00:00,  4.03it/s]
2025-04-30 04:50:52,329 - INFO - Positive predictions per genre: [1940. 1454.  536. 2966. 1523. 1253. 5663.  822.  717.  554.  884. 1211.
 1097. 2344.  725. 1067. 1673.  998.  413.]
2025-04-30 04:50:52,331 - INFO - Starting epoch 5/10


Epoch 4/10, Train Loss: 0.3577, Val Loss: 0.3517, Val F1 Macro: 0.3485, Val Accuracy: 0.8322


Epoch 5 Training: 100%|██████████| 7425/7425 [1:02:23<00:00,  1.98it/s]
Epoch 5 Validation: 100%|██████████| 499/499 [02:03<00:00,  4.03it/s]
2025-04-30 05:55:20,380 - INFO - Positive predictions per genre: [1397. 1073.  567. 3052. 2151. 1253. 5667.  704.  764.  527. 1129. 1107.
  718. 2774. 1057. 1168. 1674.  346.  298.]
2025-04-30 05:55:20,382 - INFO - Starting epoch 6/10


Epoch 5/10, Train Loss: 0.3438, Val Loss: 0.3509, Val F1 Macro: 0.3386, Val Accuracy: 0.8333


Epoch 6 Training: 100%|██████████| 7425/7425 [1:02:59<00:00,  1.96it/s]
Epoch 6 Validation: 100%|██████████| 499/499 [02:07<00:00,  3.93it/s]
2025-04-30 07:00:26,682 - INFO - Positive predictions per genre: [2132. 1126.  573. 3091. 2280. 1202. 6413.  656.  767.  416.  963. 1030.
 1121. 1932.  755.  914. 1611.  275.  267.]


Epoch 6/10, Train Loss: 0.3181, Val Loss: 0.3560, Val F1 Macro: 0.3437, Val Accuracy: 0.8353
Early stopping triggered


2025-04-30 07:00:27,160 - INFO - Evaluating model...
Evaluation: 100%|██████████| 499/499 [02:06<00:00,  3.94it/s]


Image Model Evaluation:
F1 Macro: 0.3567, F1 Micro: 0.4286, Accuracy: 0.8377
Precision: 0.3104, Recall: 0.4573
Action - F1: 0.3933, Accuracy: 0.8129
Adventure - F1: 0.2869, Accuracy: 0.8897
Animation - F1: 0.6636, Accuracy: 0.9629
Comedy - F1: 0.5793, Accuracy: 0.7097
Crime - F1: 0.2957, Accuracy: 0.7875
Documentary - F1: 0.3643, Accuracy: 0.8146
Drama - F1: 0.6262, Accuracy: 0.5578
Family - F1: 0.4030, Accuracy: 0.9295
Fantasy - F1: 0.2458, Accuracy: 0.8831
History - F1: 0.1439, Accuracy: 0.9120
Horror - F1: 0.4393, Accuracy: 0.8570
Music - F1: 0.2816, Accuracy: 0.7999
Mystery - F1: 0.1700, Accuracy: 0.7871
Romance - F1: 0.3937, Accuracy: 0.7750
Science Fiction - F1: 0.3054, Accuracy: 0.9105
TV Movie - F1: 0.2342, Accuracy: 0.8984
Thriller - F1: 0.4062, Accuracy: 0.7835
War - F1: 0.1372, Accuracy: 0.8723
Western - F1: 0.4077, Accuracy: 0.9731



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x=f1_scores, y=all_genres, palette="viridis")

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x=accuracy_scores, y=all_genres, palette="magma")



Sample misclassifications:
Movie: Destruction Force
True Genres: ['Crime']
Pred Genres: ['Action', 'Crime', 'Drama']
Poster Path: /hLQxmeMTmxgn4HjraPb6kkxgd9N.jpg

Movie: The Perez Family
True Genres: ['Comedy', 'Drama', 'Romance']
Pred Genres: ['Comedy', 'Drama', 'Music', 'Romance']
Poster Path: /jxi3UDHf86kDk5hqJxnpEMzVch2.jpg

Movie: Queensrÿche: The Art of Live
True Genres: ['Music']
Pred Genres: ['Documentary', 'Music']
Poster Path: /krJsYaLovn1JgkjprEu6DWFWrdh.jpg

Movie: The Totenwackers
True Genres: ['Adventure', 'Comedy']
Pred Genres: ['Drama', 'Romance']
Poster Path: /vj45dWwcwBlbp1COcthiJAiTN8Y.jpg

Movie: 400 Against 1: A History of Organized Crime
True Genres: ['Action', 'Crime', 'Drama']
Pred Genres: ['Documentary', 'Drama', 'Music']
Poster Path: /vYjpX0uLBCtjJBf42zJk4OxepPx.jpg

Best image model saved to best_image_model.pth
