In [2]:
# Block 1: Environment Setup & Imports

# ---------------------------------------------------------------------------
# 1. Install Necessary Libraries
# ---------------------------------------------------------------------------
!pip install -q torch torchvision torchaudio transformers pandas numpy scikit-learn matplotlib seaborn nltk librosa soundfile opencv-python-headless gdown facenet-pytorch captum tqdm moviepy ipywidgets # Added moviepy and ipywidgets for tqdm in notebooks

print("--- Libraries potentially installed/updated ---")

# ---------------------------------------------------------------------------
# 2. Import Essential Libraries
# ---------------------------------------------------------------------------
import os
import re
import glob
import zipfile
import tarfile # Needed for .tar.gz
import pickle
import requests
import warnings
import random
import time
import tempfile
from collections import Counter
from copy import deepcopy # For saving best model state

import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup # LR Scheduler

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from sklearn.manifold import TSNE # For visualization

from transformers import RobertaTokenizer, RobertaModel # For Text
from transformers import HubertModel, Wav2Vec2FeatureExtractor # For Audio (Using Feature Extractor)

import librosa
import librosa.display
# import soundfile as sf

import cv2
from facenet_pytorch import MTCNN
import torchvision.models as vision_models
import torchvision.transforms as vision_transforms
from PIL import Image
from moviepy.editor import VideoFileClip # For audio extraction

from captum.attr import IntegratedGradients, Occlusion, LayerIntegratedGradients

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm # Use notebook tqdm

# Optional: Logging
# import wandb
# from torch.utils.tensorboard import SummaryWriter

print("--- Essential libraries imported ---")

# ---------------------------------------------------------------------------
# 3. Basic Configuration & Setup
# ---------------------------------------------------------------------------
warnings.filterwarnings('ignore')
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# Deterministic ops can sometimes slow things down or not be available
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')
    print("Using CPU")

# Define base data directory consistently
BASE_DATA_DIR = 'data/MELD'

print("--- Basic configuration set ---")
print(f"Device: {device}")
print(f"Seed: {SEED}")
# ---------------------------------------------------------------------------
# End of Block 1
# ---------------------------------------------------------------------------

--- Libraries potentially installed/updated ---
--- Essential libraries imported ---
Using GPU: Tesla T4
--- Basic configuration set ---
Device: cuda
Seed: 42


In [3]:
# Block 2: Download and Extract MELD Raw & Nested Archives

import tarfile
import requests
import os
import glob
from tqdm.notebook import tqdm # Use notebook tqdm

# Ensure BASE_DATA_DIR is defined
if 'BASE_DATA_DIR' not in locals():
     BASE_DATA_DIR = 'data/MELD'

# ---------------------------------------------------------------------------
# 1. Setup Directories
# ---------------------------------------------------------------------------
EXTRACTED_RAW_DIR = os.path.join(BASE_DATA_DIR, 'MELD.Raw') # Initial extraction target
TAR_GZ_PATH = os.path.join(BASE_DATA_DIR, 'MELD.Raw.tar.gz') # Path for the main archive

NESTED_ARCHIVES = {
    'train': os.path.join(EXTRACTED_RAW_DIR, 'train.tar.gz'),
    'dev': os.path.join(EXTRACTED_RAW_DIR, 'dev.tar.gz'),
    'test': os.path.join(EXTRACTED_RAW_DIR, 'test.tar.gz'),
}
NESTED_EXTRACT_DIRS = {
    'train': os.path.join(EXTRACTED_RAW_DIR, 'train_extracted'),
    'dev': os.path.join(EXTRACTED_RAW_DIR, 'dev_extracted'),
    'test': os.path.join(EXTRACTED_RAW_DIR, 'test_extracted'),
}

os.makedirs(BASE_DATA_DIR, exist_ok=True)
print(f"Base data directory created/ensured: {BASE_DATA_DIR}")

# ---------------------------------------------------------------------------
# 2. Download Metadata Files (CSVs)
# ---------------------------------------------------------------------------
MELD_REPO_BASE_URL = "https://raw.githubusercontent.com/declare-lab/MELD/master/data/MELD/"
METADATA_FILES = [ "train_sent_emo.csv", "dev_sent_emo.csv", "test_sent_emo.csv" ]

def download_small_file(url, filename):
    if os.path.exists(filename): return True
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        with open(filename, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
        return True
    except Exception as e: print(f"Error downloading {url}: {e}"); return False

print("\n--- Downloading Metadata CSVs (checking existence) ---")
metadata_download_success = all(download_small_file(f"{MELD_REPO_BASE_URL}{file}", os.path.join(BASE_DATA_DIR, file)) for file in METADATA_FILES)

if metadata_download_success: print("Metadata CSVs present.")
else: print("ERROR: Failed to download one or more metadata files.")

# ---------------------------------------------------------------------------
# 3. Download MELD.Raw.tar.gz (Official URL)
# ---------------------------------------------------------------------------
MELD_RAW_URL = "https://huggingface.co/datasets/declare-lab/MELD/resolve/main/MELD.Raw.tar.gz"

def download_large_file(url, filename, chunk_size=8192*4): # larger chunk
    if os.path.exists(filename):
        print(f"Archive already exists: {filename}. Skipping download.")
        return True
    try:
        print(f"Downloading {os.path.basename(filename)} from {url} ...")
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=os.path.basename(filename))
            with open(filename, 'wb') as f:
                for chunk in r.iter_content(chunk_size=chunk_size):
                    progress_bar.update(len(chunk))
                    f.write(chunk)
            progress_bar.close()
            if total_size != 0 and progress_bar.n != total_size:
                 print("ERROR: Download size mismatch!"); return False
        print(f"\nSuccessfully downloaded: {filename}")
        return True
    except Exception as e:
        print(f"\nError downloading {url}: {e}")
        if os.path.exists(filename):
            try: os.remove(filename)
            except OSError: pass
        return False

print(f"\n--- Checking/Downloading MELD Raw Data Archive ---")
download_success = download_large_file(MELD_RAW_URL, TAR_GZ_PATH)

# ---------------------------------------------------------------------------
# 4. Extract MELD.Raw.tar.gz (Initial Extraction)
# ---------------------------------------------------------------------------
initial_extraction_done = False
if os.path.exists(EXTRACTED_RAW_DIR) and len(os.listdir(EXTRACTED_RAW_DIR)) > 1:
     print(f"Initial extraction directory '{EXTRACTED_RAW_DIR}' seems populated. Skipping initial extraction.")
     initial_extraction_done = True
elif download_success and os.path.exists(TAR_GZ_PATH):
    try:
        print(f"\nPerforming initial extraction of '{TAR_GZ_PATH}' to '{BASE_DATA_DIR}'...")
        with tarfile.open(TAR_GZ_PATH, "r:gz") as tar:
             members = tar.getmembers()
             progress_bar = tqdm(members, total=len(members), desc="Initial Extract")
             for member in progress_bar: tar.extract(member, path=BASE_DATA_DIR)
        print(f"\nInitial extraction complete. Contents should be in '{EXTRACTED_RAW_DIR}'")
        initial_extraction_done = True
    except Exception as e: print(f"ERROR during initial extraction of {TAR_GZ_PATH}: {e}")
else: print(f"Archive {TAR_GZ_PATH} not found or download failed. Cannot perform initial extraction.")

# ---------------------------------------------------------------------------
# 5. Extract Nested Archives (train.tar.gz, dev.tar.gz, test.tar.gz)
# ---------------------------------------------------------------------------
all_nested_extracted = True
if initial_extraction_done:
    print(f"\n--- Checking/Extracting Nested Archives within {EXTRACTED_RAW_DIR} ---")
    for split, archive_path in NESTED_ARCHIVES.items():
        extract_to_dir = NESTED_EXTRACT_DIRS[split]
        if os.path.exists(extract_to_dir) and len(os.listdir(extract_to_dir)) > 0:
            print(f"Nested extraction dir '{extract_to_dir}' seems populated. Skipping.")
        elif os.path.exists(archive_path):
             try:
                 print(f"Extracting nested archive '{os.path.basename(archive_path)}' to '{extract_to_dir}'...")
                 os.makedirs(extract_to_dir, exist_ok=True)
                 with tarfile.open(archive_path, "r:gz") as tar:
                      members = tar.getmembers()
                      progress_bar = tqdm(members, total=len(members), desc=f"Extract {split}")
                      for member in progress_bar: tar.extract(member, path=extract_to_dir)
                 print(f"Finished extracting {os.path.basename(archive_path)}.")
             except Exception as e: print(f"ERROR extracting nested archive {archive_path}: {e}"); all_nested_extracted = False
        else: print(f"ERROR: Nested archive not found: {archive_path}. Cannot extract {split} data."); all_nested_extracted = False
else: print("\nSkipping nested extraction because initial extraction did not complete or was skipped."); all_nested_extracted = False

# --- Final Check ---
if not metadata_download_success: print("\nERROR: Metadata failed to download.")
if not all_nested_extracted: print("\nERROR: Nested archive extraction failed or was skipped for one or more splits.")
if metadata_download_success and all_nested_extracted:
    print("\n--- Block 2 Data Download & ALL Extractions Potentially Complete ---")
    print(f"Actual data files should now be in subdirectories within: {EXTRACTED_RAW_DIR}")
# ---------------------------------------------------------------------------
# End of Block 2
# ---------------------------------------------------------------------------

Base data directory created/ensured: data/MELD

--- Downloading Metadata CSVs (checking existence) ---
Metadata CSVs present.

--- Checking/Downloading MELD Raw Data Archive ---
Downloading MELD.Raw.tar.gz from https://huggingface.co/datasets/declare-lab/MELD/resolve/main/MELD.Raw.tar.gz ...


MELD.Raw.tar.gz:   0%|          | 0.00/10.9G [00:00<?, ?iB/s]


Successfully downloaded: data/MELD/MELD.Raw.tar.gz

Performing initial extraction of 'data/MELD/MELD.Raw.tar.gz' to 'data/MELD'...


Initial Extract:   0%|          | 0/8 [00:00<?, ?it/s]


Initial extraction complete. Contents should be in 'data/MELD/MELD.Raw'

--- Checking/Extracting Nested Archives within data/MELD/MELD.Raw ---
Extracting nested archive 'train.tar.gz' to 'data/MELD/MELD.Raw/train_extracted'...


