In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
import logging
import time
import ast # For safely evaluating string representations of lists
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split # For splitting data
from torchvision import models, transforms
from PIL import Image
from transformers import BertTokenizer, BertModel

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("simple_tuned_genre_model_csv.log"), # Updated log file name
        logging.StreamHandler()
    ]
)

# Helper function to calculate F1 score safely
def calculate_f1(targets, predictions, average='macro'):
    return f1_score(targets, predictions, average=average, zero_division=0)

# DATASET

class MovieGenreDataset(Dataset):
    """
    Simple Dataset for multi-modal movie genre classification.
    Handles text (overview), image (poster), and genre labels.
    """
    # Modified to accept poster_base_path for constructing full image paths
    def __init__(self, df, tokenizer, all_genres, poster_base_path, max_length=200, transform=None):
        self.df = df
        self.tokenizer = tokenizer
        self.all_genres = all_genres
        self.genre_to_idx = {genre: i for i, genre in enumerate(all_genres)}
        self.num_genres = len(all_genres)
        self.max_length = max_length
        self.transform = transform
        self.poster_base_path = poster_base_path # Store base path for posters

        # Ensure required columns exist (overview should be there based on CSV)
        if 'poster_path' not in df.columns:
            logging.warning("Column 'poster_path' not found. Using placeholder images.")
            df['poster_path'] = None
        if 'overview' not in df.columns:
            logging.warning("Column 'overview' not found. Using empty text.")
            df['overview'] = ""
        if 'groups' not in df.columns:
             raise ValueError("Dataset requires a 'groups' column containing lists of genres.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # 1. Process Text (Overview)
        text = row['overview']
        if not isinstance(text, str):
            text = "" if pd.isna(text) else str(text)

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        # 2. Process Image (Poster)
        relative_poster_path = row.get('poster_path', None) 
        full_poster_path = None
        if relative_poster_path and isinstance(relative_poster_path, str):
            relative_poster_path = relative_poster_path.lstrip('/') 
            full_poster_path = os.path.join(self.poster_base_path, relative_poster_path)

        image_placeholder = torch.full((3, 224, 224), 0.5)
        if self.transform:
             norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
             image_placeholder = norm(image_placeholder)

        image = image_placeholder 
        try:
            if full_poster_path and os.path.exists(full_poster_path):
                loaded_image = Image.open(full_poster_path).convert('RGB')
                if self.transform:
                    image = self.transform(loaded_image)
            elif relative_poster_path:
                 logging.debug(f"Poster not found at expected path: {full_poster_path}. Using placeholder.")

        except Exception as e:
            logging.warning(f"Error loading image {full_poster_path} for index {idx}: {e}. Using placeholder.")
            # Ensure placeholder is used in case of error during loading/transform
            image = image_placeholder

        # 3. Process Labels (Genres) 
        target_genres = torch.zeros(self.num_genres)
        try:
            movie_genres = row['groups']
            if isinstance(movie_genres, list):
                for genre in movie_genres:
                    if genre in self.genre_to_idx:
                        genre_idx = self.genre_to_idx[genre]
                        target_genres[genre_idx] = 1
            else:
                 logging.warning(f"Data in 'groups' column is not a list for index {idx}. Using empty labels. Value: {movie_genres}")
        except KeyError:
            logging.error(f"Critical error: Column 'groups' not found in DataFrame for index {idx}.")
        except Exception as e:
            logging.error(f"Unexpected error processing genres for index {idx}: {e}")


        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'image': image,
            'labels': target_genres
        }

    # Calculates class weights
    def calculate_class_weights(self):
        """Calculates class weights for handling imbalance."""
        genre_counts = torch.zeros(self.num_genres)
        # Iterate through dataframe to count genre occurrences
        for _, row in self.df.iterrows():
             # Assumes 'groups' column exists and contains lists
            if 'groups' in row and isinstance(row['groups'], list):
                for genre in row['groups']:
                    if genre in self.genre_to_idx:
                        genre_counts[self.genre_to_idx[genre]] += 1

        # Calculate positive weights: total_samples / (num_classes * class_count)
        total_samples = len(self.df)
        pos_weight = total_samples / (self.num_genres * (genre_counts + 1e-6)) # Add epsilon
        pos_weight = torch.clamp(pos_weight, min=1.0, max=50.0) 
        logging.info(f"Calculated pos_weight for BCEWithLogitsLoss: Min={pos_weight.min():.2f}, Max={pos_weight.max():.2f}, Mean={pos_weight.mean():.2f}")
        return pos_weight

# MODEL COMPONENTS

