In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler
from transformers import BertModel
import warnings
warnings.filterwarnings('ignore')

# Configuration

np.random.seed(42)
torch.manual_seed(42)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {DEVICE}\n")

# Settings
BATCH_SIZE = 4
NUM_EPOCHS = 15
LEARNING_RATE = 2e-5
DATA_FOLDER = './ted_humor_data'
SAVE_FOLDER = './saved_models'
os.makedirs(SAVE_FOLDER, exist_ok=True)

print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Data folder: {DATA_FOLDER}\n")

# Dataset Class

class BertLSTMDataset(Dataset):
    """Dataset that preserves temporal sequences for Audio and Video."""

    def __init__(self, ids, text_data, audio_data, video_data, labels,
                 word_embeddings, audio_scaler, video_scaler, is_train=False):
        self.ids = ids
        self.text_data = text_data
        self.audio_data = audio_data
        self.video_data = video_data
        self.labels = labels
        self.word_embeddings = word_embeddings
        self.audio_scaler = audio_scaler
        self.video_scaler = video_scaler
        self.is_train = is_train

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sample_id = self.ids[idx]

        # Text
        text_data = self.text_data[sample_id]
        if isinstance(text_data, dict):
            if 'punchline_features' in text_data:
                text_indices = text_data['punchline_features']
            elif 'punchline' in text_data:
                text_indices = text_data['punchline']
            else:
                text_indices = list(text_data.values())[0]
        else:
            text_indices = text_data

        word_indices = []
        if isinstance(text_indices, (list, np.ndarray)):
            for idx_list in text_indices:
                try:
                    if isinstance(idx_list, (list, np.ndarray)) and len(idx_list) > 0:
                        idx_val = int(idx_list[0]) if hasattr(idx_list[0], '__int__') else 0
                    else:
                        idx_val = int(idx_list) if hasattr(idx_list, '__int__') else 0

                    if 0 <= idx_val < 30522:  # BERT vocab size
                        word_indices.append(idx_val)
                except:
                    pass

        if len(word_indices) == 0:
            word_indices = [101]  # [CLS] token
        word_indices = word_indices[:512]

        # Audio
        # KEEP as sequence: shape (T, 81)
        audio_data = self.audio_data[sample_id]
        if isinstance(audio_data, dict):
            if 'punchline_features' in audio_data:
                audio_features = audio_data['punchline_features']
            elif 'punchline' in audio_data:
                audio_features = audio_data['punchline']
            else:
                audio_features = list(audio_data.values())[0]
        else:
            audio_features = audio_data

        try:
            audio_seq = np.array(audio_features, dtype=np.float32).reshape(-1, 81)
            if audio_seq.shape[0] == 0:
                audio_seq = np.zeros((1, 81), dtype=np.float32)

            if self.audio_scaler:
                audio_seq = self.audio_scaler.transform(audio_seq)
        except:
            audio_seq = np.zeros((1, 81), dtype=np.float32)

        # Video
        # KEEP as sequence: shape (T, 75)
        video_data = self.video_data[sample_id]
        if isinstance(video_data, dict):
            if 'punchline_features' in video_data:
                video_features = video_data['punchline_features']
            elif 'punchline' in video_data:
                video_features = video_data['punchline']
            else:
                video_features = list(video_data.values())[0]
        else:
            video_features = video_data

        try:
            video_seq = np.array(video_features, dtype=np.float32).reshape(-1, 75)
            if video_seq.shape[0] == 0:
                video_seq = np.zeros((1, 75), dtype=np.float32)

            if self.video_scaler:
                video_seq = self.video_scaler.transform(video_seq)
        except:
            video_seq = np.zeros((1, 75), dtype=np.float32)

        #  Modality Dropout
        if self.is_train and np.random.rand() < 0.2:
            which = np.random.choice(['text', 'audio', 'video'])
            if which == 'text':
                word_indices = [101]
            elif which == 'audio':
                audio_seq = np.zeros((1, 81), dtype=np.float32)
            elif which == 'video':
                video_seq = np.zeros((1, 75), dtype=np.float32)

        label = self.labels[sample_id]

        return {
            'word_indices': torch.LongTensor(word_indices),
            'audio': torch.FloatTensor(audio_seq),  # [T, 81]
            'video': torch.FloatTensor(video_seq),  # [T, 75]
            'label': torch.tensor(label, dtype=torch.long)
        }


# Collate Funciton: Handles Variable-Length Sequences


def collate_fn_lstm(batch):
    """Collate function that pads sequences and tracks their original lengths."""

    # Text
    max_text_len = max(len(item['word_indices']) for item in batch)
    word_indices_padded = []
    attention_masks = []

    for item in batch:
        indices = item['word_indices'].numpy()
        padded = np.pad(indices, (0, max_text_len - len(indices)), constant_values=0)
        word_indices_padded.append(torch.LongTensor(padded))

        mask = np.ones(len(indices))
        mask_padded = np.pad(mask, (0, max_text_len - len(indices)), constant_values=0)
        attention_masks.append(torch.LongTensor(mask_padded))

    # Audion Sequences
    audio_seqs = [item['audio'] for item in batch]  # List of [T_i, 81]
    audio_lengths = torch.tensor([len(seq) for seq in audio_seqs])
    audio_padded = pad_sequence(audio_seqs, batch_first=True)  # [batch, max_T, 81]

    # Video Sequences
    video_seqs = [item['video'] for item in batch]  # List of [T_i, 75]
    video_lengths = torch.tensor([len(seq) for seq in video_seqs])
    video_padded = pad_sequence(video_seqs, batch_first=True)  # [batch, max_T, 75]

    return {
        'word_indices': torch.stack(word_indices_padded),
        'attention_mask': torch.stack(attention_masks),
        'audio': audio_padded,
        'audio_lengths': audio_lengths,
        'video': video_padded,
        'video_lengths': video_lengths,
        'label': torch.stack([item['label'] for item in batch])
    }

# Encoder

class BertTextEncoder(nn.Module):
    """BERT-based text encoder (frozen)."""

    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.bert_dim = 768

    def forward(self, input_ids, attention_mask=None):
        if input_ids.shape[0] == 0 or input_ids.shape[1] == 0:
            return torch.zeros(input_ids.shape[0] if input_ids.shape[0] > 0 else 1,
                              self.bert_dim).to(input_ids.device)

        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        with torch.no_grad():
            outputs = self.bert(input_ids, attention_mask=attention_mask)
            cls_token = outputs.last_hidden_state[:, 0, :]  # [batch, 768]

        return cls_token