Extract train:   0%|          | 0/9991 [00:00<?, ?it/s]

Finished extracting train.tar.gz.
Extracting nested archive 'dev.tar.gz' to 'data/MELD/MELD.Raw/dev_extracted'...


Extract dev:   0%|          | 0/1113 [00:00<?, ?it/s]

Finished extracting dev.tar.gz.
Extracting nested archive 'test.tar.gz' to 'data/MELD/MELD.Raw/test_extracted'...


Extract test:   0%|          | 0/4809 [00:00<?, ?it/s]

Finished extracting test.tar.gz.

--- Block 2 Data Download & ALL Extractions Potentially Complete ---
Actual data files should now be in subdirectories within: data/MELD/MELD.Raw


In [4]:
# Block 3: Load and Preprocess Data with NaN Handling

import pandas as pd
import numpy as np
import os
import re
from sklearn.preprocessing import LabelEncoder
import pickle
from tqdm.notebook import tqdm

# Ensure BASE_DATA_DIR and EXTRACTED_RAW_DIR are defined
if 'BASE_DATA_DIR' not in locals(): BASE_DATA_DIR = 'data/MELD'
if 'EXTRACTED_RAW_DIR' not in locals(): EXTRACTED_RAW_DIR = os.path.join(BASE_DATA_DIR, 'MELD.Raw')

print("--- Starting Block 3: Load and Preprocess Data ---")
print(f"Expecting raw video data within subdirectories of: {EXTRACTED_RAW_DIR}")

# ---------------------------------------------------------------------------
# 1. Load Metadata CSVs
# ---------------------------------------------------------------------------
try:
    df_train = pd.read_csv(os.path.join(BASE_DATA_DIR, 'train_sent_emo.csv'))
    df_val = pd.read_csv(os.path.join(BASE_DATA_DIR, 'dev_sent_emo.csv'))
    df_test = pd.read_csv(os.path.join(BASE_DATA_DIR, 'test_sent_emo.csv'))
    df_train['Split'] = 'train'
    df_val['Split'] = 'dev'
    df_test['Split'] = 'test'
    print("\nMetadata CSVs loaded successfully.")
except FileNotFoundError as e: print(f"ERROR: Metadata CSV files not found in {BASE_DATA_DIR}."); raise e

# ---------------------------------------------------------------------------
# 2. Initial Data Exploration
# ---------------------------------------------------------------------------
print("\n--- Sample Data (Train Set) ---")
print(df_train.head(2))
df_full = pd.concat([df_train, df_val, df_test], ignore_index=True)
print(f"\nTotal samples (combined for analysis): {len(df_full)}")

# ---------------------------------------------------------------------------
# 3. Text Preprocessing
# ---------------------------------------------------------------------------
def clean_text(text):
    text = str(text).lower(); text = re.sub(r"[^a-z0-9\s]", '', text); text = re.sub(r'\s+', ' ', text).strip(); return text
print("\n--- Applying Text Cleaning ---")
tqdm.pandas(desc="Cleaning Text"); df_train['clean_text'] = df_train['Utterance'].progress_apply(clean_text)
tqdm.pandas(desc="Cleaning Text"); df_val['clean_text'] = df_val['Utterance'].progress_apply(clean_text)
tqdm.pandas(desc="Cleaning Text"); df_test['clean_text'] = df_test['Utterance'].progress_apply(clean_text)
print("Cleaned text generated ('clean_text' column).")

# ---------------------------------------------------------------------------
# 4. Label Encoding (Emotion, Sentiment, Speaker with NaN Handling)
# ---------------------------------------------------------------------------
emotion_encoder = LabelEncoder(); sentiment_encoder = LabelEncoder(); speaker_encoder = LabelEncoder()

# --- Handle potential NaNs in Speaker column BEFORE fitting ---
print("\n--- Checking/Handling NaNs in 'Speaker' column ---")
nan_found = False
for df, name in zip([df_train, df_val, df_test], ['Train', 'Val', 'Test']):
    nan_count = df['Speaker'].isnull().sum()
    if nan_count > 0:
        print(f"Found {nan_count} NaN values in {name} Speaker column. Filling with 'UnknownSpeaker'.")
        df['Speaker'].fillna('UnknownSpeaker', inplace=True)
        nan_found = True
if not nan_found: print("No NaNs found in Speaker columns.")

# Fit encoders - ensure 'UnknownSpeaker' is included if NaNs were found
# Create the full list of unique speakers after potential NaN handling
all_speakers_for_fit = pd.concat([df_train['Speaker'], df_val['Speaker'], df_test['Speaker']]).unique()
all_emotions_for_fit = pd.concat([df_train['Emotion'], df_val['Emotion'], df_test['Emotion']]).unique()
all_sentiments_for_fit = pd.concat([df_train['Sentiment'], df_val['Sentiment'], df_test['Sentiment']]).unique()

print("\n--- Fitting Label Encoders ---")
emotion_encoder.fit(all_emotions_for_fit)
sentiment_encoder.fit(all_sentiments_for_fit)
speaker_encoder.fit(all_speakers_for_fit)
print("Label Encoders Fitted.")

# Apply fitted encoders
print("\n--- Applying Label Encoders to Splits ---")
for df in [df_train, df_val, df_test]:
    df['emotion_label'] = emotion_encoder.transform(df['Emotion'])
    df['sentiment_label'] = sentiment_encoder.transform(df['Sentiment'])
    df['speaker_label'] = speaker_encoder.transform(df['Speaker'])
print("Labels applied.")

# Store and Print mappings
encoder_mappings = {
    'emotion': dict(zip(emotion_encoder.classes_, emotion_encoder.transform(emotion_encoder.classes_))),
    'sentiment': dict(zip(sentiment_encoder.classes_, sentiment_encoder.transform(sentiment_encoder.classes_))),
    'speaker': dict(zip(speaker_encoder.classes_, speaker_encoder.transform(speaker_encoder.classes_))),
}
print("\n--- Label Mappings ---")
print("Emotion Mapping:", encoder_mappings['emotion'])
print(f"Speaker Mapping includes 'UnknownSpeaker': {'UnknownSpeaker' in encoder_mappings['speaker']}")
mapping_path = os.path.join(BASE_DATA_DIR, 'encoder_mappings.pkl')
with open(mapping_path, 'wb') as f: pickle.dump(encoder_mappings, f)
print(f"Encoder mappings saved to {mapping_path}")

# ---------------------------------------------------------------------------
# 5. Generate Video File Paths
# ---------------------------------------------------------------------------
print("\n--- Generating Video File Paths ---")
def generate_video_path(df):
    base_extracted_dir = EXTRACTED_RAW_DIR
    split_subfolders = { 'train': 'train_extracted/train_splits', 'dev': 'dev_extracted/dev_splits_complete', 'test': 'test_extracted/output_repeated_splits_test'}
    tqdm.pandas(desc=f"Generating {df['Split'].iloc[0]} video paths")
    path_template = os.path.join(base_extracted_dir, "{split_subfolder}", "dia{did}_utt{uid}.mp4")
    df['video_file'] = df.progress_apply( lambda row: path_template.format(split_subfolder=split_subfolders[row['Split']], did=row['Dialogue_ID'], uid=row['Utterance_ID']), axis=1)
    return df
df_train = generate_video_path(df_train); df_val = generate_video_path(df_val); df_test = generate_video_path(df_test)
print("Video file paths generated.")

# ---------------------------------------------------------------------------
# 6. Verify Video File Paths Exist
# ---------------------------------------------------------------------------
print("\n--- Verifying Sample Video File Paths Exist ---")
verified_count = 0; error_paths_video = []; check_limit = min(10, len(df_train))
for i in tqdm(range(check_limit), desc="Verifying Video Paths"):
    video_path = df_train.iloc[i]['video_file']; video_exists = os.path.exists(video_path)
    if video_exists: verified_count += 1
    else: error_paths_video.append(video_path)
    if len(error_paths_video) >= 5: print("Stopping verification early..."); break
if error_paths_video: print(f"\nWARN: Video paths not found:\n - " + "\n - ".join(error_paths_video))
print(f"\n{verified_count} out of {check_limit} sample train video paths verified.")
paths_verified = verified_count == check_limit
if paths_verified: print("--- Video path verification successful! ---")
else: print("!!! WARNING: Video path verification FAILED. Check structure/logic.")

# ---------------------------------------------------------------------------
# 7. Final Checks and Saving
# ---------------------------------------------------------------------------
print("\n--- Final DataFrame Shapes ---")
print(f"Train: {df_train.shape}, Validation: {df_val.shape}, Test: {df_test.shape}")
print("\n--- Columns in Train DataFrame ---")
print(df_train.columns.tolist())
save_preprocessed = True
if save_preprocessed and paths_verified:
    try:
        print("\n--- Saving Preprocessed DataFrames ---")
        df_train.to_csv(os.path.join(BASE_DATA_DIR, 'preprocessed_train.csv'), index=False)
        df_val.to_csv(os.path.join(BASE_DATA_DIR, 'preprocessed_val.csv'), index=False)
        df_test.to_csv(os.path.join(BASE_DATA_DIR, 'preprocessed_test.csv'), index=False)
        print("Saved preprocessed dataframes.")
    except Exception as e: print(f"Error saving preprocessed files: {e}")
elif save_preprocessed: print("\nSkipping saving preprocessed files due to path verification failure.")

print("\n--- Block 3 Load and Preprocess Data Complete ---")
# ---------------------------------------------------------------------------
# End of Block 3
# ---------------------------------------------------------------------------

--- Starting Block 3: Load and Preprocess Data ---
Expecting raw video data within subdirectories of: data/MELD/MELD.Raw

Metadata CSVs loaded successfully.

--- Sample Data (Train Set) ---
   Sr No.                                          Utterance          Speaker  \
0       1  also I was the point person on my company’s tr...         Chandler   
1       2                   You must’ve had your hands full.  The Interviewer   

   Emotion Sentiment  Dialogue_ID  Utterance_ID  Season  Episode  \
0  neutral   neutral            0             0       8       21   
1  neutral   neutral            0             1       8       21   

      StartTime       EndTime  Split  
0  00:16:16,059  00:16:21,731  train  
1  00:16:21,940  00:16:23,442  train  

Total samples (combined for analysis): 13708

--- Applying Text Cleaning ---