class SimpleMultiModalClassifier(nn.Module):
    """
    Simplified multi-modal classifier for genre prediction.
    Combines features from BERT (text) and a CNN (image).
    """
    #
    def __init__(self, num_genres, text_model_name='bert-base-uncased', img_model_name='resnet18', hidden_dim=512, dropout=0.3):
        super().__init__()
        self.num_genres = num_genres

        # 1. Text Encoder (BERT)
        self.bert = BertModel.from_pretrained(text_model_name)
        self.bert_dropout = nn.Dropout(dropout)
        bert_output_dim = self.bert.config.hidden_size

        # 2. Image Encoder
        if img_model_name == 'resnet18':
            self.cnn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            img_output_dim = self.cnn.fc.in_features
            self.cnn.fc = nn.Identity()
        elif img_model_name == 'resnet34':
             self.cnn = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
             img_output_dim = self.cnn.fc.in_features
             self.cnn.fc = nn.Identity()
        elif img_model_name == 'efficientnet_b0':
            #
            self.cnn = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
            img_output_dim = self.cnn.classifier[1].in_features
            self.cnn.classifier = nn.Identity()
        else:
            raise ValueError(f"Unsupported image model: {img_model_name}")
        self.img_dropout = nn.Dropout(dropout)

        # 3. Fusion Layer
        self.fusion_layer = nn.Sequential(
            nn.Linear(bert_output_dim + img_output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(hidden_dim)
        )

        # 4. Classifier Head
        self.classifier = nn.Linear(hidden_dim, num_genres)

    def forward(self, input_ids, attention_mask, image):
        # Text features
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = bert_output.pooler_output
        text_features = self.bert_dropout(text_features)

        # Image features
        image_features = self.cnn(image)
        image_features = self.img_dropout(image_features)

        # Fusion
        combined_features = torch.cat((text_features, image_features), dim=1)
        fused_features = self.fusion_layer(combined_features)

        # Classification
        logits = self.classifier(fused_features)
        return logits

# TRAINING & EVALUATION

def train_epoch(model, dataloader, optimizer, criterion, device, scheduler=None):
    """Trains the model for one epoch."""
    model.train()
    total_loss = 0.0
    for batch in tqdm(dataloader, desc="Training"):
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        image = batch['image'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad() #
        logits = model(input_ids, attention_mask, image) #
        loss = criterion(logits, labels) #

        loss.backward() #
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader) #
    return avg_loss


def evaluate(model, dataloader, criterion, device, threshold=0.5):
    """Evaluates the model on a dataset."""
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            image = batch['image'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask, image) 
            loss = criterion(logits, labels) 
            total_loss += loss.item()

            probs = torch.sigmoid(logits) 
            preds = (probs > threshold).float()
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    avg_loss = total_loss / len(dataloader) 
    all_preds = torch.cat(all_preds, dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()

    # Calculate F1 scores
    f1_macro = calculate_f1(all_labels, all_preds, average='macro')
    f1_micro = calculate_f1(all_labels, all_preds, average='micro')
    f1_weighted = calculate_f1(all_labels, all_preds, average='weighted')
    f1_samples = calculate_f1(all_labels, all_preds, average='samples')

    metrics = {
        'loss': avg_loss,
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'f1_weighted': f1_weighted,
        'f1_samples': f1_samples
    }
    return metrics, all_preds, all_labels

# Function to safely parse lists
def parse_genre_list(genre_str):
    if pd.isna(genre_str) or not isinstance(genre_str, str) or not genre_str.strip():
        return []
    try:
        # Use ast.literal_eval for safe evaluation of Python literals
        genre_list = ast.literal_eval(genre_str)
        if isinstance(genre_list, list):
            # Ensure elements are strings, handle potential non-string elements
            return [str(item).strip() for item in genre_list if item]
        else:
            # Log if the evaluated result is not a list
            logging.warning(f"Parsed genre data is not a list: {genre_str}")
            return []
    except (ValueError, SyntaxError, TypeError) as e:
        # Log parsing errors
        logging.warning(f"Could not parse genre string: '{genre_str}'. Error: {e}")
        return []

# MAIN PIPELINE


def main():
    # Configuration 
    SEED = 42
    MAX_LENGTH = 256
    BATCH_SIZE = 32
    LEARNING_RATE = 3e-5
    WEIGHT_DECAY = 0.01
    DROPOUT = 0.3
    N_EPOCHS = 20
    IMAGE_MODEL = 'resnet18'
    TEXT_MODEL = 'bert-base-uncased'
    EARLY_STOPPING_PATIENCE = 5
    EARLY_STOPPING_METRIC = 'f1_weighted'
    PREDICTION_THRESHOLD = 0.5 
    SCHEDULER_FACTOR = 0.5
    SCHEDULER_PATIENCE = 2
    TRAIN_SIZE = 0.7 # 70% for training
    VAL_SIZE = 0.15 # 15% for validation (remaining 15% for test)

    # Set seeds
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.benchmark = True

    # Setup Directories
    base_dir = "/content/drive/MyDrive/movie_dataset_complete 2 (1)/movie_dataset_complete"
    logging.info(f"Using specified base directory: {base_dir}")

    # Path to the input CSV file
    csv_file_path = os.path.join(base_dir, "movies_data.csv")

    poster_dir = os.path.join(base_dir, "posters")
    models_dir = os.path.join(base_dir, "models_output", "simple_tuned_csv") # New subdir
    results_dir = os.path.join(base_dir, "results_output", "simple_tuned_csv") # New subdir

    # Create output directories if they don't exist
    os.makedirs(models_dir, exist_ok=True)
    os.makedirs(results_dir, exist_ok=True)
    # Check if necessary input exists
    if not os.path.isfile(csv_file_path):
         logging.error(f"Input CSV file not found: {csv_file_path}")
         return
    if not os.path.isdir(poster_dir):
         logging.warning(f"Poster directory not found: {poster_dir}. Using placeholder images.")


    # Device Selection
    if torch.backends.mps.is_available():
         device = torch.device("mps")
    elif torch.cuda.is_available():
         device = torch.device("cuda")
    else:
         device = torch.device("cpu")
    logging.info(f"Using device: {device}")

    # Load Data from CSV and Preprocess
    
    logging.info(f"Loading data from CSV: {csv_file_path}")
    try:
        # Load the entire dataset from CSV
        full_df = pd.read_csv(csv_file_path)
        logging.info(f"Loaded {len(full_df)} total samples from CSV.")

        # Preprocess Genres
        if 'genres' not in full_df.columns:
            logging.error("CSV file must contain a 'genres' column.")
            return

        # Apply the parsing function 
        full_df['groups'] = full_df['genres'].apply(parse_genre_list)

        # Drop rows where genre parsing failed completely 
        original_empty_or_nan = full_df['genres'].isna() | (full_df['genres'].str.strip() == '') | (full_df['genres'].str.strip() == '[]')
        parsed_empty = full_df['groups'].apply(lambda x: len(x) == 0)
        rows_to_drop = parsed_empty & ~original_empty_or_nan
        if rows_to_drop.sum() > 0:
            logging.warning(f"Dropping {rows_to_drop.sum()} rows due to failed genre parsing (malformed data).")
            full_df = full_df[~rows_to_drop].copy()


        # Split Data
        logging.info("Splitting data into train, validation, and test sets...")
        # Split into train and temp (val + test)
        train_df, temp_df = train_test_split(
            full_df,
            train_size=TRAIN_SIZE,
            random_state=SEED
            
        )
        # Split temp into val and test
        relative_test_size = (1 - TRAIN_SIZE - VAL_SIZE) / (1 - TRAIN_SIZE)
        val_df, test_df = train_test_split(
            temp_df,
            test_size=relative_test_size,
            random_state=SEED
        
        )
        logging.info(f"Split sizes: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")

    except FileNotFoundError:
        logging.error(f"Input CSV file not found: {csv_file_path}")
        return
    except Exception as e:
        logging.error(f"Error loading or processing data from CSV: {e}")
        return


    # Genre Handling
    #
    all_genres = set()
    # Collect all unique genres
    train_df['groups'].apply(lambda genres: all_genres.update(genres) if isinstance(genres, list) else None) # [cite: 2613]

    all_genres = sorted(list(all_genres))
    num_genres = len(all_genres)
    idx_to_genre = {i: genre for i, genre in enumerate(all_genres)}
    logging.info(f"Found {num_genres} unique genres in training data.")

    if num_genres == 0:
         logging.error("No genres found after processing the 'genres'/'groups' column. Cannot proceed.")
         return


    # Tokenizer and Transformations
    tokenizer = BertTokenizer.from_pretrained(TEXT_MODEL)
    img_size = 224
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Datasets and Dataloaders
    logging.info("Creating datasets...")
    train_dataset = MovieGenreDataset(train_df, tokenizer, all_genres, poster_dir, MAX_LENGTH, train_transform)
    val_dataset = MovieGenreDataset(val_df, tokenizer, all_genres, poster_dir, MAX_LENGTH, val_transform)
    test_dataset = MovieGenreDataset(test_df, tokenizer, all_genres, poster_dir, MAX_LENGTH, val_transform)

    # Calculate class weights based on the training set
    pos_weight = train_dataset.calculate_class_weights().to(device)

    logging.info("Creating dataloaders...")
    num_workers = 2
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=False)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=False)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True, persistent_workers=False)


    # Model, Loss, Optimizer
    logging.info(f"Creating model (Text: {TEXT_MODEL}, Image: {IMAGE_MODEL}, Dropout: {DROPOUT})...")
    model = SimpleMultiModalClassifier(
        num_genres=num_genres,
        text_model_name=TEXT_MODEL,
        img_model_name=IMAGE_MODEL,
        hidden_dim=512,
        dropout=DROPOUT
    )
    model.to(device)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logging.info(f"Model has {trainable_params:,} trainable parameters.")

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=SCHEDULER_FACTOR,
        patience=SCHEDULER_PATIENCE,
        verbose=True
    )

    # Training Loop
    logging.info(f"Starting training for up to {N_EPOCHS} epochs...")
    logging.info(f"Models will be saved to: {models_dir}")
    logging.info(f"Results will be saved to: {results_dir}")
    best_val_metric = -1.0
    best_epoch = -1
    epochs_no_improve = 0
    history = {'train_loss': [], 'val_loss': [], 'val_f1_macro': [], 'val_f1_weighted': [], 'val_f1_samples': []}

    start_time = time.time()
    for epoch in range(N_EPOCHS):
        epoch_start = time.time()

        train_loss = train_epoch(model, train_dataloader, optimizer, criterion, device)
        history['train_loss'].append(train_loss)

        val_metrics, _, _ = evaluate(model, val_dataloader, criterion, device, PREDICTION_THRESHOLD)
        history['val_loss'].append(val_metrics['loss'])
        history['val_f1_macro'].append(val_metrics['f1_macro'])
        history['val_f1_weighted'].append(val_metrics['f1_weighted'])
        history['val_f1_samples'].append(val_metrics['f1_samples'])

        epoch_time = time.time() - epoch_start

        logging.info(
            f"Epoch {epoch+1}/{N_EPOCHS} [{epoch_time:.2f}s] | "
            f"LR: {optimizer.param_groups[0]['lr']:.2e} | "
            f"Train Loss: {train_loss:.4f} | Val Loss: {val_metrics['loss']:.4f} | "
            f"Val F1 Macro: {val_metrics['f1_macro']:.4f} | Val F1 Wgt: {val_metrics['f1_weighted']:.4f} | "
            f"Val F1 Samp: {val_metrics['f1_samples']:.4f}"
        )

        current_metric = val_metrics[EARLY_STOPPING_METRIC]
        scheduler.step(current_metric)

        if current_metric > best_val_metric:
            best_val_metric = current_metric
            best_epoch = epoch
            epochs_no_improve = 0
            best_model_save_path = os.path.join(models_dir, "best_model.pt")
            try:
                torch.save(model.state_dict(), best_model_save_path)
                logging.info(f"  >> Improvement found! New best {EARLY_STOPPING_METRIC}: {current_metric:.4f}. Model saved to {best_model_save_path}")
            except Exception as e:
                logging.error(f"  >> Error saving model to {best_model_save_path}: {e}")
        else:
            epochs_no_improve += 1
            logging.info(f"  (No improvement for {epochs_no_improve}/{EARLY_STOPPING_PATIENCE} epochs)")

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            logging.info(f"Early stopping triggered after {epoch+1} epochs.")
            break

    total_time = time.time() - start_time
    logging.info(f"Training finished in {total_time/60:.2f} minutes. Best epoch: {best_epoch+1} ({EARLY_STOPPING_METRIC}: {best_val_metric:.4f})")

    # Plot Training History
    #
    try:
        plt.figure(figsize=(18, 6))
        plt.subplot(1, 3, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.title('Loss During Training')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        if best_epoch != -1:
             plt.axvline(best_epoch, color='r', linestyle='--', label=f'Best Epoch ({best_epoch+1})')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)

        plt.subplot(1, 3, 2)
        plt.plot(history['val_f1_macro'], label='Validation F1 Macro')
        plt.plot(history['val_f1_weighted'], label='Validation F1 Weighted')
        plt.plot(history['val_f1_samples'], label='Validation F1 Samples', linestyle=':')
        plt.title('Validation F1 Scores')
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        if best_epoch != -1:
             plt.scatter(best_epoch, history[f'val_{EARLY_STOPPING_METRIC}'][best_epoch], s=100, c='red', marker='*', label=f'Best {EARLY_STOPPING_METRIC}={best_val_metric:.3f}')
             plt.axvline(best_epoch, color='r', linestyle='--', label=f'Best Epoch ({best_epoch+1})')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)

        plt.tight_layout()
        plot_path = os.path.join(results_dir, "training_curves.png")
        plt.savefig(plot_path)
        logging.info(f"Training curves saved to {plot_path}")
        plt.close()
    except Exception as e:
        logging.warning(f"Could not plot training curves: {e}")


    # Final Evaluation on Test Set
    logging.info("Loading best model for final evaluation on test set...")
    best_model_path = os.path.join(models_dir, "best_model.pt")
    if os.path.exists(best_model_path):
        try:
            model.load_state_dict(torch.load(best_model_path, map_location=device))
            logging.info(f"Loaded best model state from epoch {best_epoch+1}.")
        except Exception as e:
             logging.error(f"Error loading best model from {best_model_path}: {e}. Evaluating with the last state.")
    else:
        logging.warning(f"Best model file {best_model_path} not found. Evaluating with the model state from the final epoch.")


    logging.info(f"Evaluating on the test set using threshold: {PREDICTION_THRESHOLD}...")
    test_metrics, test_preds, test_labels = evaluate(model, test_dataloader, criterion, device, PREDICTION_THRESHOLD)

    logging.info("--- Test Set Performance ---")
    logging.info(f" Test Loss:       {test_metrics['loss']:.4f}")
    logging.info(f" Test F1 Macro:   {test_metrics['f1_macro']:.4f}")
    logging.info(f" Test F1 Micro:   {test_metrics['f1_micro']:.4f}")
    logging.info(f" Test F1 Weighted:{test_metrics['f1_weighted']:.4f}")
    logging.info(f" Test F1 Samples: {test_metrics['f1_samples']:.4f}")
    logging.info("---------")

    # Save Classification Report
    #
    try:
        # Ensure all_genres list is not empty before generating report
        if not all_genres:
            logging.error("Cannot generate classification report: No target names (genres) found.")
        else:
            report = classification_report(
                test_labels,
                test_preds,
                target_names=all_genres,
                zero_division=0,
                output_dict=True # Get report as dict
            )
            report_df = pd.DataFrame(report).transpose()
            report_path = os.path.join(results_dir, "classification_report.csv")
            report_df.to_csv(report_path)
            logging.info(f"Classification report saved to {report_path}")

            logging.info(f"Macro Avg Precision (Test):   {report_df.loc['macro avg', 'precision']:.4f}")
            logging.info(f"Macro Avg Recall (Test):      {report_df.loc['macro avg', 'recall']:.4f}")
            logging.info(f"Macro Avg F1 (Test):          {report_df.loc['macro avg', 'f1-score']:.4f}")
            logging.info(f"Weighted Avg F1 (Test):       {report_df.loc['weighted avg', 'f1-score']:.4f}")

            # Check target F1 score
            target_f1 = 0.7 # I wanted to keep a target because my previous model gave 0.5 f1 score 
            current_f1_weighted = report_df.loc['weighted avg', 'f1-score']
            if current_f1_weighted >= target_f1:
                logging.info(f"##### Goal Achieved: Target F1 score ({target_f1}) met or exceeded (Weighted F1: {current_f1_weighted:.4f}) #####")
            else:
                logging.warning(f"##### Goal Not Met: Target F1 score ({target_f1}) not reached (Weighted F1: {current_f1_weighted:.4f}) #####")
                logging.warning("Try something else for fusion later something like :(later)hyperparameter tuning, threshold optimization,diff. IMAGE_MODEL, adjusting MAX_LENGTH.")

    except Exception as e:
         logging.error(f"Could not generate or save classification report: {e}")


    # Save Final Results Summary
    final_results = {
        'best_epoch': best_epoch + 1 if best_epoch != -1 else None,
        'best_val_metric ({EARLY_STOPPING_METRIC})': best_val_metric if best_val_metric != -1.0 else None,
        'test_metrics': test_metrics,
        'training_time_minutes': total_time / 60,
        'config': {
            'seed': SEED,
            'max_length': MAX_LENGTH,
            'batch_size': BATCH_SIZE,
            'learning_rate': LEARNING_RATE,
            'weight_decay': WEIGHT_DECAY,
            'dropout': DROPOUT,
            'n_epochs_run': epoch + 1, # Actual epochs run
            'n_epochs_max': N_EPOCHS,
            'image_model': IMAGE_MODEL,
            'text_model': TEXT_MODEL,
            'early_stopping_patience': EARLY_STOPPING_PATIENCE,
            'early_stopping_metric': EARLY_STOPPING_METRIC,
            'scheduler_factor': SCHEDULER_FACTOR,
            'scheduler_patience': SCHEDULER_PATIENCE,
            'prediction_threshold': PREDICTION_THRESHOLD,
            'train_split_size': TRAIN_SIZE,
            'val_split_size': VAL_SIZE,
        }
    }
    results_path = os.path.join(results_dir, "final_results.json")
    try:
        with open(results_path, "w") as f:
            json.dump(final_results, f, indent=4, default=str)
        logging.info(f"Final results summary saved to {results_path}")
    except Exception as e:
        logging.error(f"Could not save final results summary: {e}")

    logging.info("Script finished.")