class LSTMAudioEncoder(nn.Module):
    """LSTM-based audio encoder that preserves temporal sequences."""

    def __init__(self, input_dim=81, hidden_dim=64, num_layers=2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        self.hidden_dim = hidden_dim

    def forward(self, audio_seq, lengths):

        if audio_seq.shape[0] == 0:
            return torch.zeros(1, self.hidden_dim).to(audio_seq.device)

        # Pack padded sequence
        packed = pack_padded_sequence(audio_seq, lengths.cpu(),
                                     batch_first=True, enforce_sorted=False)

        # LSTM forward
        _, (hidden, cell) = self.lstm(packed)


        last_hidden = hidden[-1]

        return last_hidden


class LSTMVideoEncoder(nn.Module):
    """LSTM-based video encoder that preserves temporal sequences."""

    def __init__(self, input_dim=75, hidden_dim=64, num_layers=2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        self.hidden_dim = hidden_dim

    def forward(self, video_seq, lengths):

        if video_seq.shape[0] == 0:
            return torch.zeros(1, self.hidden_dim).to(video_seq.device)

        # Pack padded sequence
        packed = pack_padded_sequence(video_seq, lengths.cpu(),
                                     batch_first=True, enforce_sorted=False)

        # LSTM forward
        _, (hidden, cell) = self.lstm(packed)


        last_hidden = hidden[-1]

        return last_hidden


# Multimodal Model

class BertLSTMModel(nn.Module):
    """Multimodal model with LSTM sequence encoders for audio and video."""

    def __init__(self, audio_hidden=64, video_hidden=64, lstm_layers=2):
        super().__init__()

        # Text encoder
        self.bert_encoder = BertTextEncoder()
        text_dim = 768

        # Audio encoder
        self.audio_encoder = LSTMAudioEncoder(input_dim=81, hidden_dim=audio_hidden,
                                              num_layers=lstm_layers)
        audio_dim = audio_hidden

        # Video encoder
        self.video_encoder = LSTMVideoEncoder(input_dim=75, hidden_dim=video_hidden,
                                              num_layers=lstm_layers)
        video_dim = video_hidden

        # Fusion layer
        total_dim = text_dim + audio_dim + video_dim
        self.fusion = nn.Sequential(
            nn.Linear(total_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices, audio, video,
               audio_lengths, video_lengths, attention_mask=None):
        # Text encoding
        if word_indices.shape[1] > 0:
            text_out = self.bert_encoder(word_indices, attention_mask)
        else:
            text_out = torch.zeros(audio.shape[0], 768).to(audio.device)

        # Audio encoding
        if audio.shape[0] > 0 and audio.shape[1] > 0:
            audio_out = self.audio_encoder(audio, audio_lengths)
        else:
            audio_out = torch.zeros(audio.shape[0], 64).to(audio.device)

        # Video encoding
        if video.shape[0] > 0 and video.shape[1] > 0:
            video_out = self.video_encoder(video, video_lengths)
        else:
            video_out = torch.zeros(video.shape[0], 64).to(video.device)

        # Fusion
        combined = torch.cat([text_out, audio_out, video_out], dim=1)
        fused = self.fusion(combined)

        # Classification
        logits = self.classifier(fused)

        return logits


# Training Function

def train_epoch(model, loader, optimizer, loss_fn, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0

    for batch in loader:
        word_indices = batch['word_indices'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        audio = batch['audio'].to(device)
        audio_lengths = batch['audio_lengths'].to(device)
        video = batch['video'].to(device)
        video_lengths = batch['video_lengths'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(word_indices, audio, video,
                      audio_lengths, video_lengths, attention_mask)
        loss = loss_fn(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


# Evaluation


def evaluate(model, loader, device, missing_mod=None):
    """Evaluate model with optional missing modality"""
    model.eval()
    preds = []
    targets = []

    with torch.no_grad():
        for batch in loader:
            word_indices = batch['word_indices'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            audio_lengths = batch['audio_lengths'].to(device)
            video = batch['video'].to(device)
            video_lengths = batch['video_lengths'].to(device)
            labels = batch['label'].to(device)

            # Simulate missing modality
            if missing_mod == 'audio':
                audio = torch.zeros_like(audio)
            elif missing_mod == 'video':
                video = torch.zeros_like(video)
            elif missing_mod == 'text':
                word_indices = torch.zeros_like(word_indices)
                attention_mask = torch.zeros_like(attention_mask)

            logits = model(word_indices, audio, video,
                          audio_lengths, video_lengths, attention_mask)
            pred = logits.argmax(dim=1)

            preds.extend(pred.cpu().numpy())
            targets.extend(labels.cpu().numpy())

    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro', zero_division=0)
    return acc, f1


# Data Loading

def load_data():
    """Load and prepare data."""
    print("Loading data")

    with open(f'{DATA_FOLDER}/data_folds.pkl', 'rb') as f:
        data_folds = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/word_embedding_list.pkl', 'rb') as f:
        word_embeddings = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/word_embedding_indexes_sdk.pkl', 'rb') as f:
        text_features = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/covarep_features_sdk.pkl', 'rb') as f:
        audio_features = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/openface_features_sdk.pkl', 'rb') as f:
        video_features = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/humor_label_sdk.pkl', 'rb') as f:
        labels = pickle.load(f, encoding='latin1')

    # Parse folds
    train_ids = None
    try:
        if isinstance(data_folds, dict) and 'train' in data_folds:
            train_ids = data_folds['train']
            dev_ids = data_folds['dev']
            test_ids = data_folds['test']
        elif isinstance(data_folds, list):
            train_ids = data_folds[0].get('train', [])
            dev_ids = data_folds[0].get('dev', [])
            test_ids = data_folds[0].get('test', [])
    except:
        pass

    if train_ids is None or len(train_ids) == 0:
        print("Splitting data 70/15/15")
        all_ids = list(range(len(labels)))
        np.random.shuffle(all_ids)
        split1 = int(0.7 * len(all_ids))
        split2 = int(0.85 * len(all_ids))
        train_ids = all_ids[:split1]
        dev_ids = all_ids[split1:split2]
        test_ids = all_ids[split2:]

    print(f"Train: {len(train_ids)}, Dev: {len(dev_ids)}, Test: {len(test_ids)}")

    # Fit scalers on training audio/video
    train_audio_list = []
    for id in train_ids:
        try:
            audio_data = audio_features[id]
            if isinstance(audio_data, dict):
                audio_data = audio_data.get('punchline_features',
                                           audio_data.get('punchline',
                                           list(audio_data.values())[0]))
            audio_arr = np.array(audio_data, dtype=np.float32).reshape(-1, 81)
            train_audio_list.append(audio_arr)
        except:
            pass

    train_audio = np.vstack(train_audio_list) if train_audio_list else np.zeros((1, 81))
    scaler_audio = StandardScaler().fit(train_audio)

    train_video_list = []
    for id in train_ids:
        try:
            video_data = video_features[id]
            if isinstance(video_data, dict):
                video_data = video_data.get('punchline_features',
                                           video_data.get('punchline',
                                           list(video_data.values())[0]))
            video_arr = np.array(video_data, dtype=np.float32).reshape(-1, 75)
            train_video_list.append(video_arr)
        except:
            pass

    train_video = np.vstack(train_video_list) if train_video_list else np.zeros((1, 75))
    scaler_video = StandardScaler().fit(train_video)

    # Create datasets
    train_set = BertLSTMDataset(train_ids, text_features, audio_features, video_features,
                               labels, word_embeddings, scaler_audio, scaler_video, is_train=True)
    dev_set = BertLSTMDataset(dev_ids, text_features, audio_features, video_features,
                             labels, word_embeddings, scaler_audio, scaler_video, is_train=False)
    test_set = BertLSTMDataset(test_ids, text_features, audio_features, video_features,
                              labels, word_embeddings, scaler_audio, scaler_video, is_train=False)

    return train_set, dev_set, test_set, word_embeddings


def main():
    """Main training pipeline with all phases"""

    train_set, dev_set, test_set, word_emb = load_data()

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                             collate_fn=collate_fn_lstm)
    dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, collate_fn=collate_fn_lstm)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn_lstm)

    loss_fn = nn.CrossEntropyLoss()

    # Phase 1
    print("\n" + "="*70)
    print("PHASE 1: BERT + LSTM AUDIO/VIDEO (Temporal Sequence Modeling)")
    print("="*70)

    model = BertLSTMModel(audio_hidden=64, video_hidden=64, lstm_layers=2).to(DEVICE)

    # Layerwise learning rates
    param_groups = [
        {'params': model.bert_encoder.parameters(), 'lr': 1e-5},
        {'params': model.audio_encoder.parameters(), 'lr': 1e-3},
        {'params': model.video_encoder.parameters(), 'lr': 1e-3},
        {'params': model.fusion.parameters(), 'lr': 1e-3},
        {'params': model.classifier.parameters(), 'lr': 1e-3}
    ]
    opt = optim.AdamW(param_groups, weight_decay=0.01)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=NUM_EPOCHS, eta_min=1e-6)

    best_f1 = 0
    for epoch in range(NUM_EPOCHS):
        loss = train_epoch(model, train_loader, opt, loss_fn, DEVICE)
        dev_acc, dev_f1 = evaluate(model, dev_loader, DEVICE)
        scheduler.step()

        if dev_f1 > best_f1:
            best_f1 = dev_f1
            torch.save(model.state_dict(), f'{SAVE_FOLDER}/phase1_lstm.pt')

        if (epoch + 1) % 3 == 0:
            print(f"Epoch {epoch+1:2d} | Loss: {loss:.4f} | Dev Acc: {dev_acc:.4f} | F1: {dev_f1:.4f}")

    model.load_state_dict(torch.load(f'{SAVE_FOLDER}/phase1_lstm.pt'))
    test_acc, test_f1 = evaluate(model, test_loader, DEVICE)
    print(f" Phase 1: Acc={test_acc:.4f}, F1={test_f1:.4f}\n")

    # Missing Dataset
    print("="*70)
    print("PHASE 2: ABLATION STUDY (Missing Modalities)")
    print("="*70)

    for missing in [None, 'audio', 'video', 'text']:
        acc, f1 = evaluate(model, test_loader, DEVICE, missing)
        name = missing if missing else 'All'
        print(f"Missing {name:6s}: Acc={acc:.4f}, F1={f1:.4f}")

    print()

    # Knowledge Distillation
    print("="*70)
    print("PHASE 3: KNOWLEDGE DISTILLATION (Audio Teacher)")
    print("="*70)

    # Train audio unimodal teacher
    print("Training audio-only teacher")
    class SimpleAudioTeacher(nn.Module):
        def __init__(self):
            super().__init__()
            self.lstm = LSTMAudioEncoder(input_dim=81, hidden_dim=64, num_layers=2)
            self.classifier = nn.Sequential(
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 2)
            )

        def forward(self, audio, lengths):
            audio_feat = self.lstm(audio, lengths)
            return self.classifier(audio_feat)

    audio_teacher = SimpleAudioTeacher().to(DEVICE)
    audio_opt = optim.Adam(audio_teacher.parameters(), lr=1e-3)

    for epoch in range(10):
        audio_teacher.train()
        for batch in train_loader:
            audio = batch['audio'].to(DEVICE)
            audio_lengths = batch['audio_lengths'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            audio_opt.zero_grad()
            logits = audio_teacher(audio, audio_lengths)
            loss = loss_fn(logits, labels)
            loss.backward()
            audio_opt.step()

    # Train student with KD
    print("Training student with KD loss")
    model_kd = BertLSTMModel(audio_hidden=64, video_hidden=64, lstm_layers=2).to(DEVICE)

    param_groups_kd = [
        {'params': model_kd.bert_encoder.parameters(), 'lr': 1e-5},
        {'params': model_kd.audio_encoder.parameters(), 'lr': 1e-3},
        {'params': model_kd.video_encoder.parameters(), 'lr': 1e-3},
        {'params': model_kd.fusion.parameters(), 'lr': 1e-3},
        {'params': model_kd.classifier.parameters(), 'lr': 1e-3}
    ]
    opt_kd = optim.AdamW(param_groups_kd, weight_decay=0.01)
    scheduler_kd = optim.lr_scheduler.CosineAnnealingLR(opt_kd, T_max=NUM_EPOCHS)

    for epoch in range(NUM_EPOCHS):
        model_kd.train()
        total_loss = 0

        for batch in train_loader:
            word_indices = batch['word_indices'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            audio = batch['audio'].to(DEVICE)
            audio_lengths = batch['audio_lengths'].to(DEVICE)
            video = batch['video'].to(DEVICE)
            video_lengths = batch['video_lengths'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            opt_kd.zero_grad()

            student_logits = model_kd(word_indices, audio, video,
                                      audio_lengths, video_lengths, attention_mask)
            ce_loss = loss_fn(student_logits, labels)

            with torch.no_grad():
                teacher_logits = audio_teacher(audio, audio_lengths)

            T = 4.0
            kd_loss = 0.5 * torch.nn.functional.kl_div(
                torch.nn.functional.log_softmax(student_logits/T, dim=1),
                torch.softmax(teacher_logits/T, dim=1),
                reduction='batchmean'
            )

            loss = ce_loss + kd_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_kd.parameters(), 1.0)
            opt_kd.step()

            total_loss += loss.item()

        scheduler_kd.step()

        if (epoch + 1) % 3 == 0:
            print(f"Epoch {epoch+1:2d} | Loss: {total_loss/len(train_loader):.4f}")

    test_acc_kd, test_f1_kd = evaluate(model_kd, test_loader, DEVICE)
    print(f" Phase 3: Acc={test_acc_kd:.4f}, F1={test_f1_kd:.4f}\n")

    # Results
    print(f"Phase 1 (LSTM Baseline): Acc={test_acc:.4f}, F1={test_f1:.4f}")
    print(f"Phase 3 (KD): Acc={test_acc_kd:.4f}, F1={test_f1_kd:.4f}")

    print(f"\nModels saved to: {SAVE_FOLDER}/")


if __name__ == '__main__':
    main()

Using: cuda

Batch size: 4
Epochs: 15
Data folder: ./ted_humor_data

Loading data
Train: 10598, Dev: 2626, Test: 3290

PHASE 1: BERT + LSTM AUDIO/VIDEO (Temporal Sequence Modeling)
Epoch  3 | Loss: 0.6884 | Dev Acc: 0.5967 | F1: 0.5938


KeyboardInterrupt: 

In [None]:
!wget https://www.dropbox.com/s/izk6khkrdwcncia/ted_humor_sdk_v1.zip?dl=1

--2025-12-29 08:32:41--  https://www.dropbox.com/s/izk6khkrdwcncia/ted_humor_sdk_v1.zip?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.dropbox.com/scl/fi/002rz175n5ferwyvif2hv/ted_humor_sdk_v1.zip?rlkey=r8bszra1ez6zedylbx1d4wm99&dl=1 [following]
--2025-12-29 08:32:42--  https://www.dropbox.com/scl/fi/002rz175n5ferwyvif2hv/ted_humor_sdk_v1.zip?rlkey=r8bszra1ez6zedylbx1d4wm99&dl=1
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc5d6abf8e1da565817b8b7131c6.dl.dropboxusercontent.com/cd/0/inline/C3-68fuzf7TwdRf1gofoJuEiQxA3g2WxCPKMvA2JweobVIWyI1sDyOf6gbt2vgOlXxn_Zem9PHQBioxDLeteO_8XqBYjKJCS1RYwXJwmMYLhjW4Ee4FhlvEbefIRVpMRfJ3faP7LKCQAzs4Cu_iZlVyM/file?dl=1# [following]
--2025-12-29 08:32:42--  https://uc5d6abf8e1da565817b8b71

In [None]:
!unzip ted_humor_sdk_v1.zip?dl=1

Archive:  ted_humor_sdk_v1.zip?dl=1
   creating: final_humor_sdk/
  inflating: final_humor_sdk/word_embedding_list.pkl  
  inflating: final_humor_sdk/data_folds.pkl  
  inflating: final_humor_sdk/humor_label_sdk.pkl  
  inflating: final_humor_sdk/covarep_features_sdk.pkl  
  inflating: final_humor_sdk/openface_features_sdk.pkl  
  inflating: final_humor_sdk/word_embedding_indexes_sdk.pkl  


In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# Configuration

np.random.seed(42)
torch.manual_seed(42)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Key Settings
BATCH_SIZE = 4
NUM_EPOCHS = 20
LEARNING_RATE = 2e-5
TEXT_DROPOUT = 0.8
DATA_FOLDER = './ted_humor_data'
SAVE_FOLDER = './saved_models'
os.makedirs(SAVE_FOLDER, exist_ok=True)

# 1. Dataset with raw sequences

class BertLSTMDataset(Dataset):

    def __init__(self, ids, text_data, audio_data, video_data, labels,
                 word_embeddings, audio_scaler, video_scaler,
                 is_train=False, text_dropout=0.0):
        self.ids = ids
        self.text_data = text_data
        self.audio_data = audio_data
        self.video_data = video_data
        self.labels = labels
        self.word_embeddings = word_embeddings
        self.audio_scaler = audio_scaler
        self.video_scaler = video_scaler
        self.is_train = is_train
        self.text_dropout = text_dropout

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sample_id = self.ids[idx]

        # Get Text
        text_data = self.text_data[sample_id]
        if isinstance(text_data, dict):
            if 'punchline_features' in text_data:
                text_indices = text_data['punchline_features']
            elif 'punchline' in text_data:
                text_indices = text_data['punchline']
            else:
                text_indices = list(text_data.values())[0]
        else:
            text_indices = text_data

        word_indices = []
        if isinstance(text_indices, (list, np.ndarray)):
            for idx_list in text_indices:
                try:
                    if isinstance(idx_list, (list, np.ndarray)) and len(idx_list) > 0:
                        idx_val = int(idx_list[0]) if hasattr(idx_list[0], '__int__') else 0
                    else:
                        idx_val = int(idx_list) if hasattr(idx_list, '__int__') else 0

                    if 0 <= idx_val <= 30522:
                        word_indices.append(idx_val)
                except:
                    pass

        if len(word_indices) == 0:
            word_indices = [101]

        word_indices = word_indices[:512]

        # Get Audio - Raw Sequence
        audio_data = self.audio_data[sample_id]
        if isinstance(audio_data, dict):
            if 'punchline_features' in audio_data:
                audio_features = audio_data['punchline_features']
            elif 'punchline' in audio_data:
                audio_features = audio_data['punchline']
            else:
                audio_features = list(audio_data.values())[0]
        else:
            audio_features = audio_data

        try:
            audio_raw = np.array(audio_features, dtype=np.float32).reshape(-1, 81)
            if self.audio_scaler:
                audio_raw = self.audio_scaler.transform(audio_raw)
            audio_seq = audio_raw[:50]  # Cap at 50 timesteps
        except:
            audio_seq = np.zeros((1, 81), dtype=np.float32)

        # Get Video - Raw Sequence

        video_data = self.video_data[sample_id]
        if isinstance(video_data, dict):
            if 'punchline_features' in video_data:
                video_features = video_data['punchline_features']
            elif 'punchline' in video_data:
                video_features = video_data['punchline']
            else:
                video_features = list(video_data.values())[0]
        else:
            video_features = video_data

        try:
            video_raw = np.array(video_features, dtype=np.float32).reshape(-1, 75)
            if self.video_scaler:
                video_raw = self.video_scaler.transform(video_raw)
            video_seq = video_raw[:50]
        except:
            video_seq = np.zeros((1, 75), dtype=np.float32)

        if self.is_train and np.random.rand() < self.text_dropout:
            word_indices = [101]

        label = self.labels[sample_id]

        return {
            'word_indices': torch.LongTensor(word_indices),
            'audio': torch.FloatTensor(audio_seq),  # (seq_len, 81)
            'video': torch.FloatTensor(video_seq),  # (seq_len, 75)
            'audio_len': len(audio_seq),
            'video_len': len(video_seq),
            'label': torch.tensor(label, dtype=torch.long)
        }


def collate_fn_lstm(batch):
    """Collate function for variable-length sequences"""
    # Pad text
    max_text_len = max(len(item['word_indices']) for item in batch)
    word_indices_padded = []
    attention_masks = []

    for item in batch:
        indices = item['word_indices'].numpy()
        padded = np.pad(indices, (0, max_text_len - len(indices)), constant_values=0)
        word_indices_padded.append(torch.LongTensor(padded))

        mask = np.ones(len(indices))
        mask_padded = np.pad(mask, (0, max_text_len - len(indices)), constant_values=0)
        attention_masks.append(torch.LongTensor(mask_padded))

    # Pad audio
    max_audio_len = max(len(item['audio']) for item in batch)
    audio_padded = []
    audio_lengths = []

    for item in batch:
        audio = item['audio'].numpy()
        if len(audio) < max_audio_len:
            audio = np.pad(audio, ((0, max_audio_len - len(audio)), (0, 0)), constant_values=0)
        audio_padded.append(torch.FloatTensor(audio))
        audio_lengths.append(item['audio_len'])

    # Pad video
    max_video_len = max(len(item['video']) for item in batch)
    video_padded = []
    video_lengths = []

    for item in batch:
        video = item['video'].numpy()
        if len(video) < max_video_len:
            video = np.pad(video, ((0, max_video_len - len(video)), (0, 0)), constant_values=0)
        video_padded.append(torch.FloatTensor(video))
        video_lengths.append(item['video_len'])

    return {
        'word_indices': torch.stack(word_indices_padded),
        'attention_mask': torch.stack(attention_masks),
        'audio': torch.stack(audio_padded),  # (batch, max_len, 81)
        'video': torch.stack(video_padded),  # (batch, max_len, 75)
        'audio_lengths': torch.LongTensor(audio_lengths),
        'video_lengths': torch.LongTensor(video_lengths),
        'label': torch.stack([item['label'] for item in batch])
    }


# 2. Encoder

class LSTMTextEncoder(nn.Module):

    def __init__(self, vocab_size=30522, embedding_dim=300, hidden_dim=384, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim,
                           num_layers=num_layers, batch_first=True, bidirectional=True)
        self.text_dim = hidden_dim * 2

    def forward(self, input_ids, attention_mask=None):
        """
        Args:
            input_ids: (batch, seq_len)

        Returns:
            text_out: (batch, text_dim) - final hidden state
        """
        if input_ids.shape[0] == 0 or input_ids.shape[1] == 0:
            return torch.zeros(max(input_ids.shape[0], 1), self.text_dim).to(input_ids.device)

        embedded = self.embedding(input_ids)

        _, (hidden, cell) = self.lstm(embedded)

        text_out = torch.cat([hidden[-2], hidden[-1]], dim=1)

        return text_out


class LSTMAudioEncoder(nn.Module):


    def __init__(self, input_dim=81, hidden_dim=64, num_layers=2):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim,
                           num_layers=num_layers, batch_first=True)

    def forward(self, audio_seq, audio_lengths):
        """
        Args:
            audio_seq: (batch, max_seq_len, 81)
            audio_lengths: (batch,)

        Returns:
            audio_out: (batch, hidden_dim) - final hidden state
        """
        if audio_seq.shape[0] == 0 or audio_seq.shape[1] == 0:
            return torch.zeros(audio_seq.shape[0], self.lstm.hidden_size).to(audio_seq.device)

        _, (hidden, cell) = self.lstm(audio_seq)

        audio_out = hidden[-1]

        return audio_out


class LSTMVideoEncoder(nn.Module):
    """LSTM encoder for video - raw sequence → final hidden state"""

    def __init__(self, input_dim=75, hidden_dim=64, num_layers=2):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim,
                           num_layers=num_layers, batch_first=True)

    def forward(self, video_seq, video_lengths):
        """
        Args:
            video_seq: (batch, max_seq_len, 75)
            video_lengths: (batch,)

        Returns:
            video_out: (batch, hidden_dim) - final hidden state
        """
        if video_seq.shape[0] == 0 or video_seq.shape[1] == 0:
            return torch.zeros(video_seq.shape[0], self.lstm.hidden_size).to(video_seq.device)

        _, (hidden, cell) = self.lstm(video_seq)

        video_out = hidden[-1]

        return video_out


# Gated Fusion

class GatedFusion(nn.Module):

    def __init__(self, text_dim=768, audio_dim=64, video_dim=64):
        super().__init__()

        self.audio_gate = nn.Sequential(
            nn.Linear(audio_dim + text_dim, 128),
            nn.ReLU(),
            nn.Linear(128, audio_dim),
            nn.Sigmoid()
        )

        self.video_gate = nn.Sequential(
            nn.Linear(video_dim + text_dim, 128),
            nn.ReLU(),
            nn.Linear(128, video_dim),
            nn.Sigmoid()
        )

        self.text_proj = nn.Linear(text_dim, 256)
        self.audio_proj = nn.Linear(audio_dim, 256)
        self.video_proj = nn.Linear(video_dim, 256)

    def forward(self, text_feat, audio_feat, video_feat):
        # Compute gates
        audio_input = torch.cat([audio_feat, text_feat], dim=1)
        audio_gate = self.audio_gate(audio_input)
        audio_gated = audio_feat * audio_gate

        video_input = torch.cat([video_feat, text_feat], dim=1)
        video_gate = self.video_gate(video_input)
        video_gated = video_feat * video_gate

        # Project and fuse
        text_proj = self.text_proj(text_feat)
        audio_proj = self.audio_proj(audio_gated)
        video_proj = self.video_proj(video_gated)

        fused = text_proj + audio_proj + video_proj

        return fused


class CrossModalAttention(nn.Module):
    """Align audio/video to text"""

    def __init__(self, dim=64):
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dim = dim

    def forward(self, text_feat_small, modality_feat):
        q = self.query(text_feat_small)
        k = self.key(modality_feat)
        v = self.value(modality_feat)

        scores = torch.matmul(q.unsqueeze(1), k.unsqueeze(2)) / (self.dim ** 0.5)
        weights = F.softmax(scores, dim=-1)

        output = torch.matmul(weights, v.unsqueeze(1)).squeeze(1)
        output = self.out_proj(output)

        return output


# Main Model

class ImprovedBertLSTMModel(nn.Module):
    """Complete multimodal model with raw sequences"""

    def __init__(self, audio_hidden=64, video_hidden=64, lstm_layers=2, text_hidden=384):
        super().__init__()

        # LSTM Text encoder
        self.text_encoder = LSTMTextEncoder(embedding_dim=300, hidden_dim=text_hidden,
                                           num_layers=lstm_layers)
        text_dim = self.text_encoder.text_dim

        # LSTM Audio encoder
        self.audio_encoder = LSTMAudioEncoder(input_dim=81,
                                             hidden_dim=audio_hidden,
                                             num_layers=lstm_layers)
        audio_dim = audio_hidden

        # LSTM Video encoder
        self.video_encoder = LSTMVideoEncoder(input_dim=75,
                                             hidden_dim=video_hidden,
                                             num_layers=lstm_layers)
        video_dim = video_hidden

        # Cross-modal attention
        self.audio_attention = CrossModalAttention(dim=audio_hidden)
        self.video_attention = CrossModalAttention(dim=video_hidden)
        self.text_small_proj = nn.Linear(text_dim, audio_hidden)

        # Gated Fusion
        self.gated_fusion = GatedFusion(text_dim=text_dim,
                                       audio_dim=audio_dim,
                                       video_dim=video_dim)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices, audio, video,
               audio_lengths, video_lengths, attention_mask=None):

        # Encode: raw sequences → final hidden states
        text_out = self.text_encoder(word_indices, attention_mask)
        audio_out = self.audio_encoder(audio, audio_lengths)
        video_out = self.video_encoder(video, video_lengths)

        # Cross-modal attentio
        text_small = self.text_small_proj(text_out)
        audio_aligned = self.audio_attention(text_small, audio_out)
        video_aligned = self.video_attention(text_small, video_out)

        # Gated fusion: final hidden states → fused representation
        fused = self.gated_fusion(text_out, audio_aligned, video_aligned)  # (batch, 256)

        # Classify
        logits = self.classifier(fused)  # (batch, 2)

        return logits


class SimpleTeacher(nn.Module):
    """Simple teacher for knowledge distillation"""

    def __init__(self, input_size):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.out = nn.Linear(64, 2)

    def forward(self, x):
        x = self.enc(x)
        return self.out(x)


# Training and Evaluation

def train_epoch(model, loader, optimizer, loss_fn, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0

    for batch in loader:
        word_indices = batch['word_indices'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        audio = batch['audio'].to(device)
        video = batch['video'].to(device)
        audio_lengths = batch['audio_lengths'].to(device)
        video_lengths = batch['video_lengths'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        logits = model(word_indices, audio, video,
                      audio_lengths, video_lengths,
                      attention_mask=attention_mask)
        loss = loss_fn(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, device, missing_mod=None):
    """Evaluate model"""
    model.eval()
    preds = []
    targets = []

    with torch.no_grad():
        for batch in loader:
            word_indices = batch['word_indices'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            video = batch['video'].to(device)
            audio_lengths = batch['audio_lengths'].to(device)
            video_lengths = batch['video_lengths'].to(device)
            labels = batch['label'].to(device)

            if missing_mod == 'audio':
                audio = torch.zeros_like(audio)
                audio_lengths = torch.ones_like(audio_lengths)
            elif missing_mod == 'video':
                video = torch.zeros_like(video)
                video_lengths = torch.ones_like(video_lengths)
            elif missing_mod == 'text':
                word_indices = torch.zeros_like(word_indices)
                attention_mask = torch.zeros_like(attention_mask)

            logits = model(word_indices, audio, video,
                          audio_lengths, video_lengths,
                          attention_mask=attention_mask)
            pred = logits.argmax(dim=1)

            preds.extend(pred.cpu().numpy())
            targets.extend(labels.cpu().numpy())

    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro', zero_division=0)
    return acc, f1


# Data Loading

def load_data():
    """Load TED-Humor dataset"""


    with open(f'{DATA_FOLDER}/data_folds.pkl', 'rb') as f:
        data_folds = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/word_embeddings_list.pkl', 'rb') as f:
        word_embeddings = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/word_embedding_indexes_sdk.pkl', 'rb') as f:
        text_features = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/covarep_features_sdk.pkl', 'rb') as f:
        audio_features = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/openface_features_sdk.pkl', 'rb') as f:
        video_features = pickle.load(f, encoding='latin1')

    with open(f'{DATA_FOLDER}/humor_labels_sdk.pkl', 'rb') as f:
        labels = pickle.load(f, encoding='latin1')

    # Get splits
    train_ids = None
    try:
        if isinstance(data_folds, dict) and 'train' in data_folds:
            train_ids = data_folds['train']
            dev_ids = data_folds['dev']
            test_ids = data_folds['test']
        elif isinstance(data_folds, list):
            train_ids = data_folds[0].get('train')
            dev_ids = data_folds[0].get('dev')
            test_ids = data_folds[0].get('test')
    except:
        pass

    if train_ids is None or len(train_ids) == 0:

        all_ids = list(range(len(labels)))
        np.random.shuffle(all_ids)
        split1 = int(0.7 * len(all_ids))
        split2 = int(0.85 * len(all_ids))
        train_ids = all_ids[:split1]
        dev_ids = all_ids[split1:split2]
        test_ids = all_ids[split2:]

    print(f"Train: {len(train_ids)}, Dev: {len(dev_ids)}, Test: {len(test_ids)}")

    # Scale audio/video
    train_audio_list = []
    for id in train_ids:
        try:
            audio_data = audio_features[id]
            if isinstance(audio_data, dict):
                audio_data = audio_data.get('punchline_features',
                                           audio_data.get('punchline', list(audio_data.values())[0]))
            audio_arr = np.array(audio_data, dtype=np.float32).reshape(-1, 81)
            train_audio_list.append(audio_arr)
        except:
            pass

    train_audio = np.vstack(train_audio_list) if train_audio_list else np.zeros((1, 81))
    scaler_audio = StandardScaler().fit(train_audio)

    train_video_list = []
    for id in train_ids:
        try:
            video_data = video_features[id]
            if isinstance(video_data, dict):
                video_data = video_data.get('punchline_features',
                                           video_data.get('punchline', list(video_data.values())[0]))
            video_arr = np.array(video_data, dtype=np.float32).reshape(-1, 75)
            train_video_list.append(video_arr)
        except:
            pass

    train_video = np.vstack(train_video_list) if train_video_list else np.zeros((1, 75))
    scaler_video = StandardScaler().fit(train_video)

    # Create datasets
    train_set = BertLSTMDataset(train_ids, text_features, audio_features,
                               video_features, labels, word_embeddings,
                               scaler_audio, scaler_video, is_train=True,
                               text_dropout=TEXT_DROPOUT)

    dev_set = BertLSTMDataset(dev_ids, text_features, audio_features,
                             video_features, labels, word_embeddings,
                             scaler_audio, scaler_video, is_train=False,
                             text_dropout=0.0)

    test_set = BertLSTMDataset(test_ids, text_features, audio_features,
                              video_features, labels, word_embeddings,
                              scaler_audio, scaler_video, is_train=False,
                              text_dropout=0.0)

    return train_set, dev_set, test_set, word_embeddings



def main():

    # Load data
    train_set, dev_set, test_set, _ = load_data()

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_fn_lstm)
    dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE,
                           collate_fn=collate_fn_lstm)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
                            collate_fn=collate_fn_lstm)

    loss_fn = nn.CrossEntropyLoss()

    # Phase 1: Baseline


    model_phase1 = ImprovedBertLSTMModel(audio_hidden=64, video_hidden=64,
                                         lstm_layers=2, text_hidden=384).to(DEVICE)

    param_groups = [
        {'params': model_phase1.text_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase1.audio_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase1.video_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase1.audio_attention.parameters(), 'lr': 1e-3},
        {'params': model_phase1.video_attention.parameters(), 'lr': 1e-3},
        {'params': model_phase1.text_small_proj.parameters(), 'lr': 1e-3},
        {'params': model_phase1.gated_fusion.parameters(), 'lr': 1e-3},
        {'params': model_phase1.classifier.parameters(), 'lr': 1e-3}
    ]

    optimizer_p1 = optim.AdamW(param_groups, weight_decay=0.01)
    scheduler_p1 = optim.lr_scheduler.CosineAnnealingLR(optimizer_p1,
                                                        T_max=NUM_EPOCHS,
                                                        eta_min=1e-6)

    best_f1 = 0
    for epoch in range(NUM_EPOCHS):
        loss = train_epoch(model_phase1, train_loader, optimizer_p1, loss_fn, DEVICE)
        dev_acc, dev_f1 = evaluate(model_phase1, dev_loader, DEVICE)
        scheduler_p1.step()

        if dev_f1 > best_f1:
            best_f1 = dev_f1
            torch.save(model_phase1.state_dict(),
                      f'{SAVE_FOLDER}/phase1_baseline.pt')
            status = " Saved"
        else:
            status = ""

        if (epoch + 1) % 3 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | Loss: {loss:.4f} | "
                  f"Dev Acc: {dev_acc:.4f} | F1: {dev_f1:.4f} | {status}")

    model_phase1.load_state_dict(torch.load(f'{SAVE_FOLDER}/phase1_baseline.pt'))
    test_acc_p1, test_f1_p1 = evaluate(model_phase1, test_loader, DEVICE)
    print(f"\n Phase 1 - All Modalities: Acc={test_acc_p1:.4f}, F1={test_f1_p1:.4f}\n")

    # Phase 2: Missing Dataset



    phase2_results = {}
    for missing in [None, 'audio', 'video', 'text']:
        acc, f1 = evaluate(model_phase1, test_loader, DEVICE, missing)
        name = missing.upper() if missing else 'ALL'
        phase2_results[name] = (acc, f1)
        print(f"Missing {name:6s}: Accuracy={acc:.4f}, F1={f1:.4f}")

    print()

    # Phase 3: Knowledge Distillation


    print("Training audio teacher")
    audio_teacher = SimpleTeacher(input_size=81).to(DEVICE)
    audio_opt = optim.Adam(audio_teacher.parameters(), lr=1e-3)

    for epoch in range(10):
        audio_teacher.train()
        for batch in train_loader:
            audio = batch['audio'].to(DEVICE)
            # Mean pool audio for teacher
            audio_mean = audio.mean(dim=1)  # (batch, 81)
            labels = batch['label'].to(DEVICE)
            audio_opt.zero_grad()
            logits = audio_teacher(audio_mean)
            loss = loss_fn(logits, labels)
            loss.backward()
            audio_opt.step()

    print("Training student with KD")
    model_phase3 = ImprovedBertLSTMModel(audio_hidden=64, video_hidden=64,
                                         lstm_layers=2, text_hidden=384).to(DEVICE)

    param_groups_p3 = [
        {'params': model_phase3.text_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase3.audio_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase3.video_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase3.audio_attention.parameters(), 'lr': 1e-3},
        {'params': model_phase3.video_attention.parameters(), 'lr': 1e-3},
        {'params': model_phase3.text_small_proj.parameters(), 'lr': 1e-3},
        {'params': model_phase3.gated_fusion.parameters(), 'lr': 1e-3},
        {'params': model_phase3.classifier.parameters(), 'lr': 1e-3}
    ]

    optimizer_p3 = optim.AdamW(param_groups_p3, weight_decay=0.01)
    scheduler_p3 = optim.lr_scheduler.CosineAnnealingLR(optimizer_p3,
                                                        T_max=NUM_EPOCHS,
                                                        eta_min=1e-6)

    best_f1_p3 = 0
    for epoch in range(NUM_EPOCHS):
        model_phase3.train()
        total_loss = 0

        for batch in train_loader:
            word_indices = batch['word_indices'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            audio = batch['audio'].to(DEVICE)
            video = batch['video'].to(DEVICE)
            audio_lengths = batch['audio_lengths'].to(DEVICE)
            video_lengths = batch['video_lengths'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            optimizer_p3.zero_grad()

            # Student prediction
            student_logits = model_phase3(word_indices, audio, video,
                                         audio_lengths, video_lengths,
                                         attention_mask=attention_mask)
            ce_loss = loss_fn(student_logits, labels)

            # Teacher prediction (mean pool audio)
            audio_mean = audio.mean(dim=1)
            with torch.no_grad():
                teacher_logits = audio_teacher(audio_mean)

            # KD loss
            T = 4.0
            kd_loss = 0.5 * F.kl_div(
                F.log_softmax(student_logits / T, dim=1),
                F.softmax(teacher_logits / T, dim=1),
                reduction='batchmean'
            )

            loss = ce_loss + kd_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_phase3.parameters(), 1.0)
            optimizer_p3.step()

            total_loss += loss.item()

        dev_acc, dev_f1 = evaluate(model_phase3, dev_loader, DEVICE)
        scheduler_p3.step()

        if dev_f1 > best_f1_p3:
            best_f1_p3 = dev_f1
            torch.save(model_phase3.state_dict(),
                      f'{SAVE_FOLDER}/phase3_kd.pt')
            status = " Saved"
        else:
            status = ""

        if (epoch + 1) % 3 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | Loss: {total_loss/len(train_loader):.4f} | "
                  f"Dev Acc: {dev_acc:.4f} | F1: {dev_f1:.4f} | {status}")

    model_phase3.load_state_dict(torch.load(f'{SAVE_FOLDER}/phase3_kd.pt'))
    test_acc_p3, test_f1_p3 = evaluate(model_phase3, test_loader, DEVICE)
    print(f"\n Phase 3 - KD: Acc={test_acc_p3:.4f}, F1={test_f1_p3:.4f}\n")

    # Phase 4 : Calibration



    model_phase4 = ImprovedBertLSTMModel(audio_hidden=64, video_hidden=64,
                                         lstm_layers=2, text_hidden=384).to(DEVICE)

    param_groups_p4 = [
        {'params': model_phase4.text_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase4.audio_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase4.video_encoder.parameters(), 'lr': 1e-3},
        {'params': model_phase4.audio_attention.parameters(), 'lr': 1e-3},
        {'params': model_phase4.video_attention.parameters(), 'lr': 1e-3},
        {'params': model_phase4.text_small_proj.parameters(), 'lr': 1e-3},
        {'params': model_phase4.gated_fusion.parameters(), 'lr': 1e-3},
        {'params': model_phase4.classifier.parameters(), 'lr': 1e-3}
    ]

    optimizer_p4 = optim.AdamW(param_groups_p4, weight_decay=0.01)
    scheduler_p4 = optim.lr_scheduler.CosineAnnealingLR(optimizer_p4,
                                                        T_max=NUM_EPOCHS,
                                                        eta_min=1e-6)

    best_f1_p4 = 0
    for epoch in range(NUM_EPOCHS):
        model_phase4.train()
        total_loss = 0

        for batch in train_loader:
            word_indices = batch['word_indices'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            audio = batch['audio'].to(DEVICE)
            video = batch['video'].to(DEVICE)
            audio_lengths = batch['audio_lengths'].to(DEVICE)
            video_lengths = batch['video_lengths'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            # Random modality dropout
            if np.random.rand() < 0.15:
                which = np.random.choice(['text', 'audio', 'video'])
                if which == 'text':
                    word_indices = torch.zeros_like(word_indices)
                    attention_mask = torch.zeros_like(attention_mask)
                elif which == 'audio':
                    audio = torch.zeros_like(audio)
                    audio_lengths = torch.ones_like(audio_lengths)
                else:
                    video = torch.zeros_like(video)
                    video_lengths = torch.ones_like(video_lengths)

            optimizer_p4.zero_grad()

            logits = model_phase4(word_indices, audio, video,
                                 audio_lengths, video_lengths,
                                 attention_mask=attention_mask)
            ce_loss = loss_fn(logits, labels)

            # Calibration loss
            probs = F.softmax(logits, dim=1)
            max_prob = probs.max(dim=1)[0]
            cal_loss = torch.clamp(max_prob - 0.8, min=0).mean()

            loss = ce_loss + 0.2 * cal_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_phase4.parameters(), 1.0)
            optimizer_p4.step()

            total_loss += loss.item()

        dev_acc, dev_f1 = evaluate(model_phase4, dev_loader, DEVICE)
        scheduler_p4.step()

        if dev_f1 > best_f1_p4:
            best_f1_p4 = dev_f1
            torch.save(model_phase4.state_dict(),
                      f'{SAVE_FOLDER}/phase4_calibration.pt')
            status = " Saved"
        else:
            status = ""

        if (epoch + 1) % 3 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | Loss: {total_loss/len(train_loader):.4f} | "
                  f"Dev Acc: {dev_acc:.4f} | F1: {dev_f1:.4f} | {status}")

    model_phase4.load_state_dict(torch.load(f'{SAVE_FOLDER}/phase4_calibration.pt'))
    test_acc_p4, test_f1_p4 = evaluate(model_phase4, test_loader, DEVICE)
    print(f"\n Phase 4 - Calibration: Acc={test_acc_p4:.4f}, F1={test_f1_p4:.4f}\n")


    print("Phase 4 Missing Data:")
    for missing in [None, 'audio', 'video', 'text']:
        acc, f1 = evaluate(model_phase4, test_loader, DEVICE, missing)
        name = missing.upper() if missing else 'ALL'
        print(f"  Missing {name:6s}: Accuracy={acc:.4f}, F1={f1:.4f}")

    # SUMMARY

    print()
    print()
    print(" Result:")
    print(f"Phase 1 (Baseline):      Acc={test_acc_p1:.4f}, F1={test_f1_p1:.4f}")
    print(f"Phase 3 (KD):            Acc={test_acc_p3:.4f}, F1={test_f1_p3:.4f}")
    print(f"Phase 4 (Calibration):   Acc={test_acc_p4:.4f}, F1={test_f1_p4:.4f}")
    print()

if __name__ == '__main__':
    main()

Using device: cuda
Train: 10598, Dev: 2626, Test: 3290
Epoch  1/20 | Loss: 0.7227 | Dev Acc: 0.5141 | F1: 0.4944 |  Saved
Epoch  3/20 | Loss: 0.7083 | Dev Acc: 0.5175 | F1: 0.4963 |  Saved
Epoch  6/20 | Loss: 0.6955 | Dev Acc: 0.5560 | F1: 0.5496 |  Saved
Epoch  9/20 | Loss: 0.6719 | Dev Acc: 0.5819 | F1: 0.5815 |  Saved
Epoch 12/20 | Loss: 0.6452 | Dev Acc: 0.5952 | F1: 0.5950 |  Saved
Epoch 15/20 | Loss: 0.6046 | Dev Acc: 0.5971 | F1: 0.5945 | 
Epoch 18/20 | Loss: 0.5694 | Dev Acc: 0.5929 | F1: 0.5876 | 

 Phase 1 - All Modalities: Acc=0.6067, F1=0.6064

Missing ALL   : Accuracy=0.6067, F1=0.6064
Missing AUDIO : Accuracy=0.5492, F1=0.5193
Missing VIDEO : Accuracy=0.6027, F1=0.6024
Missing TEXT  : Accuracy=0.5802, F1=0.5796

Training audio teacher
Training student with KD
Epoch  1/20 | Loss: 0.7272 | Dev Acc: 0.4966 | F1: 0.3485 |  Saved
Epoch  3/20 | Loss: 0.7128 | Dev Acc: 0.5061 | F1: 0.4532 | 
Epoch  6/20 | Loss: 0.7053 | Dev Acc: 0.4992 | F1: 0.4716 | 
Epoch  9/20 | Loss: 0.6945 

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from transformers import get_cosine_schedule_with_warmup
import warnings
warnings.filterwarnings('ignore')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_FOLDER = './ted_humor_data'
SAVE_FOLDER = './saved_models'
os.makedirs(SAVE_FOLDER, exist_ok=True)

BATCH_SIZE = 32
NUM_EPOCHS = 150
LEARNING_RATE_TEXT = 5e-5
LEARNING_RATE_LSTM = 5e-5
LEARNING_RATE_MLP = 2e-4

print("=" * 80)
print("AUTO-ENSEMBLE: Training 3 models + ensemble")
print("=" * 80)


# Load embeddings once
print("Loading embeddings")
with open(f'{DATA_FOLDER}/word_embedding_list.pkl', 'rb') as f:
    word_embeddings_list = pickle.load(f, encoding='latin1')
word_embeddings_array = np.array(word_embeddings_list, dtype=np.float32)
embedding_dim = word_embeddings_array.shape[1]

# Model architecture
class SimpleTextEncoder(nn.Module):
    def __init__(self, embedding_tensor):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=False)
        embed_dim = embedding_tensor.shape[1]
        self.lstm = nn.LSTM(embed_dim, 128, num_layers=1, batch_first=True, bidirectional=True, dropout=0.0)
        self.projection = nn.Linear(128, 256)
        self.dropout = nn.Dropout(0.2)

    def forward(self, word_indices):
        if word_indices.shape[1] == 0:
            return torch.zeros(word_indices.shape[0], 256).to(word_indices.device)
        embeds = self.embedding(word_indices)
        lstm_out, (h, c) = self.lstm(embeds)
        out = h[-1]
        out = self.projection(out)
        out = self.dropout(out)
        return out

class SimpleAudioEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(81, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 256),
        )

    def forward(self, audio_raw):
        audio_pooled = audio_raw.mean(dim=1)
        return self.net(audio_pooled)

class SimpleVideoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(75, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 256),
        )

    def forward(self, video_raw):
        video_pooled = video_raw.mean(dim=1)
        return self.net(video_pooled)

class SimpleFusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(256 + 256 + 256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
        )

    def forward(self, text, audio, video):
        combined = torch.cat([text, audio, video], dim=1)
        return self.net(combined)

class SimpleMultimodalModel(nn.Module):
    def __init__(self, embedding_tensor):
        super().__init__()
        self.text_encoder = SimpleTextEncoder(embedding_tensor)
        self.audio_encoder = SimpleAudioEncoder()
        self.video_encoder = SimpleVideoEncoder()
        self.fusion = SimpleFusion()
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices, audio_raw, video_raw):
        text_feat = self.text_encoder(word_indices)
        audio_feat = self.audio_encoder(audio_raw)
        video_feat = self.video_encoder(video_raw)
        fused = self.fusion(text_feat, audio_feat, video_feat)
        logits = self.classifier(fused)
        return logits

# Dataset
class SimpleDataset(Dataset):
    def __init__(self, ids, text_data, audio_data, video_data, labels, audio_scaler, video_scaler):
        self.ids = ids
        self.text_data = text_data
        self.audio_data = audio_data
        self.video_data = video_data
        self.labels = labels
        self.audio_scaler = audio_scaler
        self.video_scaler = video_scaler

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sample_id = self.ids[idx]

        text_data = self.text_data[sample_id]
        if isinstance(text_data, dict):
            text_indices = text_data.get('punchline_features', text_data.get('punchline', list(text_data.values())[0]))
        else:
            text_indices = text_data

        word_indices = []
        if isinstance(text_indices, (list, np.ndarray)):
            for idx_val in text_indices:
                try:
                    idx_val = int(idx_val[0]) if isinstance(idx_val, (list, np.ndarray)) and len(idx_val) > 0 else int(idx_val)
                    if 0 <= idx_val < len(word_embeddings_list):
                        word_indices.append(idx_val)
                except:
                    pass

        if len(word_indices) == 0:
            word_indices = [0]
        word_indices = word_indices[:512]

        audio_data = self.audio_data[sample_id]
        if isinstance(audio_data, dict):
            audio_features = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0]))
        else:
            audio_features = audio_data

        try:
            audio_raw = np.array(audio_features, dtype=np.float32).reshape(-1, 81)
            if audio_raw.shape[0] == 0:
                audio_raw = np.zeros((1, 81), dtype=np.float32)
            if self.audio_scaler:
                audio_raw = self.audio_scaler.transform(audio_raw)
        except:
            audio_raw = np.zeros((1, 81), dtype=np.float32)

        video_data = self.video_data[sample_id]
        if isinstance(video_data, dict):
            video_features = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0]))
        else:
            video_features = video_data

        try:
            video_raw = np.array(video_features, dtype=np.float32).reshape(-1, 75)
            if video_raw.shape[0] == 0:
                video_raw = np.zeros((1, 75), dtype=np.float32)
            if self.video_scaler:
                video_raw = self.video_scaler.transform(video_raw)
        except:
            video_raw = np.zeros((1, 75), dtype=np.float32)

        return {
            'word_indices': torch.LongTensor(word_indices),
            'audio_raw': torch.FloatTensor(audio_raw),
            'video_raw': torch.FloatTensor(video_raw),
            'label': torch.tensor(self.labels[sample_id], dtype=torch.long)
        }