Cleaning Text:   0%|          | 0/9989 [00:00<?, ?it/s]

Cleaning Text:   0%|          | 0/1109 [00:00<?, ?it/s]

Cleaning Text:   0%|          | 0/2610 [00:00<?, ?it/s]

Cleaned text generated ('clean_text' column).

--- Checking/Handling NaNs in 'Speaker' column ---
No NaNs found in Speaker columns.

--- Fitting Label Encoders ---
Label Encoders Fitted.

--- Applying Label Encoders to Splits ---
Labels applied.

--- Label Mappings ---
Emotion Mapping: {'anger': 0, 'disgust': 1, 'fear': 2, 'joy': 3, 'neutral': 4, 'sadness': 5, 'surprise': 6}
Speaker Mapping includes 'UnknownSpeaker': False
Encoder mappings saved to data/MELD/encoder_mappings.pkl

--- Generating Video File Paths ---


Generating train video paths:   0%|          | 0/9989 [00:00<?, ?it/s]

Generating dev video paths:   0%|          | 0/1109 [00:00<?, ?it/s]

Generating test video paths:   0%|          | 0/2610 [00:00<?, ?it/s]

Video file paths generated.

--- Verifying Sample Video File Paths Exist ---


Verifying Video Paths:   0%|          | 0/10 [00:00<?, ?it/s]


10 out of 10 sample train video paths verified.
--- Video path verification successful! ---

--- Final DataFrame Shapes ---
Train: (9989, 17), Validation: (1109, 17), Test: (2610, 17)

--- Columns in Train DataFrame ---
['Sr No.', 'Utterance', 'Speaker', 'Emotion', 'Sentiment', 'Dialogue_ID', 'Utterance_ID', 'Season', 'Episode', 'StartTime', 'EndTime', 'Split', 'clean_text', 'emotion_label', 'sentiment_label', 'speaker_label', 'video_file']

--- Saving Preprocessed DataFrames ---
Saved preprocessed dataframes.

--- Block 3 Load and Preprocess Data Complete ---


In [11]:
# Block 4: Dataset and DataLoader Classes
# ---------------------------------------------------------------------------
# 1. Install/Check necessary library for audio extraction
# ---------------------------------------------------------------------------
# !pip install -q moviepy # Should be installed from Block 1
print("--- Checking moviepy (ensure installed) ---")

# ---------------------------------------------------------------------------
# 2. Import Libraries
# ---------------------------------------------------------------------------
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import numpy as np
import os
import time
import tempfile
import pickle
from transformers import RobertaTokenizer, HubertModel, Wav2Vec2FeatureExtractor
import librosa
import cv2
from facenet_pytorch import MTCNN
import torchvision.models as vision_models
import torchvision.transforms as vision_transforms
from PIL import Image
from moviepy.video.io.VideoFileClip import VideoFileClip
from torch import nn
import torch.nn.functional as F
from tqdm.notebook import tqdm

# ---------------------------------------------------------------------------
# 3. Setup (Device, Paths)
# ---------------------------------------------------------------------------
if 'device' not in locals(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if 'BASE_DATA_DIR' not in locals(): BASE_DATA_DIR = 'data/MELD'
print(f"Using device: {device}")

# ---------------------------------------------------------------------------
# 4. Define Multimodal Dataset Class
# ---------------------------------------------------------------------------
class MultimodalEmotionDataset(Dataset):
    _is_initialized = False 
    _audio_processor = None; _audio_model = None; _face_detector = None
    _video_feature_extractor = None; _video_transform = None

    def __init__(self, dataframe, tokenizer, max_text_length=128, device='cuda'):
        self.dataframe = dataframe.copy(); self.tokenizer = tokenizer
        self.max_text_length = max_text_length; self.device = torch.device(device)
        if not MultimodalEmotionDataset._is_initialized:
            print(f"--- Initializing Dataset Models (First Instance) on {self.device} ---")
            print("Initializing HuBERT+Extractor..."); MultimodalEmotionDataset._audio_processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
            MultimodalEmotionDataset._audio_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").to(self.device); MultimodalEmotionDataset._audio_model.eval()
            print("Initializing MTCNN+ResNet50..."); MultimodalEmotionDataset._face_detector = MTCNN(keep_all=False, device=self.device, select_largest=True, post_process=False)
            MultimodalEmotionDataset._video_feature_extractor = vision_models.resnet50(weights=vision_models.ResNet50_Weights.IMAGENET1K_V1)
            MultimodalEmotionDataset._video_feature_extractor.fc = nn.Identity(); MultimodalEmotionDataset._video_feature_extractor = MultimodalEmotionDataset._video_feature_extractor.to(self.device); MultimodalEmotionDataset._video_feature_extractor.eval()
            MultimodalEmotionDataset._video_transform = vision_transforms.Compose([ vision_transforms.Resize(256), vision_transforms.CenterCrop(224), vision_transforms.ToTensor(), vision_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
            MultimodalEmotionDataset._is_initialized = True; print("Dataset Models Initialized and shared.")
        self.audio_processor = MultimodalEmotionDataset._audio_processor; self.audio_model = MultimodalEmotionDataset._audio_model
        self.face_detector = MultimodalEmotionDataset._face_detector; self.video_feature_extractor = MultimodalEmotionDataset._video_feature_extractor
        self.video_transform = MultimodalEmotionDataset._video_transform; self.audio_feature_dim = self.audio_model.config.hidden_size; self.video_feature_dim = 2048

    def __len__(self): return len(self.dataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx): idx = idx.tolist()
        try:
            row = self.dataframe.iloc[idx]; text = row.get('clean_text', ""); video_path = row.get('video_file')
            emotion_label = torch.tensor(row.get('emotion_label', 0), dtype=torch.long); sentiment_label = torch.tensor(row.get('sentiment_label', 0), dtype=torch.long)
            speaker_label = torch.tensor(row.get('speaker_label', 0), dtype=torch.long)
            text_features = self._process_text(text); audio_features = torch.zeros(self.audio_feature_dim, dtype=torch.float); video_features = torch.zeros(self.video_feature_dim, dtype=torch.float)
            if isinstance(video_path, str) and os.path.exists(video_path):
                 video_features = self._process_video(video_path)
                 if torch.any(video_features != 0): audio_features = self._process_audio(video_path)
            return { 'text_input_ids': text_features['input_ids'], 'text_attention_mask': text_features['attention_mask'], 'audio_features': audio_features, 'video_features': video_features,
                     'emotion_label': emotion_label, 'sentiment_label': sentiment_label, 'speaker_label': speaker_label }
        except Exception as e:
             dummy_text = self._process_text(""); return { 'text_input_ids': dummy_text['input_ids'], 'text_attention_mask': dummy_text['attention_mask'],'audio_features': torch.zeros(self.audio_feature_dim, dtype=torch.float),
                     'video_features': torch.zeros(self.video_feature_dim, dtype=torch.float), 'emotion_label': torch.tensor(0, dtype=torch.long), 'sentiment_label': torch.tensor(0, dtype=torch.long),'speaker_label': torch.tensor(0, dtype=torch.long) }

    def _process_text(self, text):
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_text_length, return_tensors='pt', add_special_tokens=True)
        return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0) }

    def _process_audio(self, video_path):
        temp_audio_path = None
        try:
            with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpfile: temp_audio_path = tmpfile.name
            with VideoFileClip(video_path) as video_clip:
                if video_clip.audio is None: return torch.zeros(self.audio_feature_dim, dtype=torch.float)
                video_clip.audio.write_audiofile(temp_audio_path, codec='pcm_s16le')
            if librosa.get_duration(path=temp_audio_path) < 0.1: return torch.zeros(self.audio_feature_dim, dtype=torch.float)
            y, sr = librosa.load(temp_audio_path, sr=16000)
            if len(y) == 0: return torch.zeros(self.audio_feature_dim, dtype=torch.float)
            inputs = self.audio_processor(y, sampling_rate=16000, return_tensors="pt", padding=True)
            input_values = inputs.input_values.to(self.device)
            with torch.no_grad(): outputs = self.audio_model(input_values); features = outputs.last_hidden_state
            pooled_features = torch.mean(features, dim=1).squeeze(0)
            return pooled_features.cpu()
        except Exception as e: return torch.zeros(self.audio_feature_dim, dtype=torch.float)
        finally:
            if temp_audio_path and os.path.exists(temp_audio_path):
                try: os.remove(temp_audio_path)
                except Exception: pass

    def _process_video(self, video_path):
        try:
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened(): return torch.zeros(self.video_feature_dim, dtype=torch.float)
            frame_features = []; frame_count = 0; max_frames = 16; sample_rate = 5
            while cap.isOpened() and len(frame_features) < max_frames:
                ret, frame = cap.read();
                if not ret: break
                if frame_count % sample_rate == 0:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB); frame_pil = Image.fromarray(frame_rgb)
                    boxes = self.face_detector.detect(frame_pil, landmarks=False)
                    if boxes is not None and len(boxes) > 0 and boxes[0] is not None:
                        box = boxes[0]; img_w, img_h = frame_pil.size
                        x1, y1, x2, y2 = [int(max(0, min(coord, dim))) for coord, dim in zip(box, [img_w, img_h, img_w, img_h])]
                        if x1 < x2 and y1 < y2:
                             face_pil = frame_pil.crop((x1, y1, x2, y2))
                             if face_pil.size[0] > 0 and face_pil.size[1] > 0:
                                 face_tensor = self.video_transform(face_pil).unsqueeze(0).to(self.device)
                                 with torch.no_grad(): features = self.video_feature_extractor(face_tensor)
                                 frame_features.append(features.squeeze(0).cpu())
                frame_count += 1
            cap.release()
            if frame_features: return torch.mean(torch.stack(frame_features), dim=0)
            else: return torch.zeros(self.video_feature_dim, dtype=torch.float)
        except Exception as e: return torch.zeros(self.video_feature_dim, dtype=torch.float)


# ---------------------------------------------------------------------------
# 5. Initialize Tokenizer, Load DataFrames, Create Datasets & DataLoaders
# ---------------------------------------------------------------------------
print("\n--- Initializing Tokenizer ---")
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
print("RoBERTa Tokenizer initialized.")