if __name__ == "__main__":
    main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 168MB/s]
Training: 100%|██████████| 387/387 [01:59<00:00,  3.23it/s]
Evaluating: 100%|██████████| 83/83 [00:08<00:00, 10.01it/s]
Training: 100%|██████████| 387/387 [01:57<00:00,  3.30it/s]
Evaluating: 100%|██████████| 83/83 [00:08<00:00, 10.13it/s]
Training: 100%|██████████| 387/387 [01:57<00:00,  3.30it/s]
Evaluating: 100%|██████████| 83/83 [00:08<00:00, 10.11it/s]
Training: 100%|██████████| 387/387 [01:57<00:00,  3.30it/s]
Evaluating: 100%|██████████| 83/83 [00:08<00:00, 10.15it/s]
Training: 100%|██████████| 387/387 [01:57<00:00,  3.30it/s]
Evaluating: 100%|██████████| 83/83 [00:08<00:00, 10.14it/s]
Training: 100%|██████████| 387/387 [01:57<00:00,  3.30it/s]
Evaluating: 100%|██████████| 83/83 [00:08<00:00, 10.12it/s]
Training: 100%|██████████| 387/387 [01:57<00:00,  3.30it/s]
Evaluating: 100%|██████████| 83/83 

In [None]:
import os
import pandas as pd
import numpy as np
import json
import logging
import pickle
from tqdm import tqdm
import re
import shutil
from PIL import Image
from sklearn.model_selection import train_test_split
import nltk
from nltk.sentiment import SentimentIntensityAnalyzer
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("/content/data_preprocessing.log"),
        logging.StreamHandler()
    ]
)