def collate_fn(batch):
    max_text_len = max(len(item['word_indices']) for item in batch)
    word_indices_padded = []
    for item in batch:
        indices = item['word_indices']
        padded = np.pad(indices.numpy(), (0, max_text_len - len(indices)), constant_values=0)
        word_indices_padded.append(torch.LongTensor(padded))

    max_audio_len = max(item['audio_raw'].shape[0] for item in batch)
    audio_raw_padded = []
    for item in batch:
        audio_raw = item['audio_raw'].numpy()
        if audio_raw.shape[0] < max_audio_len:
            padded = np.pad(audio_raw, ((0, max_audio_len - audio_raw.shape[0]), (0, 0)))
        else:
            padded = audio_raw
        audio_raw_padded.append(torch.FloatTensor(padded))

    max_video_len = max(item['video_raw'].shape[0] for item in batch)
    video_raw_padded = []
    for item in batch:
        video_raw = item['video_raw'].numpy()
        if video_raw.shape[0] < max_video_len:
            padded = np.pad(video_raw, ((0, max_video_len - video_raw.shape[0]), (0, 0)))
        else:
            padded = video_raw
        video_raw_padded.append(torch.FloatTensor(padded))

    return {
        'word_indices': torch.stack(word_indices_padded),
        'audio_raw': torch.stack(audio_raw_padded),
        'video_raw': torch.stack(video_raw_padded),
        'label': torch.stack([item['label'] for item in batch])
    }