print("\n--- Loading Preprocessed DataFrames ---")
try:
    df_train = pd.read_csv(os.path.join(BASE_DATA_DIR, 'preprocessed_train.csv'))
    df_val = pd.read_csv(os.path.join(BASE_DATA_DIR, 'preprocessed_val.csv'))
    df_test = pd.read_csv(os.path.join(BASE_DATA_DIR, 'preprocessed_test.csv'))
    print("Preprocessed dataframes loaded.")
except FileNotFoundError: print("ERROR: Preprocessed CSV files not found. Run Block 3."); raise

print("\n--- Creating Datasets ---")
try:
    train_dataset = MultimodalEmotionDataset(df_train, tokenizer, device=device)
    val_dataset = MultimodalEmotionDataset(df_val, tokenizer, device=device)
    test_dataset = MultimodalEmotionDataset(df_test, tokenizer, device=device)
    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
except Exception as e: print(f"Error creating Datasets: {e}"); raise

# --- Create DataLoaders ---
BATCH_SIZE = 16   #  16 or 32
NUM_WORKERS = 2   # or 0 for stability
PIN_MEMORY = False

print(f"\n--- Creating DataLoaders (Batch Size: {BATCH_SIZE}, NUM_WORKERS: {NUM_WORKERS}) ---")
dl_args = {'batch_size': BATCH_SIZE, 'num_workers': NUM_WORKERS, 'pin_memory': PIN_MEMORY}

train_loader = DataLoader(train_dataset, shuffle=True, **dl_args)
val_loader = DataLoader(val_dataset, shuffle=False, **dl_args)
test_loader = DataLoader(test_dataset, shuffle=False, **dl_args)
print(f"DataLoaders created. Train batches: {len(train_loader)}")

# --- Test one batch ---
print("\n--- Testing one batch (to check for errors) ---")
try:
    batch_start_time = time.time(); batch = next(iter(train_loader)); batch_end_time = time.time()
    print(f"Time to load one batch: {batch_end_time - batch_start_time:.2f} seconds")
    for key, tensor in batch.items(): print(f"  {key}: {tensor.shape}")
except Exception as e: print(f"Error loading first batch: {e}");

print("\n--- Block 4 Dataset and DataLoaders Complete ---")
# ---------------------------------------------------------------------------
# End of Block 4
# ---------------------------------------------------------------------------

--- Checking moviepy (ensure installed) ---
Using device: cuda

--- Initializing Tokenizer ---
RoBERTa Tokenizer initialized.

--- Loading Preprocessed DataFrames ---
Preprocessed dataframes loaded.

--- Creating Datasets ---
--- Initializing Dataset Models (First Instance) on cuda ---
Initializing HuBERT+Extractor...
Initializing MTCNN+ResNet50...
Dataset Models Initialized and shared.
Train: 9989, Val: 1109, Test: 2610

--- Creating DataLoaders (Batch Size: 16, NUM_WORKERS: 2) ---
DataLoaders created. Train batches: 625

--- Testing one batch (to check for errors) ---
Time to load one batch: 1.37 seconds
  text_input_ids: torch.Size([16, 128])
  text_attention_mask: torch.Size([16, 128])
  audio_features: torch.Size([16, 768])
  video_features: torch.Size([16, 2048])
  emotion_label: torch.Size([16])
  sentiment_label: torch.Size([16])
  speaker_label: torch.Size([16])

--- Block 4 Dataset and DataLoaders Complete (Fixed moviepy Import) ---


In [12]:
# Block 5: Model Architecture (Aggressive Dropout)

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaModel
import pickle
import os

# Check device is defined
if 'device' not in locals(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if 'BASE_DATA_DIR' not in locals(): BASE_DATA_DIR = 'data/MELD'

# Load encoder mappings to get config parameters
try:
    with open(os.path.join(BASE_DATA_DIR, 'encoder_mappings.pkl'), 'rb') as f: encoder_mappings = pickle.load(f)
    num_emotions = len(encoder_mappings['emotion']); num_sentiments = len(encoder_mappings['sentiment']); num_speakers = len(encoder_mappings['speaker'])
    print(f"Loaded counts: Emotions={num_emotions}, Sentiments={num_sentiments}, Speakers={num_speakers}")
except FileNotFoundError: print("ERROR: encoder_mappings.pkl not found. Run Block 3."); raise

print("\n--- Defining Model Components (Aggressive Dropout) ---")

# ---------------------------------------------------------------------------
# 1. Define Modality Encoders
# ---------------------------------------------------------------------------
class TextEncoder(nn.Module):
    def __init__(self, model_name='roberta-base', freeze_bert=False, output_dim=768):
        super().__init__(); self.output_dim = output_dim
        self.bert = RobertaModel.from_pretrained(model_name)
        if freeze_bert: print("Freezing RoBERTa"); [p.requires_grad_(False) for p in self.bert.parameters()]
        else: print("Fine-tuning RoBERTa")
    def forward(self, input_ids, attention_mask): return self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]

class AudioEncoder(nn.Module): # HuBERT base 768
    def __init__(self, input_dim=768, projection_dim=128):
        super().__init__(); self.output_dim = projection_dim
        self.projection = nn.Linear(input_dim, projection_dim)
    def forward(self, x): return self.projection(x)

class VideoEncoder(nn.Module): # ResNet50 2048
    def __init__(self, input_dim=2048, projection_dim=128, dropout=0.1): # Dropout refers to encoder output only
        super().__init__(); self.output_dim = projection_dim
        self.projection = nn.Linear(input_dim, projection_dim); self.dropout = nn.Dropout(dropout)
    def forward(self, x): return self.dropout(self.projection(x))

# ---------------------------------------------------------------------------
# 2. Define the Main Multimodal Classifier
# ---------------------------------------------------------------------------
class EnhancedMultimodalClassifier(nn.Module):
    def __init__( self, num_emotions, num_sentiments, num_speakers,
                 text_input_dim=768, audio_input_dim=768, video_input_dim=2048,
                 projection_dim=128, speaker_embed_dim=64, fusion_hidden_dim=256,
                 num_heads=8, num_layers=4, transformer_ff_dim=512,
                 dropout_rate=0.4,
                 fine_tune_text=True, device='cuda' ):
        super().__init__()
        self.device = torch.device(device); self.num_emotions = num_emotions; self.num_sentiments = num_sentiments
        # Instantiate encoders
        self.text_encoder = TextEncoder(freeze_bert=(not fine_tune_text), output_dim=text_input_dim)
        self.audio_encoder = AudioEncoder(input_dim=audio_input_dim, projection_dim=projection_dim)
        self.video_encoder = VideoEncoder(input_dim=video_input_dim, projection_dim=projection_dim, dropout=0.2)
        self.text_proj = nn.Linear(text_input_dim, projection_dim)
        self.speaker_embedding = nn.Embedding(num_speakers, speaker_embed_dim)
        self.modality_embedding = nn.Embedding(3, fusion_hidden_dim)
        combined_input_dim = projection_dim + speaker_embed_dim
        # Apply main dropout rate to fusion input and final dropout
        self.fusion_input_proj = nn.Sequential( nn.Linear(combined_input_dim, fusion_hidden_dim), nn.ReLU(), nn.Dropout(dropout_rate) )
        encoder_layer = nn.TransformerEncoderLayer( d_model=fusion_hidden_dim, nhead=num_heads, dim_feedforward=transformer_ff_dim, dropout=dropout_rate, activation='relu', batch_first=True )
        transformer_norm = nn.LayerNorm(fusion_hidden_dim)
        self.transformer = nn.TransformerEncoder( encoder_layer, num_layers=num_layers, norm=transformer_norm )
        self.dropout = nn.Dropout(dropout_rate) # Main dropout after fusion
        self.final_layer_norm = nn.LayerNorm(fusion_hidden_dim)
        self.emotion_classifier = nn.Linear(fusion_hidden_dim, num_emotions)
        self.sentiment_classifier = nn.Linear(fusion_hidden_dim, num_sentiments)
        unimodal_dropout = nn.Dropout(0.2)
        self.unimodal_emotion_classifier_t = nn.Linear(projection_dim, num_emotions)
        self.unimodal_emotion_classifier_a = nn.Linear(projection_dim, num_emotions)
        self.unimodal_emotion_classifier_v = nn.Linear(projection_dim, num_emotions)
        print(f"\n--- EnhancedMultimodalClassifier Initialized (Dropout Rate: {dropout_rate}) ---")

    def forward( self, text_input_ids, text_attention_mask, audio_features, video_features, speaker_ids ):
        t = self.text_encoder(text_input_ids, text_attention_mask); t_proj = self.text_proj(t)
        a_proj = self.audio_encoder(audio_features); v_proj = self.video_encoder(video_features)
        spk_emb = self.speaker_embedding(speaker_ids)
        t_comb = torch.cat([t_proj, spk_emb], dim=-1); a_comb = torch.cat([a_proj, spk_emb], dim=-1); v_comb = torch.cat([v_proj, spk_emb], dim=-1)
        t_fused_in = self.fusion_input_proj(t_comb); a_fused_in = self.fusion_input_proj(a_comb); v_fused_in = self.fusion_input_proj(v_comb)
        mod_indices = torch.arange(3, device=self.device).unsqueeze(0).expand(t_fused_in.size(0), -1); mod_emb = self.modality_embedding(mod_indices)
        tokens = torch.stack([t_fused_in, a_fused_in, v_fused_in], dim=1); tokens = tokens + mod_emb
        transformer_output = self.transformer(tokens)
        fused = transformer_output.mean(dim=1); fused = self.final_layer_norm(fused); fused_dropped = self.dropout(fused) # Apply main dropout
        main_emotion_logits = self.emotion_classifier(fused_dropped); main_sentiment_logits = self.sentiment_classifier(fused_dropped)
        unimodal_emotion_logits_t = self.unimodal_emotion_classifier_t(t_proj) # No dropout or use unimodal_dropout(t_proj)
        unimodal_emotion_logits_a = self.unimodal_emotion_classifier_a(a_proj) # No dropout or use unimodal_dropout(a_proj)
        unimodal_emotion_logits_v = self.unimodal_emotion_classifier_v(v_proj) # No dropout or use unimodal_dropout(v_proj)
        return ( main_emotion_logits, main_sentiment_logits, unimodal_emotion_logits_t, unimodal_emotion_logits_a, unimodal_emotion_logits_v )