# Download NLTK resources
nltk.download('vader_lexicon', quiet=True)

# Set paths
base_dir = "/content"
data_dir = os.path.join(base_dir, "data")
processed_data_dir = os.path.join(base_dir, "processed_data")
poster_dir = os.path.join(base_dir, "posters")

# Create directories
os.makedirs(data_dir, exist_ok=True)
os.makedirs(processed_data_dir, exist_ok=True)
os.makedirs(os.path.join(processed_data_dir, "splits"), exist_ok=True)
os.makedirs(poster_dir, exist_ok=True)

# Path to your CSV file in Google Drive
MOVIE_CSV_PATH = "/content/drive/MyDrive/movie_dataset_complete 2 (1)/movie_dataset_complete/movies_data.csv"  # Update this path

# Path to your movie posters folder in Google Drive
POSTER_FOLDER_PATH = "/content/drive/MyDrive/movie_dataset_complete 2 (1)/movie_dataset_complete/posters"  # Update this path

def load_movie_data_csv(file_path):
    """
    Load movie data from CSV file
    """
    logging.info(f"Loading movie data from CSV: {file_path}")

    try:
        df = pd.read_csv(file_path)
        logging.info(f"Loaded {len(df)} movies from CSV")

        # Check for required columns
        required_columns = ['id', 'title', 'overview']
        missing_columns = [col for col in required_columns if col not in df.columns]

        if missing_columns:
            logging.warning(f"CSV is missing required columns: {missing_columns}")

        # Display sample data
        print("Sample data:")
        print(df.head())
        print("\nColumns:", df.columns.tolist())

        return df

    except Exception as e:
        logging.error(f"Error loading CSV file: {e}")
        raise