def train_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for batch in loader:
        word_indices = batch['word_indices'].to(device)
        audio_raw = batch['audio_raw'].to(device)
        video_raw = batch['video_raw'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(word_indices, audio_raw, video_raw)
        loss = loss_fn(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for batch in loader:
            word_indices = batch['word_indices'].to(device)
            audio_raw = batch['audio_raw'].to(device)
            video_raw = batch['video_raw'].to(device)
            labels = batch['label'].to(device)

            logits = model(word_indices, audio_raw, video_raw)
            pred = logits.argmax(dim=1)

            preds.extend(pred.cpu().numpy())
            targets.extend(labels.cpu().numpy())

    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro', zero_division=0)
    return acc, f1

# Load data once
print("Loading data")
with open(f'{DATA_FOLDER}/data_folds.pkl', 'rb') as f:
    data_folds = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/word_embedding_indexes_sdk.pkl', 'rb') as f:
    text_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/covarep_features_sdk.pkl', 'rb') as f:
    audio_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/openface_features_sdk.pkl', 'rb') as f:
    video_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/humor_label_sdk.pkl', 'rb') as f:
    labels = pickle.load(f, encoding='latin1')

train_ids = data_folds['train']
dev_ids = data_folds['dev']
test_ids = data_folds['test']

# Create scalers
train_audio_list = []
for id in train_ids:
    try:
        audio_data = audio_features.get(id, {})
        if isinstance(audio_data, dict):
            audio_feat = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0] if audio_data else None))
        else:
            audio_feat = audio_data
        if audio_feat is not None:
            arr = np.array(audio_feat, dtype=np.float32).reshape(-1, 81)
            if arr.shape[0] > 0:
                train_audio_list.append(arr)
    except:
        pass

scaler_audio = StandardScaler().fit(np.vstack(train_audio_list) if train_audio_list else np.zeros((1, 81)))

train_video_list = []
for id in train_ids:
    try:
        video_data = video_features.get(id, {})
        if isinstance(video_data, dict):
            video_feat = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0] if video_data else None))
        else:
            video_feat = video_data
        if video_feat is not None:
            arr = np.array(video_feat, dtype=np.float32).reshape(-1, 75)
            if arr.shape[0] > 0:
                train_video_list.append(arr)
    except:
        pass

scaler_video = StandardScaler().fit(np.vstack(train_video_list) if train_video_list else np.zeros((1, 75)))

# Train 3 models
seeds = [42, 43, 44]
trained_models = []