# ---------------------------------------------------------------------------
# 3. Instantiate the Model
# ---------------------------------------------------------------------------
print("\n--- Instantiating the Multimodal Model (Aggressive Dropout) ---")
config = { "num_emotions": num_emotions, "num_sentiments": num_sentiments, "num_speakers": num_speakers,
           "text_input_dim": 768, "audio_input_dim": 768, "video_input_dim": 2048, "projection_dim": 128,
           "speaker_embed_dim": 64, "fusion_hidden_dim": 256, "num_heads": 8, "num_layers": 4,
           "transformer_ff_dim": 512,
           "dropout_rate": 0.4,
           "fine_tune_text": True, "device": device }
model = EnhancedMultimodalClassifier(**config).to(device)
print("Model Instantiated and moved to device.")

# ---------------------------------------------------------------------------
# 4. Print Model Summary
# ---------------------------------------------------------------------------
print("\n--- Model Architecture Summary ---")
total_params = sum(p.numel() for p in model.parameters()); trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")

print("\n--- Block 5 Model Architecture Complete (Aggressive Dropout) ---")
# ---------------------------------------------------------------------------
# End of Block 5
# ---------------------------------------------------------------------------

Loaded counts: Emotions=7, Sentiments=3, Speakers=304

--- Defining Model Components (Aggressive Dropout) ---

--- Instantiating the Multimodal Model (Aggressive Dropout) ---


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fine-tuning RoBERTa

--- EnhancedMultimodalClassifier Initialized (Dropout Rate: 0.4) ---
Model Instantiated and moved to device.

--- Model Architecture Summary ---
Total Parameters: 127,289,119
Trainable Parameters: 127,289,119

--- Block 5 Model Architecture Complete (Aggressive Dropout) ---


In [13]:
# Block 6: Training and Evaluation Functions with Class Weighting

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight # Import for class weights
import pickle
import os
import pandas as pd # Needed for loading df_train for weights

# --- Setup ---
if 'device' not in locals(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if 'BASE_DATA_DIR' not in locals(): BASE_DATA_DIR = 'data/MELD'
try:
    with open(os.path.join(BASE_DATA_DIR, 'encoder_mappings.pkl'), 'rb') as f: encoder_mappings = pickle.load(f)
except FileNotFoundError: print("WARN: encoder_mappings.pkl not found for Block 6."); encoder_mappings = None
print(f"--- Defining Training/Evaluation Functions (Using Device: {device}) ---")

# ---------------------------------------------------------------------------
# 1. Class Weight & Loss Function Setup
# ---------------------------------------------------------------------------
# --- Calculate Class Weights for Emotion Loss ---
print("\n--- Calculating Class Weights for Emotion ---")
try:
    # Load train dataframe if needed
    df_train_path = os.path.join(BASE_DATA_DIR, 'preprocessed_train.csv')
    if 'df_train' not in locals() or not isinstance(df_train, pd.DataFrame):
        print(f"Loading train dataframe from {df_train_path} for class weight calculation...")
        df_train = pd.read_csv(df_train_path)

    emotion_labels_train = df_train['emotion_label'].values
    emotion_classes = np.unique(emotion_labels_train)
    # Compute weights inversely proportional to class frequency
    class_weights_emotion = compute_class_weight('balanced', classes=emotion_classes, y=emotion_labels_train)
    emotion_class_weights_tensor = torch.tensor(class_weights_emotion, dtype=torch.float).to(device)
    print(f"Using Emotion Class Weights: {np.round(class_weights_emotion, 2)}")
    criterion_nll_emotion = nn.CrossEntropyLoss(weight=emotion_class_weights_tensor)
except Exception as e:
    print(f"WARN: Could not calculate/load data for class weights: {e}. Using standard CrossEntropyLoss for emotion.")
    criterion_nll_emotion = nn.CrossEntropyLoss()

# Standard NLL loss for sentiment (MELD sentiment is more balanced)
criterion_nll_sentiment = nn.CrossEntropyLoss()
# KL Div loss for distillation
criterion_kl = nn.KLDivLoss(reduction='batchmean')
log_softmax = nn.LogSoftmax(dim=-1); softmax = nn.Softmax(dim=-1)
# Loss hyperparameters
ALPHA_NLL_UNI = 0.2; BETA_KL_DISTILL = 0.5; TEMPERATURE = 8.0
print(f"Loss Params: Alpha={ALPHA_NLL_UNI}, Beta={BETA_KL_DISTILL}, Temp={TEMPERATURE}. Using Class Weights for Emotion.")

# ---------------------------------------------------------------------------
# 2. Training Function (Single Epoch - Uses Weighted Loss)
# ---------------------------------------------------------------------------
def train_one_epoch(model, train_loader, optimizer, device, grad_clip_value=1.0, scheduler=None):
    model.train(); total_loss = 0.0; total_emotion_correct = 0; total_sentiment_correct = 0; total_samples = 0
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    for batch in progress_bar:
        text_input_ids = batch['text_input_ids'].to(device); text_attention_mask = batch['text_attention_mask'].to(device)
        audio_features = batch['audio_features'].to(device); video_features = batch['video_features'].to(device)
        emotion_labels = batch['emotion_label'].to(device); sentiment_labels = batch['sentiment_label'].to(device)
        speaker_labels = batch['speaker_label'].to(device)
        optimizer.zero_grad()
        main_emo_logits, main_sent_logits, uni_emo_t_logits, uni_emo_a_logits, uni_emo_v_logits = model(
            text_input_ids, text_attention_mask, audio_features, video_features, speaker_labels )
        # --- Calculate Combined Loss (Using potentially weighted emotion loss) ---
        loss_emo_main = criterion_nll_emotion(main_emo_logits, emotion_labels) # Uses weighted loss
        loss_sent_main = criterion_nll_sentiment(main_sent_logits, sentiment_labels)
        loss_main = loss_emo_main + 0.5 * loss_sent_main
        loss_emo_t = criterion_nll_emotion(uni_emo_t_logits, emotion_labels); loss_emo_a = criterion_nll_emotion(uni_emo_a_logits, emotion_labels); loss_emo_v = criterion_nll_emotion(uni_emo_v_logits, emotion_labels)
        loss_unimodal = loss_emo_t + loss_emo_a + loss_emo_v
        with torch.no_grad(): kl_target_emo = softmax(main_emo_logits / TEMPERATURE)
        loss_kl_emo_t = criterion_kl(log_softmax(uni_emo_t_logits / TEMPERATURE), kl_target_emo)
        loss_kl_emo_a = criterion_kl(log_softmax(uni_emo_a_logits / TEMPERATURE), kl_target_emo)
        loss_kl_emo_v = criterion_kl(log_softmax(uni_emo_v_logits / TEMPERATURE), kl_target_emo)
        loss_distill = loss_kl_emo_t + loss_kl_emo_a + loss_kl_emo_v
        loss = loss_main + ALPHA_NLL_UNI * loss_unimodal + BETA_KL_DISTILL * loss_distill
        # --- End Loss Calculation ---
        loss.backward()
        if grad_clip_value > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_value)
        optimizer.step()
        if scheduler is not None: scheduler.step()
        batch_size = emotion_labels.size(0); total_loss += loss.item() * batch_size; total_samples += batch_size
        emotion_preds = torch.argmax(main_emo_logits, dim=1); sentiment_preds = torch.argmax(main_sent_logits, dim=1)
        total_emotion_correct += (emotion_preds == emotion_labels).sum().item(); total_sentiment_correct += (sentiment_preds == sentiment_labels).sum().item()
        current_lr = optimizer.param_groups[0]['lr']
        progress_bar.set_postfix({'Loss': f"{loss.item():.4f}", 'EmoAcc': f"{(total_emotion_correct/total_samples):.3f}", 'LR': f"{current_lr:.2e}"})
    avg_loss = total_loss / total_samples; avg_emotion_acc = total_emotion_correct / total_samples; avg_sentiment_acc = total_sentiment_correct / total_samples
    return { 'loss': avg_loss, 'emotion_acc': avg_emotion_acc, 'sentiment_acc': avg_sentiment_acc }