def copy_posters_from_drive(drive_folder_path):
    """
    Copy movie posters from Google Drive folder to local poster directory

    Args:
        drive_folder_path: Path to the Google Drive folder containing poster images
    """
    logging.info(f"Copying poster images from: {drive_folder_path}")

    # Make sure the drive folder exists
    if not os.path.exists(drive_folder_path):
        logging.error(f"Drive folder not found: {drive_folder_path}")
        return 0

    # Track statistics
    copied = 0
    skipped = 0
    errors = 0

    # List all image files in the drive folder
    image_extensions = ['.jpg', '.jpeg', '.png']
    poster_files = [f for f in os.listdir(drive_folder_path)
                   if os.path.isfile(os.path.join(drive_folder_path, f)) and
                   any(f.lower().endswith(ext) for ext in image_extensions)]

    logging.info(f"Found {len(poster_files)} poster images in drive folder")

    # Copy each poster file to the local poster directory
    for poster_file in tqdm(poster_files, desc="Copying posters"):
        src_path = os.path.join(drive_folder_path, poster_file)
        dst_path = os.path.join(poster_dir, poster_file)

        # Skip if the file already exists in the destination
        if os.path.exists(dst_path):
            skipped += 1
            continue

        try:
            # Copy the file
            shutil.copy2(src_path, dst_path)

            # Verify the image is valid
            try:
                img = Image.open(dst_path)
                img.verify()
                copied += 1
            except Exception as e:
                logging.error(f"Invalid image file: {poster_file} - {e}")
                os.remove(dst_path)  # Remove the invalid file
                errors += 1

        except Exception as e:
            logging.error(f"Error copying poster {poster_file}: {e}")
            errors += 1

    logging.info(f"Poster copying complete. Copied: {copied}, Skipped: {skipped}, Errors: {errors}")
    return copied