for seed_idx, seed in enumerate(seeds):
    print("\n" + "=" * 80)
    print(f"TRAINING MODEL {seed_idx + 1}/3 (SEED={seed})")
    print("=" * 80)

    # Set random seeds
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Create datasets
    train_set = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video)
    dev_set = SimpleDataset(dev_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video)
    test_set = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
    dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

    # Loss with class weighting
    train_labels = np.array([labels[id] for id in train_ids])
    class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
    class_weights = torch.FloatTensor(class_weights).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)

    # Model
    embedding_tensor = torch.FloatTensor(word_embeddings_array).to(DEVICE)
    model = SimpleMultimodalModel(embedding_tensor).to(DEVICE)

    param_groups = [
        {'params': model.text_encoder.parameters(), 'lr': LEARNING_RATE_TEXT},
        {'params': model.audio_encoder.parameters(), 'lr': LEARNING_RATE_LSTM},
        {'params': model.video_encoder.parameters(), 'lr': LEARNING_RATE_LSTM},
        {'params': model.fusion.parameters(), 'lr': LEARNING_RATE_LSTM},
        {'params': model.classifier.parameters(), 'lr': LEARNING_RATE_MLP}
    ]

    optimizer = optim.AdamW(param_groups, weight_decay=1e-5)
    total_steps = NUM_EPOCHS * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=total_steps, num_cycles=0.5)

    best_f1 = 0
    patience = 20
    patience_counter = 0

    for epoch in range(NUM_EPOCHS):
        loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
        dev_acc, dev_f1 = evaluate(model, dev_loader, DEVICE)
        scheduler.step()

        if dev_f1 > best_f1:
            best_f1 = dev_f1
            patience_counter = 0
            torch.save(model.state_dict(), f'{SAVE_FOLDER}/simple_best_seed{seed}.pt')
        else:
            patience_counter += 1

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {loss:.4f} | Dev F1: {dev_f1:.4f}")

        if patience_counter >= patience:
            print(f"Early stopped at epoch {epoch+1}")
            break

    # Evaluate on test
    model.load_state_dict(torch.load(f'{SAVE_FOLDER}/simple_best_seed{seed}.pt'))
    test_acc, test_f1 = evaluate(model, test_loader, DEVICE)
    print(f"\n Model {seed_idx + 1} test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")

    trained_models.append((model, test_acc, test_f1))

# Ensemble evaluation
print("\n" + "=" * 80)
print("ENSEMBLE EVALUATION")
print("=" * 80)

test_set = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

all_preds_list = []
all_labels_list = []

with torch.no_grad():
    for batch in test_loader:
        word_indices = batch['word_indices'].to(DEVICE)
        audio_raw = batch['audio_raw'].to(DEVICE)
        video_raw = batch['video_raw'].to(DEVICE)
        labels_batch = batch['label'].to(DEVICE)

        logits_list = []
        for model, _, _ in trained_models:
            model.eval()
            logits = model(word_indices, audio_raw, video_raw)
            logits_list.append(logits)

        avg_logits = sum(logits_list) / len(logits_list)
        preds = avg_logits.argmax(dim=1)

        all_preds_list.append(preds)
        all_labels_list.append(labels_batch)

all_preds = torch.cat(all_preds_list).cpu().numpy()
all_labels = torch.cat(all_labels_list).cpu().numpy()

acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='macro')
prec = precision_score(all_labels, all_preds, average='macro')
rec = recall_score(all_labels, all_preds, average='macro')

print("\n" + "=" * 80)
print("ENSEMBLE RESULTS")
print("=" * 80)
print(f"\nIndividual models:")
for i, (_, test_acc, test_f1) in enumerate(trained_models, 1):
    print(f"  Model {i}: Accuracy={test_acc:.4f}, F1={test_f1:.4f}")

print(f"\n ENSEMBLE RESULTS:")
print(f"  Accuracy:  {acc:.4f} ({acc*100:.2f}%)")
print(f"  F1 Score:  {f1:.4f}")
print(f"  Precision: {prec:.4f}")
print(f"  Recall:    {rec:.4f}")



print("\n" + "=" * 80)
print("Overall")
print("=" * 80)
print("\n1. 3 models trained: simple_best_seed42.pt, simple_best_seed43.pt, simple_best_seed44.pt")
print("2.  Ensemble accuracy:", f"{acc*100:.2f}%")



In [None]:
Loading embeddings...
Loading data...

================================================================================
TRAINING MODEL 1/3 (SEED=42)
================================================================================
Epoch  10 | Loss: 0.6934 | Dev F1: 0.3333
Epoch  20 | Loss: 0.6909 | Dev F1: 0.5174
Epoch  30 | Loss: 0.6675 | Dev F1: 0.5826
Epoch  40 | Loss: 0.6289 | Dev F1: 0.6226
Epoch  50 | Loss: 0.6020 | Dev F1: 0.6529
Epoch  60 | Loss: 0.5846 | Dev F1: 0.6647
Epoch  70 | Loss: 0.5656 | Dev F1: 0.6716
Epoch  80 | Loss: 0.5412 | Dev F1: 0.6795
Epoch  90 | Loss: 0.5175 | Dev F1: 0.6799
Epoch 100 | Loss: 0.4884 | Dev F1: 0.6742
Early stopped at epoch 104

 Model 1 test accuracy: 0.6805 (68.05%)

================================================================================
TRAINING MODEL 2/3 (SEED=43)
================================================================================
Epoch  10 | Loss: 0.6935 | Dev F1: 0.3333
Epoch  20 | Loss: 0.6905 | Dev F1: 0.5072
Epoch  30 | Loss: 0.6697 | Dev F1: 0.5650
Epoch  40 | Loss: 0.6357 | Dev F1: 0.6186
Epoch  50 | Loss: 0.6039 | Dev F1: 0.6412
Epoch  60 | Loss: 0.5860 | Dev F1: 0.6603
Epoch  70 | Loss: 0.5683 | Dev F1: 0.6691
Epoch  80 | Loss: 0.5458 | Dev F1: 0.6781
Epoch  90 | Loss: 0.5177 | Dev F1: 0.6792
Epoch 100 | Loss: 0.4841 | Dev F1: 0.6795
Epoch 110 | Loss: 0.4510 | Dev F1: 0.6662
Early stopped at epoch 117

 Model 2 test accuracy: 0.6900 (69.00%)

================================================================================
TRAINING MODEL 3/3 (SEED=44)
================================================================================
Epoch  10 | Loss: 0.6946 | Dev F1: 0.3333
Epoch  20 | Loss: 0.6920 | Dev F1: 0.3535
Epoch  30 | Loss: 0.6734 | Dev F1: 0.5614
Epoch  40 | Loss: 0.6384 | Dev F1: 0.6069
Epoch  50 | Loss: 0.6078 | Dev F1: 0.6395
Epoch  60 | Loss: 0.5880 | Dev F1: 0.6596
Epoch  70 | Loss: 0.5685 | Dev F1: 0.6713
Epoch  80 | Loss: 0.5460 | Dev F1: 0.6773
Epoch  90 | Loss: 0.5172 | Dev F1: 0.6728
Epoch 100 | Loss: 0.4860 | Dev F1: 0.6607
Epoch 110 | Loss: 0.4473 | Dev F1: 0.6568
Early stopped at epoch 111

 Model 3 test accuracy: 0.6836 (68.36%)

================================================================================
ENSEMBLE EVALUATION
================================================================================

================================================================================
ENSEMBLE RESULTS
================================================================================

Individual models:
  Model 1: Accuracy=0.6805, F1=0.6785
  Model 2: Accuracy=0.6900, F1=0.6899
  Model 3: Accuracy=0.6836, F1=0.6833

ENSEMBLE RESULTS:
  Accuracy:  0.6988 (69.88%)
  F1 Score:  0.6984
  Precision: 0.7002
  Recall:    0.6990



================================================================================
Overall:
================================================================================

1.  3 models trained: simple_best_seed42.pt, simple_best_seed43.pt, simple_best_seed44.pt
2.  Ensemble accuracy: 69.88%


In [2]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from transformers import get_cosine_schedule_with_warmup
import warnings
warnings.filterwarnings('ignore')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_FOLDER = './ted_humor_data'
SAVE_FOLDER = './saved_models_phase'
os.makedirs(SAVE_FOLDER, exist_ok=True)

BATCH_SIZE = 32
NUM_EPOCHS = 100

print("=" * 80)
print("ROBUST MULTIMODAL FUSION WITH MISSING MODALITIES")

print("=" * 80)
print(f"Device: {DEVICE}\n")

# Load all data
print("Loading data...")
with open(f'{DATA_FOLDER}/data_folds.pkl', 'rb') as f:
    data_folds = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/word_embedding_indexes_sdk.pkl', 'rb') as f:
    text_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/covarep_features_sdk.pkl', 'rb') as f:
    audio_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/openface_features_sdk.pkl', 'rb') as f:
    video_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/humor_label_sdk.pkl', 'rb') as f:
    labels = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/word_embedding_list.pkl', 'rb') as f:
    word_embeddings_list = pickle.load(f, encoding='latin1')

train_ids = data_folds['train']
dev_ids = data_folds['dev']
test_ids = data_folds['test']

word_embeddings_array = np.array(word_embeddings_list, dtype=np.float32)

# Create scalers
train_audio_list = []
for id in train_ids:
    try:
        audio_data = audio_features.get(id, {})
        if isinstance(audio_data, dict):
            audio_feat = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0] if audio_data else None))
        else:
            audio_feat = audio_data
        if audio_feat is not None:
            arr = np.array(audio_feat, dtype=np.float32).reshape(-1, 81)
            if arr.shape[0] > 0:
                train_audio_list.append(arr)
    except:
        pass

scaler_audio = StandardScaler().fit(np.vstack(train_audio_list) if train_audio_list else np.zeros((1, 81)))

train_video_list = []
for id in train_ids:
    try:
        video_data = video_features.get(id, {})
        if isinstance(video_data, dict):
            video_feat = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0] if video_data else None))
        else:
            video_feat = video_data
        if video_feat is not None:
            arr = np.array(video_feat, dtype=np.float32).reshape(-1, 75)
            if arr.shape[0] > 0:
                train_video_list.append(arr)
    except:
        pass

scaler_video = StandardScaler().fit(np.vstack(train_video_list) if train_video_list else np.zeros((1, 75)))

# Cross-Modal Attention Mechanism


class CrossModalAttention(nn.Module):

    def __init__(self, dim=256):
        super().__init__()
        self.dim = dim

        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, text_feat, modality_feat):
        q = self.query(text_feat)
        k = self.key(modality_feat)
        v = self.value(modality_feat)

        scores = torch.matmul(q.unsqueeze(1), k.unsqueeze(2)) / (self.dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)

        attn_output = torch.matmul(attn_weights, v.unsqueeze(1)).squeeze(1)
        output = self.out_proj(attn_output)
        output = self.dropout(output)

        return output

# Improved Multimodal Fusion

class ImprovedMultimodalFusion(nn.Module):

    def __init__(self, word_embeddings_array):
        super().__init__()

        # Text: Embedding + Bidirectional LSTM
        embedding_tensor = torch.FloatTensor(word_embeddings_array)
        self.text_embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=False)
        self.text_lstm = nn.LSTM(300, 128, num_layers=1, batch_first=True, bidirectional=True)

        # Audio: Bidirectional LSTM on raw 81-dim frames
        self.audio_lstm = nn.LSTM(81, 128, num_layers=1, batch_first=True, bidirectional=True)

        # Video: Bidirectional LSTM on raw 75-dim frames
        self.video_lstm = nn.LSTM(75, 128, num_layers=1, batch_first=True, bidirectional=True)

        # Cross-modal attention mechanisms
        self.audio_attention = CrossModalAttention(dim=256)
        self.video_attention = CrossModalAttention(dim=256)

        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def extract_final_hidden_state(self, lstm_output):
        """Extract final hidden state from bidirectional LSTM"""
        forward_final = lstm_output[:, -1, :128]
        backward_final = lstm_output[:, 0, 128:]
        final_state = torch.cat([forward_final, backward_final], dim=1)
        return final_state

    def forward(self, word_indices, audio_raw, video_raw, missing_mask=None):
        # Text
        text_embeds = self.text_embedding(word_indices)
        text_lstm_output, _ = self.text_lstm(text_embeds)
        text_feat = self.extract_final_hidden_state(text_lstm_output)

        # Audio
        audio_lstm_output, _ = self.audio_lstm(audio_raw)
        audio_feat = self.extract_final_hidden_state(audio_lstm_output)

        # Video
        video_lstm_output, _ = self.video_lstm(video_raw)
        video_feat = self.extract_final_hidden_state(video_lstm_output)

        # Apply missing masks
        if missing_mask is not None:
            if missing_mask.get('text', False):
                text_feat = torch.zeros_like(text_feat)
            if missing_mask.get('audio', False):
                audio_feat = torch.zeros_like(audio_feat)
            if missing_mask.get('video', False):
                video_feat = torch.zeros_like(video_feat)

        # Cross-modal attention: Align audio/video to text
        audio_aligned = self.audio_attention(text_feat, audio_feat)
        video_aligned = self.video_attention(text_feat, video_feat)

        # Fusion: Concatenate aligned features
        fused_input = torch.cat([text_feat, audio_aligned, video_aligned], dim=1)
        fused = self.fusion(fused_input)
        logits = self.classifier(fused)

        return logits