# ---------------------------------------------------------------------------
# 3. Evaluation Function (Weighted Loss for reporting consistency)
# ---------------------------------------------------------------------------
def evaluate(model, data_loader, device):
    model.eval(); total_loss = 0.0; all_emotion_preds, all_emotion_labels = [], []; all_sentiment_preds, all_sentiment_labels = [], []
    progress_bar = tqdm(data_loader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            text_input_ids = batch['text_input_ids'].to(device); text_attention_mask = batch['text_attention_mask'].to(device)
            audio_features = batch['audio_features'].to(device); video_features = batch['video_features'].to(device)
            emotion_labels = batch['emotion_label'].to(device); sentiment_labels = batch['sentiment_label'].to(device)
            speaker_labels = batch['speaker_label'].to(device)
            main_emo_logits, main_sent_logits, _, _, _ = model( text_input_ids, text_attention_mask, audio_features, video_features, speaker_labels )
            # Uses weighted emotion loss for reporting consistency
            loss_emo_main = criterion_nll_emotion(main_emo_logits, emotion_labels)
            loss_sent_main = criterion_nll_sentiment(main_sent_logits, sentiment_labels)
            loss = loss_emo_main + 0.5 * loss_sent_main
            total_loss += loss.item() * emotion_labels.size(0)
            emotion_preds = torch.argmax(main_emo_logits, dim=1); sentiment_preds = torch.argmax(main_sent_logits, dim=1)
            all_emotion_preds.extend(emotion_preds.cpu().numpy()); all_emotion_labels.extend(emotion_labels.cpu().numpy())
            all_sentiment_preds.extend(sentiment_preds.cpu().numpy()); all_sentiment_labels.extend(sentiment_labels.cpu().numpy())
    avg_loss = total_loss / len(all_emotion_labels)
    emotion_accuracy = accuracy_score(all_emotion_labels, all_emotion_preds)
    emotion_f1_weighted = f1_score(all_emotion_labels, all_emotion_preds, average='weighted', zero_division=0)
    emotion_f1_macro = f1_score(all_emotion_labels, all_emotion_preds, average='macro', zero_division=0)
    emotion_cm = confusion_matrix(all_emotion_labels, all_emotion_preds)
    emo_names = list(encoder_mappings['emotion'].keys()) if encoder_mappings and 'emotion' in encoder_mappings else None
    emotion_report = classification_report(all_emotion_labels, all_emotion_preds, target_names=emo_names, digits=4, zero_division=0)
    sentiment_accuracy = accuracy_score(all_sentiment_labels, all_sentiment_preds)
    sentiment_f1_weighted = f1_score(all_sentiment_labels, all_sentiment_preds, average='weighted', zero_division=0)
    sentiment_f1_macro = f1_score(all_sentiment_labels, all_sentiment_preds, average='macro', zero_division=0)
    sentiment_cm = confusion_matrix(all_sentiment_labels, all_sentiment_preds)
    sent_names = list(encoder_mappings['sentiment'].keys()) if encoder_mappings and 'sentiment' in encoder_mappings else None
    sentiment_report = classification_report(all_sentiment_labels, all_sentiment_preds, target_names=sent_names, digits=4, zero_division=0)
    return { 'loss': avg_loss, 'emotion_acc': emotion_accuracy, 'emotion_f1_weighted': emotion_f1_weighted, 'emotion_f1_macro': emotion_f1_macro,
             'emotion_cm': emotion_cm, 'emotion_report': emotion_report, 'sentiment_acc': sentiment_accuracy, 'sentiment_f1_weighted': sentiment_f1_weighted,
             'sentiment_f1_macro': sentiment_f1_macro, 'sentiment_cm': sentiment_cm, 'sentiment_report': sentiment_report,
             'raw_emotion_preds': all_emotion_preds, 'raw_emotion_labels': all_emotion_labels }


print("\n--- Block 6 Training/Evaluation Functions Defined---")
# ---------------------------------------------------------------------------
# End of Block 6
# ---------------------------------------------------------------------------

--- Defining Training/Evaluation Functions (Using Device: cuda) ---

--- Calculating Class Weights for Emotion ---
Using Emotion Class Weights: [1.29 5.27 5.32 0.82 0.3  2.09 1.18]
Loss Params: Alpha=0.2, Beta=0.5, Temp=8.0. Using Class Weights for Emotion.

--- Block 6 Training/Evaluation Functions Defined (Class Weights Added) ---


In [None]:
# Block 7: Optimizer, Scheduler, and Training Loop (Aggressive Regularization)

import torch
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup
import numpy as np
import time
import os
import pickle
from copy import deepcopy
import matplotlib.pyplot as plt

# --- Ensure Prerequisites ---
if 'device' not in locals(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if 'model' not in locals(): print("ERROR: Model not defined (Run Block 5)"); raise NameError
if 'train_loader' not in locals() or 'val_loader' not in locals() or 'test_loader' not in locals(): print("ERROR: DataLoaders not defined (Run Block 4)"); raise NameError
if 'encoder_mappings' not in locals():
    try:
        with open(os.path.join(BASE_DATA_DIR, 'encoder_mappings.pkl'), 'rb') as f: encoder_mappings = pickle.load(f)
    except: print("WARN: encoder_mappings.pkl not found."); encoder_mappings = None

model.to(device)

print("\n--- Starting Block 7: Training Setup and Execution (Aggressive Regularization) ---")

# ---------------------------------------------------------------------------
# 1. Training Hyperparameters
# ---------------------------------------------------------------------------
NUM_EPOCHS = 25       # Run for enough epochs, relying on early stopping
LEARNING_RATE = 2e-5  # showed promise
WEIGHT_DECAY = 1e-3   # increased from 1e-4
GRAD_CLIP_VALUE = 1.0 
EARLY_STOPPING_PATIENCE = 4 # reduced from 5
BEST_MODEL_PATH = 'best_multimodal_model_agg_reg_cw.pth'

print(f"Training Params: Epochs={NUM_EPOCHS}, LR={LEARNING_RATE}, WeightDecay={WEIGHT_DECAY}")
print(f"Early Stopping: Patience={EARLY_STOPPING_PATIENCE}, Saving to {BEST_MODEL_PATH}")

# ---------------------------------------------------------------------------
# 2. Initialize Optimizer and Scheduler (Re-initialize with new WD and steps)
# ---------------------------------------------------------------------------
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
print(f"\nOptimizer: AdamW (WD={WEIGHT_DECAY})")

if 'train_loader' in locals():
    num_batches_per_epoch = len(train_loader)
    print(f"Detected {num_batches_per_epoch} batches per epoch.")
    total_training_steps = num_batches_per_epoch * NUM_EPOCHS
    num_warmup_steps = int(0.05 * total_training_steps) # 5% warmup
    scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_training_steps )
    print(f"Scheduler: Linear Warmup ({num_warmup_steps} steps) then Decay over {total_training_steps} steps")
else:
    print("ERROR: train_loader not found, cannot initialize scheduler steps.")
    raise NameError("train_loader not found")


# ---------------------------------------------------------------------------
# 3. Training Loop 
# ---------------------------------------------------------------------------
history = { 'train_loss': [], 'val_loss': [], 'train_emotion_acc': [], 'val_emotion_acc': [], 'val_emotion_f1_weighted': [],
            'train_sentiment_acc': [], 'val_sentiment_acc': [], 'val_sentiment_f1_weighted': [] }
best_val_f1 = -1.0; epochs_no_improve = 0; best_model_state = None

print("\n--- Starting NEW Training Run (Aggressive Regularization + Class Weights) ---"); start_training_time = time.time()

# Loop starts from 1
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time(); print(f"\nEpoch {epoch}/{NUM_EPOCHS}")
    # Pass scheduler to the training function
    train_metrics = train_one_epoch(model, train_loader, optimizer, device, GRAD_CLIP_VALUE, scheduler=scheduler)
    # Validation
    val_metrics = evaluate(model, val_loader, device)
    # Logging
    history['train_loss'].append(train_metrics['loss']); history['val_loss'].append(val_metrics['loss'])
    history['train_emotion_acc'].append(train_metrics['emotion_acc']); history['val_emotion_acc'].append(val_metrics['emotion_acc'])
    history['train_sentiment_acc'].append(train_metrics['sentiment_acc']); history['val_sentiment_acc'].append(val_metrics['sentiment_acc'])
    history['val_emotion_f1_weighted'].append(val_metrics['emotion_f1_weighted']); history['val_sentiment_f1_weighted'].append(val_metrics['sentiment_f1_weighted'])
    epoch_duration = time.time() - epoch_start_time
    # Print Summary
    print(f"Epoch {epoch} Summary [Time: {epoch_duration:.2f}s | LR: {optimizer.param_groups[0]['lr']:.2e}]:")
    print(f"  Train: Loss={train_metrics['loss']:.4f}, EmoAcc={train_metrics['emotion_acc']:.4f}, SentAcc={train_metrics['sentiment_acc']:.4f}")
    print(f"  Val:   Loss={val_metrics['loss']:.4f}, EmoAcc={val_metrics['emotion_acc']:.4f}, EmoF1_W={val_metrics['emotion_f1_weighted']:.4f}, SentAcc={val_metrics['sentiment_acc']:.4f}, SentF1_W={val_metrics['sentiment_f1_weighted']:.4f}")
    # Early Stopping Check
    current_val_f1 = val_metrics['emotion_f1_weighted']
    if current_val_f1 > best_val_f1:
        best_val_f1 = current_val_f1; epochs_no_improve = 0; best_model_state = deepcopy(model.state_dict())
        torch.save(model.state_dict(), BEST_MODEL_PATH); print(f"  Validation F1 Improved! Saving model to {BEST_MODEL_PATH}")
    else:
        epochs_no_improve += 1; print(f"  Validation F1 did not improve. Patience: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}")
    if epochs_no_improve >= EARLY_STOPPING_PATIENCE: print(f"\nEarly stopping triggered after {epoch} epochs."); break

# --- End of Training Loop ---
total_training_time = time.time() - start_training_time
print(f"\n--- Training Finished ---")
print(f"Total Training Time: {total_training_time / 60:.2f} minutes")
print(f"Best Validation Emotion F1 (Weighted): {best_val_f1:.4f}")

# ---------------------------------------------------------------------------
# 4. Load Best Model and Evaluate on Test Set
# ---------------------------------------------------------------------------
if os.path.exists(BEST_MODEL_PATH):
    print(f"\nLoading best model weights from {BEST_MODEL_PATH}")
    model.load_state_dict(torch.load(BEST_MODEL_PATH)) # Load weights on current device
elif best_model_state is not None: print("\nLoading best model weights from memory."); model.load_state_dict(best_model_state)
else: print("\nWARN: No best model saved. Evaluating with the final model state.")

print("\n--- Evaluating on Test Set (using best model) ---")
test_metrics = evaluate(model, test_loader, device)

# --- Print Final Results ---
print("\n" + "="*50); print("       Final Test Set Performance (Aggressive Reg + CW)"); print("="*50)
print(f"Test Loss: {test_metrics['loss']:.4f}")
print("\n--- Emotion Results (Test) ---"); print(f"  Accuracy: {test_metrics['emotion_acc']:.4f}, F1 (Weighted): {test_metrics['emotion_f1_weighted']:.4f}, F1 (Macro): {test_metrics['emotion_f1_macro']:.4f}")
print("  Classification Report:\n", test_metrics['emotion_report'])
print("\n--- Sentiment Results (Test) ---"); print(f"  Accuracy: {test_metrics['sentiment_acc']:.4f}, F1 (Weighted): {test_metrics['sentiment_f1_weighted']:.4f}, F1 (Macro): {test_metrics['sentiment_f1_macro']:.4f}")
print("  Classification Report:\n", test_metrics['sentiment_report'])

# --- Save history and test_metrics for this run ---
run_name = "agg_reg_cw"
try:
    with open(f'training_history_{run_name}.pkl', 'wb') as f: pickle.dump(history, f)
    with open(f'test_metrics_{run_name}.pkl', 'wb') as f: pickle.dump(test_metrics, f)
    print(f"\nSaved training history and test metrics to pickle files (suffix: _{run_name}).")
except Exception as e: print(f"Error saving history/metrics: {e}")

print("\n--- Block 7 Training and Evaluation Complete ---")
# ---------------------------------------------------------------------------
# End of Block 7
# ---------------------------------------------------------------------------


--- Starting Block 7: Training Setup and Execution (Aggressive Regularization) ---
Training Params: Epochs=25, LR=2e-05, WeightDecay=0.001
Early Stopping: Patience=4, Saving to best_multimodal_model_agg_reg_cw.pth

Optimizer: AdamW (WD=0.001)
Detected 625 batches per epoch.
Scheduler: Linear Warmup (781 steps) then Decay over 15625 steps

--- Starting NEW Training Run (Aggressive Regularization + Class Weights) ---

Epoch 1/25


Training:   0%|          | 0/625 [00:00<?, ?it/s]

In [None]:
# Block 8: Plotting Training History and Results

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import pickle

# --- Setup ---
if 'BASE_DATA_DIR' not in locals(): BASE_DATA_DIR = 'data/MELD'
plt.style.use('seaborn-v0_8-whitegrid') # Set plot style

# --- Load necessary data ---
try: # Load history and metrics saved from Block 7
    with open('training_history.pkl', 'rb') as f: history = pickle.load(f)
    with open('test_metrics.pkl', 'rb') as f: test_metrics = pickle.load(f)
except FileNotFoundError: print("ERROR: history or test_metrics pkl file not found. Run Block 7 first."); raise
try: # Load encoder mappings for labels
    with open(os.path.join(BASE_DATA_DIR, 'encoder_mappings.pkl'), 'rb') as f: encoder_mappings = pickle.load(f)
except FileNotFoundError: print("WARN: encoder_mappings.pkl not found."); encoder_mappings = None

print("--- Starting Block 8: Plotting History and Results ---")

# ---------------------------------------------------------------------------
# 1. Plot Training History
# ---------------------------------------------------------------------------
def plot_training_history(history, num_epochs_ran):
    epochs = range(1, num_epochs_ran + 1)
    fig, axes = plt.subplots(2, 3, figsize=(18, 10)); fig.suptitle('Training History', fontsize=16)
    # Loss
    axes[0, 0].plot(epochs, history.get('train_loss', []), 'o-', label='Train Loss'); axes[0, 0].plot(epochs, history.get('val_loss', []), 'o-', label='Validation Loss')
    axes[0, 0].set_title('Loss vs. Epochs'); axes[0, 0].set_xlabel('Epoch'); axes[0, 0].set_ylabel('Loss'); axes[0, 0].legend(); axes[0, 0].grid(True)
    # Emotion Accuracy
    axes[0, 1].plot(epochs, history.get('train_emotion_acc', []), 'o-', label='Train Acc'); axes[0, 1].plot(epochs, history.get('val_emotion_acc', []), 'o-', label='Validation Acc')
    axes[0, 1].set_title('Emotion Accuracy vs. Epochs'); axes[0, 1].set_xlabel('Epoch'); axes[0, 1].set_ylabel('Accuracy'); axes[0, 1].legend(); axes[0, 1].grid(True)
    # Sentiment Accuracy
    axes[0, 2].plot(epochs, history.get('train_sentiment_acc', []), 'o-', label='Train Acc'); axes[0, 2].plot(epochs, history.get('val_sentiment_acc', []), 'o-', label='Validation Acc')
    axes[0, 2].set_title('Sentiment Accuracy vs. Epochs'); axes[0, 2].set_xlabel('Epoch'); axes[0, 2].set_ylabel('Accuracy'); axes[0, 2].legend(); axes[0, 2].grid(True)
    # Validation Emotion F1
    axes[1, 0].plot(epochs, history.get('val_emotion_f1_weighted', []), 'o-', label='Emotion F1 (Wt)', color='purple')
    axes[1, 0].set_title('Validation Emotion F1 (Weighted)'); axes[1, 0].set_xlabel('Epoch'); axes[1, 0].set_ylabel('F1 Score'); axes[1, 0].legend(); axes[1, 0].grid(True)
    # Validation Sentiment F1
    axes[1, 1].plot(epochs, history.get('val_sentiment_f1_weighted', []), 'o-', label='Sentiment F1 (Wt)', color='green')
    axes[1, 1].set_title('Validation Sentiment F1 (Weighted)'); axes[1, 1].set_xlabel('Epoch'); axes[1, 1].set_ylabel('F1 Score'); axes[1, 1].legend(); axes[1, 1].grid(True)
    axes[1, 2].axis('off'); plt.tight_layout(rect=[0, 0.03, 1, 0.95]); plt.savefig('training_history_plots.png'); plt.show()

num_epochs_completed = len(history['val_loss'])
if num_epochs_completed > 0: print(f"\n--- Plotting Training History ({num_epochs_completed} Epochs) ---"); plot_training_history(history, num_epochs_completed)
else: print("\nWARN: History dictionary is empty. Cannot plot training history.")

# ---------------------------------------------------------------------------
# 2. Display Final Test Results & Confusion Matrices
# ---------------------------------------------------------------------------
print("\n" + "="*50); print("       Final Test Set Performance (Recap)"); print("="*50)
print(f"\nTest Loss: {test_metrics.get('loss', 'N/A'):.4f}")
# --- Emotion ---
print("\n--- Emotion Results (Test) ---")
print(f"  Accuracy: {test_metrics.get('emotion_acc', 'N/A'):.4f}, F1 (Weighted): {test_metrics.get('emotion_f1_weighted', 'N/A'):.4f}, F1 (Macro): {test_metrics.get('emotion_f1_macro', 'N/A'):.4f}")
print("  Classification Report:\n", test_metrics.get('emotion_report', 'N/A'))
emotion_cm = test_metrics.get('emotion_cm'); emotion_names = list(encoder_mappings['emotion'].keys()) if encoder_mappings else None
if emotion_cm is not None:
    plt.figure(figsize=(8, 6)); sns.heatmap(emotion_cm, annot=True, fmt='d', cmap='Blues', xticklabels=emotion_names, yticklabels=emotion_names)
    plt.title('Emotion Confusion Matrix (Test Set)'); plt.xlabel('Predicted Label'); plt.ylabel('True Label'); plt.savefig('emotion_confusion_matrix_test.png'); plt.show()
# --- Sentiment ---
print("\n--- Sentiment Results (Test) ---")
print(f"  Accuracy: {test_metrics.get('sentiment_acc', 'N/A'):.4f}, F1 (Weighted): {test_metrics.get('sentiment_f1_weighted', 'N/A'):.4f}, F1 (Macro): {test_metrics.get('sentiment_f1_macro', 'N/A'):.4f}")
print("  Classification Report:\n", test_metrics.get('sentiment_report', 'N/A'))
sentiment_cm = test_metrics.get('sentiment_cm'); sentiment_names = list(encoder_mappings['sentiment'].keys()) if encoder_mappings else None
if sentiment_cm is not None:
    plt.figure(figsize=(6, 4)); sns.heatmap(sentiment_cm, annot=True, fmt='d', cmap='Blues', xticklabels=sentiment_names, yticklabels=sentiment_names)
    plt.title('Sentiment Confusion Matrix (Test Set)'); plt.xlabel('Predicted Label'); plt.ylabel('True Label'); plt.savefig('sentiment_confusion_matrix_test.png'); plt.show()

print("\n--- Block 8 Plotting History and Results Complete ---")
# ---------------------------------------------------------------------------
# End of Block 8
# ---------------------------------------------------------------------------

In [None]:
# Block 9: Detailed Model Analysis

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.manifold import TSNE
import pandas as pd
import os
import pickle
from tqdm.notebook import tqdm
import time # For t-SNE timing

# --- Setup ---
if 'device' not in locals(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if 'BASE_DATA_DIR' not in locals(): BASE_DATA_DIR = 'data/MELD'
if 'BEST_MODEL_PATH' not in locals(): BEST_MODEL_PATH = 'best_multimodal_model.pth'

# --- Load Model ---
# Ensure model CLASS is defined
if 'EnhancedMultimodalClassifier' not in globals():
     print("ERROR: Model class 'EnhancedMultimodalClassifier' not defined. Ensure Block 5's definition was run.")
     raise NameError("Model class not defined")

# Load encoder mappings to get config parameters (num classes/speakers)
try:
    with open(os.path.join(BASE_DATA_DIR, 'encoder_mappings.pkl'), 'rb') as f: encoder_mappings = pickle.load(f)
    num_emotions = len(encoder_mappings['emotion']); num_sentiments = len(encoder_mappings['sentiment']); num_speakers = len(encoder_mappings['speaker'])
except FileNotFoundError: print("ERROR: encoder_mappings.pkl not found. Cannot infer model config."); raise

config = { "num_emotions": num_emotions, "num_sentiments": num_sentiments, "num_speakers": num_speakers,
           "text_input_dim": 768, "audio_input_dim": 768, "video_input_dim": 2048, "projection_dim": 128,
           "speaker_embed_dim": 64, "fusion_hidden_dim": 256, "num_heads": 8, "num_layers": 4,
           "transformer_ff_dim": 512, "dropout_rate": 0.2, "fine_tune_text": True, "device": device }
model = EnhancedMultimodalClassifier(**config)

# Load best weights
if os.path.exists(BEST_MODEL_PATH):
    print(f"Loading best model weights from {BEST_MODEL_PATH}")
    # Load state dict, handling potential device mismatch if needed
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=torch.device('cpu'))) # Load to CPU first
    model.to(device) # Then move to target device
    print("Best model loaded.")