def extract_features(df):
    """
    Extract and process features from the movie dataset
    """
    logging.info("Extracting features from dataset...")

    # Make a copy to avoid modifying the original
    processed_df = df.copy()

    # Extract or process genre information
    if 'genres' in processed_df.columns:
        # Try to parse genre information if it's stored as a string
        if processed_df['genres'].dtype == 'object' and isinstance(processed_df['genres'].iloc[0], str):
            try:
                # Try to parse genres column as JSON or string list
                processed_df['genres'] = processed_df['genres'].apply(
                    lambda x: json.loads(x.replace("'", "\"")) if isinstance(x, str) else x
                )
            except Exception as e:
                logging.warning(f"Failed to parse genres as JSON: {e}")

                # Try to extract genres from string format like "[{'id': 12, 'name': 'Adventure'}, ...]"
                try:
                    def extract_genre_names(genre_str):
                        if not isinstance(genre_str, str):
                            return []
                        pattern = r"'name'\s*:\s*'([^']+)'"
                        matches = re.findall(pattern, genre_str)
                        return matches

                    processed_df['genres'] = processed_df['genres'].apply(extract_genre_names)
                except Exception as e:
                    logging.warning(f"Failed to extract genre names with regex: {e}")

        # Create 'groups' column expected by the model
        processed_df['groups'] = processed_df['genres'].apply(
            lambda x: [genre['name'] for genre in x] if isinstance(x, list) and
                      all(isinstance(g, dict) and 'name' in g for g in x) else
            x if isinstance(x, list) else []
        )

    # Calculate title length
    processed_df['title_length'] = processed_df['title'].apply(
        lambda x: len(str(x)) if pd.notna(x) else 0
    )

    # Calculate overview length
    processed_df['overview_length'] = processed_df['overview'].apply(
        lambda x: len(str(x)) if pd.notna(x) and isinstance(x, str) else 0
    )

    # Initialize sentiment analyzer
    sia = SentimentIntensityAnalyzer()

    # Title sentiment
    processed_df['title_sentiment'] = processed_df['title'].apply(
        lambda x: sia.polarity_scores(str(x))['compound'] if pd.notna(x) else 0
    )

    # Overview sentiment
    processed_df['overview_sentiment'] = processed_df['overview'].apply(
        lambda x: sia.polarity_scores(str(x))['compound'] if pd.notna(x) and isinstance(x, str) else 0
    )

    # Add vote_average if not present (used as a feature by the model)
    if 'vote_average' not in processed_df.columns:
        # Use a default value or another field if available
        if 'rating' in processed_df.columns:
            processed_df['vote_average'] = processed_df['rating']
        else:
            processed_df['vote_average'] = 5.0  # Default middle value
            logging.warning("Adding default vote_average of 5.0")

    # Filter out movies without overviews or genres
    valid_df = processed_df[
        processed_df['overview'].notna() &
        (processed_df['overview_length'] > 10) &
        processed_df['groups'].apply(lambda x: len(x) > 0)
    ]

    logging.info(f"After filtering: {len(valid_df)} movies (removed {len(processed_df) - len(valid_df)})")

    # Log genre distribution
    genre_counts = {}
    for genres in valid_df['groups']:
        for genre in genres:
            if genre in genre_counts:
                genre_counts[genre] += 1
            else:
                genre_counts[genre] = 1

    logging.info("Genre distribution:")
    for genre, count in sorted(genre_counts.items(), key=lambda x: x[1], reverse=True):
        logging.info(f"  {genre}: {count} movies ({count/len(valid_df)*100:.2f}%)")

    return valid_df