# Dataset & Collate

class SimpleDataset(Dataset):
    def __init__(self, ids, text_data, audio_data, video_data, labels, audio_scaler, video_scaler, word_embeddings_list, modality_drop_rate=0.0):
        self.ids = ids
        self.text_data = text_data
        self.audio_data = audio_data
        self.video_data = video_data
        self.labels = labels
        self.audio_scaler = audio_scaler
        self.video_scaler = video_scaler
        self.word_embeddings_list = word_embeddings_list
        self.modality_drop_rate = modality_drop_rate

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sample_id = self.ids[idx]

        # Text
        text_data = self.text_data[sample_id]
        if isinstance(text_data, dict):
            text_indices = text_data.get('punchline_features', text_data.get('punchline', list(text_data.values())[0]))
        else:
            text_indices = text_data

        word_indices = []
        if isinstance(text_indices, (list, np.ndarray)):
            for idx_val in text_indices:
                try:
                    idx_val = int(idx_val[0]) if isinstance(idx_val, (list, np.ndarray)) and len(idx_val) > 0 else int(idx_val)
                    if 0 <= idx_val < len(self.word_embeddings_list):
                        word_indices.append(idx_val)
                except:
                    pass

        if len(word_indices) == 0:
            word_indices = [0]
        word_indices = word_indices[:512]

        # Audio
        audio_data = self.audio_data[sample_id]
        if isinstance(audio_data, dict):
            audio_features = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0]))
        else:
            audio_features = audio_data

        try:
            audio_raw = np.array(audio_features, dtype=np.float32).reshape(-1, 81)
            if audio_raw.shape[0] == 0:
                audio_raw = np.zeros((1, 81), dtype=np.float32)
            if self.audio_scaler:
                audio_raw = self.audio_scaler.transform(audio_raw)
        except:
            audio_raw = np.zeros((1, 81), dtype=np.float32)

        # Video
        video_data = self.video_data[sample_id]
        if isinstance(video_data, dict):
            video_features = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0]))
        else:
            video_features = video_data

        try:
            video_raw = np.array(video_features, dtype=np.float32).reshape(-1, 75)
            if video_raw.shape[0] == 0:
                video_raw = np.zeros((1, 75), dtype=np.float32)
            if self.video_scaler:
                video_raw = self.video_scaler.transform(video_raw)
        except:
            video_raw = np.zeros((1, 75), dtype=np.float32)

        missing_mask = {
            'text': np.random.rand() < self.modality_drop_rate,
            'audio': np.random.rand() < self.modality_drop_rate,
            'video': np.random.rand() < self.modality_drop_rate
        }

        return {
            'word_indices': torch.LongTensor(word_indices),
            'audio_raw': torch.FloatTensor(audio_raw),
            'video_raw': torch.FloatTensor(video_raw),
            'label': torch.tensor(self.labels[sample_id], dtype=torch.long),
            'missing_mask': missing_mask
        }

def collate_fn(batch):
    max_text_len = max(len(item['word_indices']) for item in batch)
    word_indices_padded = [torch.LongTensor(np.pad(item['word_indices'].numpy(), (0, max_text_len - len(item['word_indices'])), constant_values=0)) for item in batch]

    max_audio_len = max(item['audio_raw'].shape[0] for item in batch)
    audio_raw_padded = []
    for item in batch:
        audio_raw = item['audio_raw'].numpy()
        if audio_raw.shape[0] < max_audio_len:
            padded = np.pad(audio_raw, ((0, max_audio_len - audio_raw.shape[0]), (0, 0)))
        else:
            padded = audio_raw
        audio_raw_padded.append(torch.FloatTensor(padded))

    max_video_len = max(item['video_raw'].shape[0] for item in batch)
    video_raw_padded = []
    for item in batch:
        video_raw = item['video_raw'].numpy()
        if video_raw.shape[0] < max_video_len:
            padded = np.pad(video_raw, ((0, max_video_len - video_raw.shape[0]), (0, 0)))
        else:
            padded = video_raw
        video_raw_padded.append(torch.FloatTensor(padded))

    return {
        'word_indices': torch.stack(word_indices_padded),
        'audio_raw': torch.stack(audio_raw_padded),
        'video_raw': torch.stack(video_raw_padded),
        'label': torch.stack([item['label'] for item in batch]),
        'missing_mask': [item['missing_mask'] for item in batch]
    }

# Training and Evaluation

def train_epoch(model, loader, optimizer, loss_fn, device, use_kd=False, teacher=None, temperature=1.0, lambda_kd=0.5):
    model.train()
    total_loss = 0
    for batch in loader:
        word_indices = batch['word_indices'].to(device)
        audio_raw = batch['audio_raw'].to(device)
        video_raw = batch['video_raw'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(word_indices, audio_raw, video_raw, missing_mask=None)
        loss_ce = loss_fn(logits, labels)

        if use_kd and teacher is not None:
            teacher.eval()
            with torch.no_grad():
                teacher_logits = teacher(word_indices, audio_raw, video_raw, missing_mask=None)

            student_probs = F.log_softmax(logits / temperature, dim=1)
            teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
            loss_kd = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)

            loss = loss_ce + lambda_kd * loss_kd
        else:
            loss = loss_ce

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device, compute_ece=False):
    model.eval()
    preds = []
    targets = []
    confidences = []

    with torch.no_grad():
        for batch in loader:
            word_indices = batch['word_indices'].to(device)
            audio_raw = batch['audio_raw'].to(device)
            video_raw = batch['video_raw'].to(device)
            labels = batch['label'].to(device)

            logits = model(word_indices, audio_raw, video_raw, missing_mask=None)
            probs = F.softmax(logits, dim=1)
            pred = logits.argmax(dim=1)
            conf = probs.max(dim=1)[0]

            preds.extend(pred.cpu().numpy())
            targets.extend(labels.cpu().numpy())
            confidences.extend(conf.cpu().numpy())

    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro', zero_division=0)
    prec = precision_score(targets, preds, average='macro', zero_division=0)
    rec = recall_score(targets, preds, average='macro', zero_division=0)

    ece = 0.0
    if compute_ece:
        confidences = np.array(confidences)
        n_bins = 10
        bin_edges = np.linspace(0, 1, n_bins + 1)
        preds_array = np.array(preds)
        targets_array = np.array(targets)

        for i in range(n_bins):
            mask = (confidences >= bin_edges[i]) & (confidences < bin_edges[i+1])
            if mask.sum() > 0:
                bin_acc = (preds_array[mask] == targets_array[mask]).mean()
                bin_conf = confidences[mask].mean()
                ece += mask.sum() / len(preds) * abs(bin_acc - bin_conf)

    return acc, f1, prec, rec, ece

print("=" * 80)
print("PHASE 1: Working Baseline")
print("=" * 80)
print("\n Enhanced Architecture:")


baseline_results = {}

# Train unimodal models for comparison
class TextOnlyModel(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()
        embedding_tensor = torch.FloatTensor(word_embeddings_array)
        self.text_embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=False)
        self.text_lstm = nn.LSTM(300, 128, num_layers=1, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices, audio_raw=None, video_raw=None, missing_mask=None):
        text_embeds = self.text_embedding(word_indices)
        text_lstm_output, _ = self.text_lstm(text_embeds)
        forward_final = text_lstm_output[:, -1, :128]
        backward_final = text_lstm_output[:, 0, 128:]
        text_feat = torch.cat([forward_final, backward_final], dim=1)
        return self.classifier(text_feat)

class AudioOnlyModel(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()
        self.audio_lstm = nn.LSTM(81, 128, num_layers=1, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices=None, audio_raw=None, video_raw=None, missing_mask=None):
        audio_lstm_output, _ = self.audio_lstm(audio_raw)
        forward_final = audio_lstm_output[:, -1, :128]
        backward_final = audio_lstm_output[:, 0, 128:]
        audio_feat = torch.cat([forward_final, backward_final], dim=1)
        return self.classifier(audio_feat)

class VideoOnlyModel(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()
        self.video_lstm = nn.LSTM(75, 128, num_layers=1, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices=None, audio_raw=None, video_raw=None, missing_mask=None):
        video_lstm_output, _ = self.video_lstm(video_raw)
        forward_final = video_lstm_output[:, -1, :128]
        backward_final = video_lstm_output[:, 0, 128:]
        video_feat = torch.cat([forward_final, backward_final], dim=1)
        return self.classifier(video_feat)

teachers = {}

for teacher_name, ModelClass in [('TEXT_teacher', TextOnlyModel), ('AUDIO_teacher', AudioOnlyModel), ('VIDEO_teacher', VideoOnlyModel)]:
    print(f"\nTraining {teacher_name}")

    np.random.seed(42)
    torch.manual_seed(42)

    train_set = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
    dev_set = SimpleDataset(dev_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
    test_set = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
    dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

    train_labels_array = np.array([labels[id] for id in train_ids])
    class_weights = compute_class_weight('balanced', classes=np.unique(train_labels_array), y=train_labels_array)
    class_weights = torch.FloatTensor(class_weights).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)

    model = ModelClass(word_embeddings_array).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader), num_cycles=0.5)

    best_f1 = 0
    patience = 15
    patience_counter = 0

    for epoch in range(NUM_EPOCHS):
        loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
        dev_acc, dev_f1, _, _ = evaluate(model, dev_loader, DEVICE)[:4]
        scheduler.step()

        if dev_f1 > best_f1:
            best_f1 = dev_f1
            patience_counter = 0
            torch.save(model.state_dict(), f'{SAVE_FOLDER}/{teacher_name}.pt')
        else:
            patience_counter += 1

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {loss:.4f} | Dev F1: {dev_f1:.4f}")

        if patience_counter >= patience:
            print(f"Early stopped at epoch {epoch+1}")
            break

    model.load_state_dict(torch.load(f'{SAVE_FOLDER}/{teacher_name}.pt'))
    test_acc, test_f1, test_prec, test_rec, _ = evaluate(model, test_loader, DEVICE)

    teachers[teacher_name] = model
    baseline_results[teacher_name] = {'accuracy': test_acc, 'f1': test_f1, 'precision': test_prec, 'recall': test_rec}
    print(f"  {teacher_name}: {test_acc*100:.2f}%")

print(f"\nTraining multimodal baseline (with cross-modal attention)")

train_set = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
dev_set = SimpleDataset(dev_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
test_set = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

baseline_model = ImprovedMultimodalFusion(word_embeddings_array).to(DEVICE)
optimizer = optim.AdamW(baseline_model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader), num_cycles=0.5)

best_f1 = 0
patience = 15
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    loss = train_epoch(baseline_model, train_loader, optimizer, loss_fn, DEVICE)
    dev_acc, dev_f1, _, _ = evaluate(baseline_model, dev_loader, DEVICE)[:4]
    scheduler.step()

    if dev_f1 > best_f1:
        best_f1 = dev_f1
        patience_counter = 0
        torch.save(baseline_model.state_dict(), f'{SAVE_FOLDER}/baseline_multimodal.pt')
    else:
        patience_counter += 1

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {loss:.4f} | Dev F1: {dev_f1:.4f}")

    if patience_counter >= patience:
        print(f"Early stopped at epoch {epoch+1}")
        break

baseline_model.load_state_dict(torch.load(f'{SAVE_FOLDER}/baseline_multimodal.pt'))
test_acc, test_f1, test_prec, test_rec, ece = evaluate(baseline_model, test_loader, DEVICE, compute_ece=True)

baseline_results['Multimodal_Baseline'] = {'accuracy': test_acc, 'f1': test_f1, 'precision': test_prec, 'recall': test_rec, 'ece': ece}

print(f"✅ Multimodal Baseline: {test_acc*100:.2f}%\n")

print("=" * 80)
print("PHASE 2: Stress Test")
print("=" * 80)

print("\nEvaluating baseline under missing modalities")