else:
    print(f"ERROR: Best model file not found at {BEST_MODEL_PATH}. Cannot perform analysis.")
    raise FileNotFoundError(f"Best model not found at {BEST_MODEL_PATH}")

model.eval() # Set to eval mode for analysis

# --- Load DataLoaders ---
if 'test_loader' not in locals() or 'test_dataset' not in locals() or 'tokenizer' not in locals():
     print("ERROR: test_loader/test_dataset/tokenizer not defined. Run Block 4.")
     raise NameError("DataLoader/Dataset/Tokenizer not found")

# Create inverse mappings for labels
inv_emotion_map = {v: k for k, v in encoder_mappings['emotion'].items()}
inv_sentiment_map = {v: k for k, v in encoder_mappings['sentiment'].items()}

print("\n--- Starting Block 9: Detailed Model Analysis ---")

# ---------------------------------------------------------------------------
# 1. Error Analysis: Analyze Misclassifications
# ---------------------------------------------------------------------------
def analyze_misclassifications(model, data_loader, tokenizer, device, num_samples=10):
    model.eval(); misclassified_samples = []; count = 0
    print(f"\n--- Analyzing up to {num_samples} Emotion Misclassifications ---")
    progress_bar = tqdm(data_loader, desc="Finding Misclassifications", leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            text_input_ids = batch['text_input_ids'].to(device); text_attention_mask = batch['text_attention_mask'].to(device)
            audio_features = batch['audio_features'].to(device); video_features = batch['video_features'].to(device)
            emotion_labels = batch['emotion_label'].cpu(); speaker_labels = batch['speaker_label'].to(device) # Keep true labels on CPU
            main_emo_logits, _, _, _, _ = model( text_input_ids, text_attention_mask, audio_features, video_features, speaker_labels )
            emotion_preds = torch.argmax(main_emo_logits, dim=1).cpu()
            incorrect_indices = torch.where(emotion_preds != emotion_labels)[0]
            for idx in incorrect_indices:
                if count < num_samples:
                    original_text_tokens = batch['text_input_ids'][idx].cpu() # Decode from CPU tensor
                    original_text = tokenizer.decode(original_text_tokens, skip_special_tokens=True)
                    true_label_idx = emotion_labels[idx].item(); pred_label_idx = emotion_preds[idx].item()
                    misclassified_samples.append({ 'text': original_text, 'true_emotion': inv_emotion_map.get(true_label_idx, f"Idx {true_label_idx}"),
                        'predicted_emotion': inv_emotion_map.get(pred_label_idx, f"Idx {pred_label_idx}"),
                        'confidence': torch.softmax(main_emo_logits[idx], dim=0)[pred_label_idx].item() })
                    count += 1
                else: break
            if count >= num_samples: break
    print(f"\nFound {len(misclassified_samples)} misclassified examples:")
    for i, sample in enumerate(misclassified_samples):
        print(f"\nExample {i+1}: Text: \"{sample['text']}\"")
        print(f"  True: {sample['true_emotion']} | Predicted: {sample['predicted_emotion']} (Conf: {sample['confidence']:.3f})")
    return misclassified_samples
misclassified = analyze_misclassifications(model, test_loader, tokenizer, device, num_samples=10)

# ---------------------------------------------------------------------------
# 2. Simulate Modality Ablation
# ---------------------------------------------------------------------------
def simulate_ablation(model, data_loader, device):
    model.eval(); results = { 'TAV': {'preds': [], 'labels': []}, 'T__': {'preds': [], 'labels': []}, 'A__': {'preds': [], 'labels': []},
                              'V__': {'preds': [], 'labels': []}, 'TA_': {'preds': [], 'labels': []}, 'T_V': {'preds': [], 'labels': []}, '_AV': {'preds': [], 'labels': []} }
    print("\n--- Simulating Modality Ablation ---"); progress_bar = tqdm(data_loader, desc="Simulating Ablation", leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            text_input_ids = batch['text_input_ids'].to(device); text_attention_mask = batch['text_attention_mask'].to(device)
            audio_features = batch['audio_features'].to(device); video_features = batch['video_features'].to(device)
            emotion_labels = batch['emotion_label'].cpu().numpy(); speaker_labels = batch['speaker_label'].to(device)
            t = model.text_encoder(text_input_ids, text_attention_mask); t_proj = model.text_proj(t)
            a_proj = model.audio_encoder(audio_features); v_proj = model.video_encoder(video_features)
            spk_emb = model.speaker_embedding(speaker_labels)
            def run_fusion_forward(t_in, a_in, v_in, spk_in):
                 t_comb = torch.cat([t_in, spk_in], dim=-1); a_comb = torch.cat([a_in, spk_in], dim=-1); v_comb = torch.cat([v_in, spk_in], dim=-1)
                 t_fused_in = model.fusion_input_proj(t_comb); a_fused_in = model.fusion_input_proj(a_comb); v_fused_in = model.fusion_input_proj(v_comb)
                 mod_indices = torch.arange(3, device=model.device).unsqueeze(0).expand(t_fused_in.size(0), -1); mod_emb = model.modality_embedding(mod_indices)
                 tokens = torch.stack([t_fused_in, a_fused_in, v_fused_in], dim=1); tokens = tokens + mod_emb
                 transformer_output = model.transformer(tokens); fused = transformer_output.mean(dim=1); fused = model.final_layer_norm(fused)
                 return model.emotion_classifier(fused)
            z_a = torch.zeros_like(a_proj); z_v = torch.zeros_like(v_proj); z_t = torch.zeros_like(t_proj)
            scenarios = { 'TAV': (t_proj, a_proj, v_proj), 'T__': (t_proj, z_a, z_v), 'A__': (z_t, a_proj, z_v), 'V__': (z_t, z_a, v_proj),
                          'TA_': (t_proj, a_proj, z_v), 'T_V': (t_proj, z_a, v_proj), '_AV': (z_t, a_proj, v_proj) }
            for key, (t_in, a_in, v_in) in scenarios.items():
                 logits = run_fusion_forward(t_in, a_in, v_in, spk_emb); preds = torch.argmax(logits, dim=1).cpu().numpy()
                 results[key]['preds'].extend(preds); results[key]['labels'].extend(emotion_labels)
    ablation_metrics = {}
    print("\nAblation Results (Emotion):")
    for key, data in results.items():
        acc = accuracy_score(data['labels'], data['preds']); f1 = f1_score(data['labels'], data['preds'], average='weighted', zero_division=0)
        ablation_metrics[key] = {'acc': acc, 'f1_weighted': f1}; print(f"  '{key}': Acc={acc:.4f}, F1(Wt)={f1:.4f}")
    # Plotting
    labels = list(ablation_metrics.keys()); acc_values = [m['acc'] for m in ablation_metrics.values()]; f1_values = [m['f1_weighted'] for m in ablation_metrics.values()]
    x = np.arange(len(labels)); width = 0.35; fig, ax = plt.subplots(figsize=(12, 6))
    rects1 = ax.bar(x - width/2, acc_values, width, label='Accuracy', color='skyblue'); rects2 = ax.bar(x + width/2, f1_values, width, label='F1 (Weighted)', color='lightcoral')
    ax.set_ylabel('Score'); ax.set_title('Simulated Modality Ablation Performance (Emotion)'); ax.set_xticks(x); ax.set_xticklabels(labels); ax.legend()
    ax.bar_label(rects1, padding=3, fmt='%.3f'); ax.bar_label(rects2, padding=3, fmt='%.3f'); fig.tight_layout(); plt.savefig('simulated_ablation_performance.png'); plt.show()
    return ablation_metrics
ablation_results = simulate_ablation(model, test_loader, device)

# ---------------------------------------------------------------------------
# 3. Feature Visualization with t-SNE
# ---------------------------------------------------------------------------
def visualize_feature_space(model, data_loader, device, num_samples=1500):
    model.eval(); features_list = []; labels_list = []
    subset_indices = np.random.choice(len(data_loader.dataset), min(num_samples, len(data_loader.dataset)), replace=False)
    subset = Subset(data_loader.dataset, subset_indices)
    # Reduce batch size for visualization if memory is tight, use num_workers=0
    subset_loader = DataLoader(subset, batch_size=max(1, data_loader.batch_size // 2), shuffle=False, num_workers=0)
    print(f"\n--- Generating Features for t-SNE ({len(subset)} samples) ---")
    progress_bar = tqdm(subset_loader, desc="Getting Features", leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            text_input_ids = batch['text_input_ids'].to(device); text_attention_mask = batch['text_attention_mask'].to(device)
            audio_features = batch['audio_features'].to(device); video_features = batch['video_features'].to(device)
            emotion_labels = batch['emotion_label'].cpu().numpy(); speaker_labels = batch['speaker_label'].to(device)
            t = model.text_encoder(text_input_ids, text_attention_mask); t_proj = model.text_proj(t)
            a_proj = model.audio_encoder(audio_features); v_proj = model.video_encoder(video_features)
            spk_emb = model.speaker_embedding(speaker_labels)
            t_comb = torch.cat([t_proj, spk_emb], dim=-1); a_comb = torch.cat([a_proj, spk_emb], dim=-1); v_comb = torch.cat([v_proj, spk_emb], dim=-1)
            t_fused_in = model.fusion_input_proj(t_comb); a_fused_in = model.fusion_input_proj(a_comb); v_fused_in = model.fusion_input_proj(v_comb)
            mod_indices = torch.arange(3, device=model.device).unsqueeze(0).expand(t_fused_in.size(0), -1); mod_emb = model.modality_embedding(mod_indices)
            tokens = torch.stack([t_fused_in, a_fused_in, v_fused_in], dim=1); tokens = tokens + mod_emb
            transformer_output = model.transformer(tokens); fused = transformer_output.mean(dim=1)
            # fused = model.final_layer_norm(fused) # Use features before final norm/dropout
            features_list.append(fused.cpu().numpy()); labels_list.append(emotion_labels)
    features_array = np.concatenate(features_list, axis=0); labels_array = np.concatenate(labels_list, axis=0)
    print("Running t-SNE..."); tsne_start_time = time.time()
    tsne = TSNE(n_components=2, random_state=SEED, perplexity=min(30, len(features_array)-1), n_iter=300, verbose=0) # Adjust perplexity
    features_2d = tsne.fit_transform(features_array)
    print(f"t-SNE completed in {time.time() - tsne_start_time:.2f} seconds.")
    # Plotting
    plt.figure(figsize=(12, 10)); emotion_names = list(inv_emotion_map.values()) if inv_emotion_map else sorted(np.unique(labels_array))
    num_classes = len(emotion_names); colors = plt.cm.rainbow(np.linspace(0, 1, num_classes))
    for i, label_name in enumerate(emotion_names):
         label_idx = -1;
         for idx, name in inv_emotion_map.items():
             if name == label_name: label_idx = idx; break
         if label_idx != -1: mask = labels_array == label_idx; plt.scatter(features_2d[mask, 0], features_2d[mask, 1], color=colors[i], label=label_name, alpha=0.6, s=15)
    plt.title('t-SNE Visualization of Fused Feature Space (Colored by Emotion)'); plt.xlabel('t-SNE Component 1'); plt.ylabel('t-SNE Component 2')
    plt.legend(markerscale=2); plt.grid(True); plt.savefig('tsne_feature_visualization.png'); plt.show()

visualize_feature_space(model, test_loader, device, num_samples=1500) # Run on test set subset

print("\n--- Block 9 Detailed Model Analysis Complete ---")
# ---------------------------------------------------------------------------
# End of Block 9
# ---------------------------------------------------------------------------