def match_posters_to_movies(df, poster_directory):
    """
    Match poster files to movies in the dataframe
    """
    logging.info("Matching poster files to movies...")

    # List all poster files
    poster_files = os.listdir(poster_directory)
    poster_files = [f for f in poster_files if f.endswith(('.jpg', '.jpeg', '.png'))]

    logging.info(f"Found {len(poster_files)} poster files in {poster_directory}")

    # Create a mapping of movie IDs to poster file paths
    id_to_poster = {}

    # Check if posters are named by ID (e.g., 123.jpg)
    for poster_file in poster_files:
        file_base = os.path.splitext(poster_file)[0]

        # If file_base is numeric, it might be an ID
        if file_base.isdigit():
            id_to_poster[int(file_base)] = os.path.join(poster_directory, poster_file)
        else:
            # Try to match by movie title
            # E.g., "The_Avengers.jpg" -> "The Avengers"
            title_from_file = file_base.replace('_', ' ')
            id_to_poster[title_from_file] = os.path.join(poster_directory, poster_file)

    # Add poster path to dataframe
    df['poster_path'] = None

    # Try to match by ID first
    if 'id' in df.columns:
        for i, row in df.iterrows():
            movie_id = row['id']
            if movie_id in id_to_poster:
                df.at[i, 'poster_path'] = id_to_poster[movie_id]

    # If still missing posters, try to match by title
    missing_posters = df['poster_path'].isna().sum()
    if missing_posters > 0 and 'title' in df.columns:
        for i, row in df.iterrows():
            if pd.isna(df.at[i, 'poster_path']):
                title = row['title']
                # Look for exact title match
                if title in id_to_poster:
                    df.at[i, 'poster_path'] = id_to_poster[title]
                else:
                    # Try normalized title match (lowercase, no punctuation)
                    normalized_title = re.sub(r'[^\w\s]', '', title.lower())

                    for poster_key, poster_path in id_to_poster.items():
                        if isinstance(poster_key, str):
                            normalized_key = re.sub(r'[^\w\s]', '', poster_key.lower())
                            if normalized_title == normalized_key:
                                df.at[i, 'poster_path'] = poster_path
                                break

    # Count movies with matching posters
    matched_posters = df['poster_path'].notna().sum()
    logging.info(f"Matched {matched_posters} posters to movies ({matched_posters/len(df)*100:.2f}%)")

    return df

def create_data_splits(df, test_size=0.1, val_size=0.1, random_state=42):
    """
    Split the dataset into train, validation, and test sets
    """
    logging.info("Creating train, validation, and test splits...")

    # Ensure we have a genre column for stratification
    if 'groups' in df.columns and len(df) > 0:
        # Get the first genre for each movie for stratification
        # This helps maintain genre distribution across splits
        stratify_col = df['groups'].apply(lambda x: x[0] if len(x) > 0 else None)

        # If too many unique values or NaN, don't stratify
        if stratify_col.nunique() > len(df) // 10 or stratify_col.isna().any():
            logging.info("Too many unique genres for stratification, using random split instead")
            stratify_col = None
    else:
        stratify_col = None

    # First split off test set
    train_val_df, test_df = train_test_split(
        df, test_size=test_size, random_state=random_state,
        stratify=stratify_col
    )

    # Then split train into train and validation
    # Adjust validation size to account for the reduced dataset size
    val_size_adjusted = val_size / (1 - test_size)

    if stratify_col is not None:
        stratify_train_val = train_val_df['groups'].apply(lambda x: x[0] if len(x) > 0 else None)
        if stratify_train_val.nunique() > len(train_val_df) // 10 or stratify_train_val.isna().any():
            stratify_train_val = None
    else:
        stratify_train_val = None

    train_df, val_df = train_test_split(
        train_val_df, test_size=val_size_adjusted, random_state=random_state,
        stratify=stratify_train_val
    )

    logging.info(f"Split sizes: Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")

    return train_df, val_df, test_df