test_set_no_drop = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
test_loader_no_drop = DataLoader(test_set_no_drop, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

missingness_patterns = [
    {'text': False, 'audio': False, 'video': False, 'name': 'All Present'},
    {'text': True, 'audio': False, 'video': False, 'name': 'Text Missing'},
    {'text': False, 'audio': True, 'video': False, 'name': 'Audio Missing'},
    {'text': False, 'audio': False, 'video': True, 'name': 'Video Missing'},
    {'text': True, 'audio': True, 'video': False, 'name': 'Text+Audio Missing'},
    {'text': True, 'audio': False, 'video': True, 'name': 'Text+Video Missing'},
    {'text': False, 'audio': True, 'video': True, 'name': 'Audio+Video Missing'},
    {'text': True, 'audio': True, 'video': True, 'name': 'All Missing'},
]

print(f"\n{'Missingness Pattern':<30} {'Accuracy':<12} {'F1':<10} {'Drop':<10}")
print("-" * 62)

missingness_results = {}
all_present_acc = None

with torch.no_grad():
    for pattern in missingness_patterns:
        preds, targets = [], []

        for batch in test_loader_no_drop:
            word_indices = batch['word_indices'].to(DEVICE)
            audio_raw = batch['audio_raw'].to(DEVICE)
            video_raw = batch['video_raw'].to(DEVICE)
            labels_batch = batch['label'].to(DEVICE)

            missing_mask = {k: v for k, v in pattern.items() if k != 'name'}
            logits = baseline_model(word_indices, audio_raw, video_raw, missing_mask=missing_mask)

            preds.extend(logits.argmax(dim=1).cpu().numpy())
            targets.extend(labels_batch.cpu().numpy())

        acc = accuracy_score(targets, preds)
        f1 = f1_score(targets, preds, average='macro', zero_division=0)

        if pattern['name'] == 'All Present':
            all_present_acc = acc
            drop = 0.0
        else:
            drop = (all_present_acc - acc) * 100

        missingness_results[pattern['name']] = {'accuracy': acc, 'f1': f1, 'drop': drop}
        print(f"{pattern['name']:<30} {acc:.4f} ({acc*100:5.2f}%)  {f1:.4f}    {drop:+.2f}%")

print("\n" + "=" * 80)
print("PHASE 3: Cross Modal Knowledge Distillation")
print("=" * 80)

best_teacher_name = max([(k, v['accuracy']) for k, v in baseline_results.items() if 'teacher' in k], key=lambda x: x[1])[0]
best_teacher = teachers[best_teacher_name]
print(f"\nBest teacher: {best_teacher_name} ({baseline_results[best_teacher_name]['accuracy']*100:.2f}%)\n")

train_set_kd = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
dev_set_kd = SimpleDataset(dev_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
test_set_kd = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)

train_loader_kd = DataLoader(train_set_kd, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
dev_loader_kd = DataLoader(dev_set_kd, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)
test_loader_kd = DataLoader(test_set_kd, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

kd_student = ImprovedMultimodalFusion(word_embeddings_array).to(DEVICE)
optimizer_kd = optim.AdamW(kd_student.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler_kd = get_cosine_schedule_with_warmup(optimizer_kd, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader_kd), num_cycles=0.5)

best_f1_kd = 0
patience_kd = 15
patience_counter_kd = 0

for epoch in range(NUM_EPOCHS):
    loss = train_epoch(kd_student, train_loader_kd, optimizer_kd, loss_fn, DEVICE, use_kd=True, teacher=best_teacher, temperature=4.0, lambda_kd=0.5)
    dev_acc, dev_f1, _, _ = evaluate(kd_student, dev_loader_kd, DEVICE)[:4]
    scheduler_kd.step()

    if dev_f1 > best_f1_kd:
        best_f1_kd = dev_f1
        patience_counter_kd = 0
        torch.save(kd_student.state_dict(), f'{SAVE_FOLDER}/kd_student.pt')
    else:
        patience_counter_kd += 1

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {loss:.4f} | Dev F1: {dev_f1:.4f}")

    if patience_counter_kd >= patience_kd:
        print(f"Early stopped at epoch {epoch+1}")
        break

kd_student.load_state_dict(torch.load(f'{SAVE_FOLDER}/kd_student.pt'))
test_acc_kd, test_f1_kd, test_prec_kd, test_rec_kd, ece_kd = evaluate(kd_student, test_loader_kd, DEVICE, compute_ece=True)

print(f"\n KD Student: {test_acc_kd*100:.2f}% (vs Baseline: {baseline_results['Multimodal_Baseline']['accuracy']*100:.2f}%)")
print(f"   Improvement: {(test_acc_kd - baseline_results['Multimodal_Baseline']['accuracy'])*100:+.2f} pp\n")

print("=" * 80)
print("PHASE 4: Confidence Calibration under missingness")
print("=" * 80)

train_set_cal = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.2)
dev_set_cal = SimpleDataset(dev_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)
test_set_cal = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, 0.0)

train_loader_cal = DataLoader(train_set_cal, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
dev_loader_cal = DataLoader(dev_set_cal, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)
test_loader_cal = DataLoader(test_set_cal, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

calibrated_model = ImprovedMultimodalFusion(word_embeddings_array).to(DEVICE)
optimizer_cal = optim.AdamW(calibrated_model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler_cal = get_cosine_schedule_with_warmup(optimizer_cal, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader_cal), num_cycles=0.5)

best_f1_cal = 0
patience_cal = 15
patience_counter_cal = 0

print("\nTraining with 20% modality dropout\n")

for epoch in range(NUM_EPOCHS):
    loss = train_epoch(calibrated_model, train_loader_cal, optimizer_cal, loss_fn, DEVICE)
    dev_acc, dev_f1, _, _ = evaluate(calibrated_model, dev_loader_cal, DEVICE)[:4]
    scheduler_cal.step()

    if dev_f1 > best_f1_cal:
        best_f1_cal = dev_f1
        patience_counter_cal = 0
        torch.save(calibrated_model.state_dict(), f'{SAVE_FOLDER}/calibrated_model.pt')
    else:
        patience_counter_cal += 1

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {loss:.4f} | Dev F1: {dev_f1:.4f}")

    if patience_counter_cal >= patience_cal:
        print(f"Early stopped at epoch {epoch+1}")
        break

calibrated_model.load_state_dict(torch.load(f'{SAVE_FOLDER}/calibrated_model.pt'))
test_acc_cal, test_f1_cal, test_prec_cal, test_rec_cal, ece_cal = evaluate(calibrated_model, test_loader_cal, DEVICE, compute_ece=True)

print(f"\n Calibrated Model: {test_acc_cal*100:.2f}%")
print(f"   ECE: {ece_cal:.4f}")
print(f"   Robustness: Trained with 20% modality dropout\n")

# Final Results

print("=" * 80)
print(" Final Summary")
print("=" * 80)


print("  PHASE 1: Working Baseline")
print(f"{'Model':<30} {'Accuracy':<12} {'F1':<10}")
print("-" * 52)
for name, res in baseline_results.items():
    print(f"{name:<30} {res['accuracy']*100:6.2f}%         {res['f1']:.4f}")

print("\n PHASE 2: Stress Test Results")
print(f"{'Pattern':<30} {'Accuracy':<12} {'Drop':<10}")
print("-" * 52)
for pattern, res in missingness_results.items():
    print(f"{pattern:<30} {res['accuracy']*100:6.2f}%         {res['drop']:+.2f}%")

print("\n PHASE 3: Knowledge Distillation")
print(f"Baseline: {baseline_results['Multimodal_Baseline']['accuracy']*100:.2f}%")
print(f"KD Student: {test_acc_kd*100:.2f}%")
print(f"Improvement: {(test_acc_kd - baseline_results['Multimodal_Baseline']['accuracy'])*100:+.2f} pp")

print("\n PHASE 4: Calibration")
print(f"Accuracy: {test_acc_cal*100:.2f}%")
print(f"ECE: {ece_cal:.4f}")


ModuleNotFoundError: No module named 'numpy'

In [None]:
================================================================================
ROBUST MULTIMODAL FUSION WITH MISSING MODALITIES

================================================================================
Device: cuda

Loading data...
================================================================================
PHASE 1: Working Baseline
================================================================================

 Enhanced Architecture:

Training TEXT_teacher
Epoch  20 | Loss: 0.5578 | Dev F1: 0.6697
Early stopped at epoch 33
 TEXT_teacher: 67.90%

Training AUDIO_teacher
Epoch  20 | Loss: 0.6595 | Dev F1: 0.5635
Epoch  40 | Loss: 0.5496 | Dev F1: 0.5654
Early stopped at epoch 48
 AUDIO_teacher: 58.88%

Training VIDEO_teacher
Epoch  20 | Loss: 0.6817 | Dev F1: 0.5108
Epoch  40 | Loss: 0.6536 | Dev F1: 0.5192
Early stopped at epoch 52
 VIDEO_teacher: 53.83%

Training multimodal baseline (with cross-modal attention)
Epoch  20 | Loss: 0.5271 | Dev F1: 0.6848
Early stopped at epoch 37
 Multimodal Baseline: 68.48%

================================================================================
PHASE 2: Stress Test
================================================================================

Evaluating baseline under missing modalities

Missingness Pattern            Accuracy     F1         Drop
--------------------------------------------------------------
All Present                    0.6848 (68.48%)  0.6844    +0.00%
Text Missing                   0.5666 (56.66%)  0.5658    +11.82%
Audio Missing                  0.6742 (67.42%)  0.6742    +1.06%
Video Missing                  0.6821 (68.21%)  0.6814    +0.27%
Text+Audio Missing             0.5161 (51.61%)  0.4649    +16.87%
Text+Video Missing             0.5562 (55.62%)  0.5535    +12.86%
Audio+Video Missing            0.6696 (66.96%)  0.6696    +1.52%
All Missing                    0.5021 (50.21%)  0.3343    +18.27%

================================================================================
PHASE 3: Cross Modal Knowledge Distillation
================================================================================

Best teacher: TEXT_teacher (67.90%)

Epoch  20 | Loss: 0.5483 | Dev F1: 0.6818
Early stopped at epoch 36

  KD Student: 69.39% (vs Baseline: 68.48%)
   Improvement: +0.91 pp

================================================================================
PHASE 4: Confidence Calibration under missingness
================================================================================

Training with 20% modality dropout

Epoch  20 | Loss: 0.5222 | Dev F1: 0.6698
Early stopped at epoch 34

  Calibrated Model: 68.84%
   ECE: 0.0577
   Robustness: Trained with 20% modality dropout

================================================================================
 Final Summary
================================================================================
  PHASE 1: Working Baseline
Model                          Accuracy     F1
----------------------------------------------------
TEXT_teacher                    67.90%         0.6781
AUDIO_teacher                   58.88%         0.5882
VIDEO_teacher                   53.83%         0.5364
Multimodal_Baseline             68.48%         0.6844

 PHASE 2: Stress Test Results
Pattern                        Accuracy     Drop
----------------------------------------------------
All Present                     68.48%         +0.00%
Text Missing                    56.66%         +11.82%
Audio Missing                   67.42%         +1.06%
Video Missing                   68.21%         +0.27%
Text+Audio Missing              51.61%         +16.87%
Text+Video Missing              55.62%         +12.86%
Audio+Video Missing             66.96%         +1.52%
All Missing                     50.21%         +18.27%

 PHASE 3: Knowledge Distillation
Baseline: 68.48%
KD Student: 69.39%
Improvement: +0.91 pp

 PHASE 4: Calibration
Accuracy: 68.84%
ECE: 0.0577


In [1]:
!wget https://www.dropbox.com/s/izk6khkrdwcncia/ted_humor_sdk_v1.zip?dl=1

--2026-01-09 09:16:07--  https://www.dropbox.com/s/izk6khkrdwcncia/ted_humor_sdk_v1.zip?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.18, 2620:100:6031:18::a27d:5112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.dropbox.com/scl/fi/002rz175n5ferwyvif2hv/ted_humor_sdk_v1.zip?rlkey=r8bszra1ez6zedylbx1d4wm99&dl=1 [following]
--2026-01-09 09:16:08--  https://www.dropbox.com/scl/fi/002rz175n5ferwyvif2hv/ted_humor_sdk_v1.zip?rlkey=r8bszra1ez6zedylbx1d4wm99&dl=1
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc6315a566efd0962680a98e7777.dl.dropboxusercontent.com/cd/0/inline/C4nVZJXh8dT9E_zEkkvQjZUwO0zl6OvuZw0aQqw3ZevMd5cLmot9X0uLl8D40JG4jOoh1ATHpJi4vvp52RmHZe1MUvJq2oaSC6lx0MypCScvz6Qhvdw4xjYmeTsABipnd7Nvwglq90KcOidn8csENZPa/file?dl=1# [following]
--2026-01-09 09:16:08--  https://uc6315a566efd0962680a