def filter_movies_with_posters(df):
    """
    Filter the dataset to only include movies with poster images
    """
    logging.info("Filtering dataset to only include movies with posters...")

    # Keep only rows with valid poster paths
    filtered_df = df[df['poster_path'].notna()]

    # Verify posters actually exist on disk
    filtered_df['poster_exists'] = filtered_df['poster_path'].apply(os.path.exists)
    final_df = filtered_df[filtered_df['poster_exists']]

    # Drop the temporary column
    final_df = final_df.drop(columns=['poster_exists'])

    logging.info(f"After filtering for posters: {len(final_df)} movies (removed {len(df) - len(final_df)})")

    return final_df

def save_to_drive(processed_df, train_df, val_df, test_df):
    """
    Save processed data to Google Drive
    """
    logging.info("Saving processed data to Google Drive...")

    # Create output directory in Google Drive
    drive_output_dir = "/content/drive/MyDrive/movie_genre_classification"
    drive_splits_dir = os.path.join(drive_output_dir, "splits")

    os.makedirs(drive_output_dir, exist_ok=True)
    os.makedirs(drive_splits_dir, exist_ok=True)

    # Save the processed dataset
    processed_df.to_pickle(os.path.join(drive_output_dir, "processed_movies.pkl"))
    processed_df.to_csv(os.path.join(drive_output_dir, "processed_movies.csv"), index=False)

    # Save splits
    train_df.to_pickle(os.path.join(drive_splits_dir, "train.pkl"))
    val_df.to_pickle(os.path.join(drive_splits_dir, "val.pkl"))
    test_df.to_pickle(os.path.join(drive_splits_dir, "test.pkl"))

    # Also save to local directory for immediate use
    train_df.to_pickle(os.path.join(processed_data_dir, "splits", "train.pkl"))
    val_df.to_pickle(os.path.join(processed_data_dir, "splits", "val.pkl"))
    test_df.to_pickle(os.path.join(processed_data_dir, "splits", "test.pkl"))

    logging.info(f"Data saved to {drive_output_dir}")

def main():
    """
    Main function to process movie dataset
    """
    # Set random seed for reproducibility
    np.random.seed(42)

    # Step 1: Load movie data from CSV file in Google Drive
    df = load_movie_data_csv(MOVIE_CSV_PATH)

    # Step 2: Copy poster images from Google Drive folder to local storage
    copy_posters_from_drive(POSTER_FOLDER_PATH)

    # Step 3: Extract features
    processed_df = extract_features(df)

    # Step 4: Match poster files to movies in the dataframe
    processed_df = match_posters_to_movies(processed_df, poster_dir)

    # Step 5: Filter to only include movies with posters
    processed_df = filter_movies_with_posters(processed_df)

    # Step 6: Create train, validation, and test splits
    train_df, val_df, test_df = create_data_splits(processed_df)

    # Step 7: Save processed data to Google Drive
    save_to_drive(processed_df, train_df, val_df, test_df)

    # Print summary
    print("\n===== DATA PREPROCESSING SUMMARY =====")
    print(f"Total dataset size: {len(processed_df)} movies")
    print(f"Training set: {len(train_df)} movies")
    print(f"Validation set: {len(val_df)} movies")
    print(f"Test set: {len(test_df)} movies")

    # Print feature columns
    print("\nFeature columns:")
    for col in processed_df.columns:
        print(f"- {col}")

    # Print top genres
    genre_counts = {}
    for genres in processed_df['groups']:
        for genre in genres:
            if genre in genre_counts:
                genre_counts[genre] += 1
            else:
                genre_counts[genre] = 1

    print("\nTop 10 genres:")
    for genre, count in sorted(genre_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"- {genre}: {count} movies ({count/len(processed_df)*100:.2f}%)")

    logging.info("Data preprocessing complete!")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Sample data:
        id                             title  \
0   950387                 A Minecraft Movie   
1  1229730                        Carjackers   
2  1197306                     A Working Man   
3  1125899                           Cleaner   
4   822119  Captain America: Brave New World   

                                         genres  \
0  ['Family', 'Comedy', 'Adventure', 'Fantasy']   
1                       ['Action', 'Adventure']   
2               ['Action', 'Crime', 'Thriller']   
3                        ['Action', 'Thriller']   
4     ['Action', 'Thriller', 'Science Fiction']   

                                            overview          poster_path  \
0  Four misfits find themselves struggling with o...   posters/950387.jpg   
1  By day, they're invisible—valets, hostesses, a...  posters/1229730.jpg   
2  Levon Cade left behind a dec

Copying posters: 100%|██████████| 130/130 [00:03<00:00, 36.31it/s] 



===== DATA PREPROCESSING SUMMARY =====
Total dataset size: 109 movies
Training set: 87 movies
Validation set: 11 movies
Test set: 11 movies

Feature columns:
- id
- title
- genres
- overview
- poster_path
- release_date
- vote_average
- groups
- title_length
- overview_length
- title_sentiment
- overview_sentiment

Top 10 genres:
- Comedy: 33 movies (30.28%)
- Thriller: 29 movies (26.61%)
- Action: 29 movies (26.61%)
- Drama: 27 movies (24.77%)
- Horror: 25 movies (22.94%)
- Romance: 18 movies (16.51%)
- Science Fiction: 14 movies (12.84%)
- Adventure: 14 movies (12.84%)
- Fantasy: 11 movies (10.09%)
- Animation: 10 movies (9.17%)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['poster_exists'] = filtered_df['poster_path'].apply(os.path.exists)
