In [None]:
!wget https://www.dropbox.com/s/izk6khkrdwcncia/ted_humor_sdk_v1.zip?dl=1

--2026-01-16 11:03:11--  https://www.dropbox.com/s/izk6khkrdwcncia/ted_humor_sdk_v1.zip?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.18, 2620:100:6031:18::a27d:5112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.dropbox.com/scl/fi/002rz175n5ferwyvif2hv/ted_humor_sdk_v1.zip?rlkey=r8bszra1ez6zedylbx1d4wm99&dl=1 [following]
--2026-01-16 11:03:11--  https://www.dropbox.com/scl/fi/002rz175n5ferwyvif2hv/ted_humor_sdk_v1.zip?rlkey=r8bszra1ez6zedylbx1d4wm99&dl=1
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucb09e0e16242f1458faaa87fd7e.dl.dropboxusercontent.com/cd/0/inline/C5GVaeaWMmQSh-7nG6WLFOlAhgt83s1wpj4QjaikKex5kSmmCLlYwGXqq-FPhAfBqfFnrItVk38sbt6Zr-UNi_REkB5RVkMgGbsHVT6QbWgBygpibM8kF3qAs_vokYKDcIuMGkTwXUcAYZsZwmUv8KcU/file?dl=1# [following]
--2026-01-16 11:03:12--  https://ucb09e0e16242f1458faa

In [None]:
!unzip ted_humor_sdk_v1.zip?dl=1

Archive:  ted_humor_sdk_v1.zip?dl=1
   creating: final_humor_sdk/
  inflating: final_humor_sdk/word_embedding_list.pkl  
  inflating: final_humor_sdk/data_folds.pkl  
  inflating: final_humor_sdk/humor_label_sdk.pkl  
  inflating: final_humor_sdk/covarep_features_sdk.pkl  
  inflating: final_humor_sdk/openface_features_sdk.pkl  
  inflating: final_humor_sdk/word_embedding_indexes_sdk.pkl  


In [None]:
import os
import pickle
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from transformers import get_cosine_schedule_with_warmup
import warnings
warnings.filterwarnings('ignore')

# CONFIGURATION

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_FOLDER = './ted_humor_data'
SAVE_FOLDER = './exp_a_baseline_final'
os.makedirs(SAVE_FOLDER, exist_ok=True)

BATCH_SIZE = 32
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
EARLY_STOP_PATIENCE = 15

print("=" * 80)
print("BASELINE - Standard Training")
print("=" * 80)
print(f"Device: {DEVICE}\n")

# LOAD DATA

print("Loading data")
with open(f'{DATA_FOLDER}/data_folds.pkl', 'rb') as f:
    data_folds = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/word_embedding_indexes_sdk.pkl', 'rb') as f:
    text_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/covarep_features_sdk.pkl', 'rb') as f:
    audio_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/openface_features_sdk.pkl', 'rb') as f:
    video_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/humor_label_sdk.pkl', 'rb') as f:
    labels = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/word_embedding_list.pkl', 'rb') as f:
    word_embeddings_list = pickle.load(f, encoding='latin1')

train_ids = data_folds['train']
dev_ids = data_folds['dev']
test_ids = data_folds['test']

word_embeddings_array = np.array(word_embeddings_list, dtype=np.float32)
print(f"✓ Data loaded successfully\n")

# SCALERS


train_audio_list = []
for id in train_ids:
    try:
        audio_data = audio_features.get(id, {})
        if isinstance(audio_data, dict):
            audio_feat = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0] if audio_data else None))
        else:
            audio_feat = audio_data
        if audio_feat is not None:
            arr = np.array(audio_feat, dtype=np.float32).reshape(-1, 81)
            if arr.shape[0] > 0:
                train_audio_list.append(arr)
    except:
        pass

scaler_audio = StandardScaler().fit(np.vstack(train_audio_list) if train_audio_list else np.zeros((1, 81)))

train_video_list = []
for id in train_ids:
    try:
        video_data = video_features.get(id, {})
        if isinstance(video_data, dict):
            video_feat = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0] if video_data else None))
        else:
            video_feat = video_data
        if video_feat is not None:
            arr = np.array(video_feat, dtype=np.float32).reshape(-1, 75)
            if arr.shape[0] > 0:
                train_video_list.append(arr)
    except:
        pass

scaler_video = StandardScaler().fit(np.vstack(train_video_list) if train_video_list else np.zeros((1, 75)))
print(f"Scalers created\n")

# ARCHITECTURE

class CrossModalAttention(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.dim = dim
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, text_feat, modality_feat):
        q = self.query(text_feat)
        k = self.key(modality_feat)
        v = self.value(modality_feat)
        scores = torch.matmul(q.unsqueeze(1), k.unsqueeze(2)) / (self.dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v.unsqueeze(1)).squeeze(1)
        output = self.out_proj(attn_output)
        output = self.dropout(output)
        return output


class ImprovedMultimodalFusion(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()

        # TEXT ENCODER
        embedding_tensor = torch.FloatTensor(word_embeddings_array)
        self.text_embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=False)
        self.text_lstm = nn.LSTM(300, 128, num_layers=1, batch_first=True, bidirectional=True)

        # AUDIO ENCODER
        self.audio_lstm = nn.LSTM(81, 128, num_layers=1, batch_first=True, bidirectional=True)

        # VIDEO ENCODER
        self.video_lstm = nn.LSTM(75, 128, num_layers=1, batch_first=True, bidirectional=True)

        # CROSS-MODAL ATTENTION
        self.audio_attention = CrossModalAttention(dim=256)
        self.video_attention = CrossModalAttention(dim=256)

        # FUSION LAYER
        self.fusion = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # CLASSIFIER
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def extract_final_hidden_state(self, lstm_output):
        forward_final = lstm_output[:, -1, :128]
        backward_final = lstm_output[:, 0, 128:]
        final_state = torch.cat([forward_final, backward_final], dim=1)
        return final_state

    def forward(self, word_indices, audio_raw, video_raw, missing_mask=None):
        # TEXT ENCODING
        text_embeds = self.text_embedding(word_indices)
        text_lstm_output, _ = self.text_lstm(text_embeds)
        text_feat = self.extract_final_hidden_state(text_lstm_output)

        # AUDIO ENCODING
        audio_lstm_output, _ = self.audio_lstm(audio_raw)
        audio_feat = self.extract_final_hidden_state(audio_lstm_output)

        # VIDEO ENCODING
        video_lstm_output, _ = self.video_lstm(video_raw)
        video_feat = self.extract_final_hidden_state(video_lstm_output)

        # APPLY MISSING MASKS
        if missing_mask is not None:
            if missing_mask.get('text', False):
                text_feat = torch.zeros_like(text_feat)
            if missing_mask.get('audio', False):
                audio_feat = torch.zeros_like(audio_feat)
            if missing_mask.get('video', False):
                video_feat = torch.zeros_like(video_feat)

        # CROSS-MODAL ATTENTION
        audio_aligned = self.audio_attention(text_feat, audio_feat)
        video_aligned = self.video_attention(text_feat, video_feat)

        # FUSION
        fused_input = torch.cat([text_feat, audio_aligned, video_aligned], dim=1)
        fused = self.fusion(fused_input)
        logits = self.classifier(fused)

        return logits


# DATASET AND DATALOADER

class SimpleDataset(Dataset):
    def __init__(self, ids, text_data, audio_data, video_data, labels, audio_scaler, video_scaler, word_embeddings_list, modality_drop_rate=0.0):
        self.ids = ids
        self.text_data = text_data
        self.audio_data = audio_data
        self.video_data = video_data
        self.labels = labels
        self.audio_scaler = audio_scaler
        self.video_scaler = video_scaler
        self.word_embeddings_list = word_embeddings_list
        self.modality_drop_rate = modality_drop_rate

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sample_id = self.ids[idx]

        # TEXT
        text_data = self.text_data[sample_id]
        if isinstance(text_data, dict):
            text_indices = text_data.get('punchline_features', text_data.get('punchline', list(text_data.values())[0]))
        else:
            text_indices = text_data

        word_indices = []
        if isinstance(text_indices, (list, np.ndarray)):
            for idx_val in text_indices:
                try:
                    idx_val = int(idx_val[0]) if isinstance(idx_val, (list, np.ndarray)) and len(idx_val) > 0 else int(idx_val)
                    if 0 <= idx_val < len(self.word_embeddings_list):
                        word_indices.append(idx_val)
                except:
                    pass
        if len(word_indices) == 0:
            word_indices = [0]
        word_indices = word_indices[:512]

        # AUDIO
        audio_data = self.audio_data[sample_id]
        if isinstance(audio_data, dict):
            audio_features_val = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0]))
        else:
            audio_features_val = audio_data
        try:
            audio_raw = np.array(audio_features_val, dtype=np.float32).reshape(-1, 81)
            if audio_raw.shape[0] == 0:
                audio_raw = np.zeros((1, 81), dtype=np.float32)
            if self.audio_scaler:
                audio_raw = self.audio_scaler.transform(audio_raw)
        except:
            audio_raw = np.zeros((1, 81), dtype=np.float32)

        # VIDEO
        video_data = self.video_data[sample_id]
        if isinstance(video_data, dict):
            video_features_val = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0]))
        else:
            video_features_val = video_data
        try:
            video_raw = np.array(video_features_val, dtype=np.float32).reshape(-1, 75)
            if video_raw.shape[0] == 0:
                video_raw = np.zeros((1, 75), dtype=np.float32)
            if self.video_scaler:
                video_raw = self.video_scaler.transform(video_raw)
        except:
            video_raw = np.zeros((1, 75), dtype=np.float32)

        missing_mask = {
            'text': np.random.rand() < self.modality_drop_rate,
            'audio': np.random.rand() < self.modality_drop_rate,
            'video': np.random.rand() < self.modality_drop_rate
        }

        return {
            'word_indices': torch.LongTensor(word_indices),
            'audio_raw': torch.FloatTensor(audio_raw),
            'video_raw': torch.FloatTensor(video_raw),
            'label': torch.tensor(self.labels[sample_id], dtype=torch.long),
            'missing_mask': missing_mask
        }


def collate_fn(batch):
    max_text_len = max(len(item['word_indices']) for item in batch)
    word_indices_padded = [torch.LongTensor(np.pad(item['word_indices'].numpy(), (0, max_text_len - len(item['word_indices'])), constant_values=0)) for item in batch]

    max_audio_len = max(item['audio_raw'].shape[0] for item in batch)
    audio_raw_padded = []
    for item in batch:
        audio_raw = item['audio_raw'].numpy()
        if audio_raw.shape[0] < max_audio_len:
            padded = np.pad(audio_raw, ((0, max_audio_len - audio_raw.shape[0]), (0, 0)))
        else:
            padded = audio_raw
        audio_raw_padded.append(torch.FloatTensor(padded))

    max_video_len = max(item['video_raw'].shape[0] for item in batch)
    video_raw_padded = []
    for item in batch:
        video_raw = item['video_raw'].numpy()
        if video_raw.shape[0] < max_video_len:
            padded = np.pad(video_raw, ((0, max_video_len - video_raw.shape[0]), (0, 0)))
        else:
            padded = video_raw
        video_raw_padded.append(torch.FloatTensor(padded))

    return {
        'word_indices': torch.stack(word_indices_padded),
        'audio_raw': torch.stack(audio_raw_padded),
        'video_raw': torch.stack(video_raw_padded),
        'label': torch.stack([item['label'] for item in batch]),
        'missing_mask': [item['missing_mask'] for item in batch]
    }


# TRAINING AND EVALUATION

def train_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for batch in loader:
        word_indices = batch['word_indices'].to(device)
        audio_raw = batch['audio_raw'].to(device)
        video_raw = batch['video_raw'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(word_indices, audio_raw, video_raw, missing_mask=None)
        loss = loss_fn(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    preds = []
    targets = []

    with torch.no_grad():
        for batch in loader:
            word_indices = batch['word_indices'].to(device)
            audio_raw = batch['audio_raw'].to(device)
            video_raw = batch['video_raw'].to(device)
            labels = batch['label'].to(device)

            logits = model(word_indices, audio_raw, video_raw, missing_mask=None)
            pred = logits.argmax(dim=1)
            preds.extend(pred.cpu().numpy())
            targets.extend(labels.cpu().numpy())

    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro', zero_division=0)
    prec = precision_score(targets, preds, average='macro', zero_division=0)
    rec = recall_score(targets, preds, average='macro', zero_division=0)

    return acc, f1, prec, rec


# MAIN TRAINING LOOP

print("=" * 80)
print("TRAINING BASELINE MODEL")
print("=" * 80 + "\n")

train_set = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=0.0)
dev_set = SimpleDataset(dev_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=0.0)
test_set = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=0.0)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

# Compute class weights
train_labels_array = np.array([labels[id] for id in train_ids])
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels_array), y=train_labels_array)
class_weights = torch.FloatTensor(class_weights).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

# Model, optimizer, scheduler
model = ImprovedMultimodalFusion(word_embeddings_array).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader), num_cycles=0.5)

best_f1 = 0
patience_counter = 0

print(f"{'Epoch':<8} {'Train Loss':<12} {'Dev F1':<10}")
print("-" * 30)

for epoch in range(NUM_EPOCHS):
    loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
    dev_acc, dev_f1, _, _ = evaluate(model, dev_loader, DEVICE)
    scheduler.step()

    if dev_f1 > best_f1:
        best_f1 = dev_f1
        patience_counter = 0
        torch.save(model.state_dict(), f'{SAVE_FOLDER}/baseline_best.pt')
    else:
        patience_counter += 1

    if (epoch + 1) % 20 == 0:
        print(f"{epoch+1:<8} {loss:<12.4f} {dev_f1:<10.4f}")

    if patience_counter >= EARLY_STOP_PATIENCE:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

print()

# STRESS TEST: MISSING MODALITY EVALUATION

print("=" * 80)
print("STRESS TEST: MISSING MODALITY EVALUATION")
print("=" * 80 + "\n")

model.load_state_dict(torch.load(f'{SAVE_FOLDER}/baseline_best.pt'))

missingness_patterns = [
    {'text': False, 'audio': False, 'video': False, 'name': 'All Present'},
    {'text': True, 'audio': False, 'video': False, 'name': 'Text Missing'},
    {'text': False, 'audio': True, 'video': False, 'name': 'Audio Missing'},
    {'text': False, 'audio': False, 'video': True, 'name': 'Video Missing'},
    {'text': True, 'audio': True, 'video': False, 'name': 'Text+Audio Missing'},
    {'text': True, 'audio': False, 'video': True, 'name': 'Text+Video Missing'},
    {'text': False, 'audio': True, 'video': True, 'name': 'Audio+Video Missing'},
]

print(f"{'Pattern':<30} {'Accuracy':<12} {'F1':<10}")
print("-" * 52)

results = {}
with torch.no_grad():
    for pattern in missingness_patterns:
        preds, targets = [], []
        for batch in test_loader:
            word_indices = batch['word_indices'].to(DEVICE)
            audio_raw = batch['audio_raw'].to(DEVICE)
            video_raw = batch['video_raw'].to(DEVICE)
            labels_batch = batch['label'].to(DEVICE)

            missing_mask = {k: v for k, v in pattern.items() if k != 'name'}
            logits = model(word_indices, audio_raw, video_raw, missing_mask=missing_mask)
            preds.extend(logits.argmax(dim=1).cpu().numpy())
            targets.extend(labels_batch.cpu().numpy())

        acc = accuracy_score(targets, preds)
        f1 = f1_score(targets, preds, average='macro', zero_division=0)
        results[pattern['name']] = acc
        print(f"{pattern['name']:<30} {acc:.4f} ({acc*100:5.2f}%) {f1:.4f}")

print()


with open(f'{SAVE_FOLDER}/results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("=" * 80)
print("Results")
print("=" * 80)
print(f"\nResults saved to: {SAVE_FOLDER}/results.json")
print(f"\nKey Finding:")
print(f"  All Present: {results['All Present']*100:.2f}%")
print(f"  Text Missing: {results['Text Missing']*100:.2f}%")
print(f"  Drop: {(results['All Present'] - results['Text Missing'])*100:.2f}pp ← VULNERABILITY!")

BASELINE - Standard Training
Device: cuda

Loading data
✓ Data loaded successfully

Scalers created

TRAINING BASELINE MODEL

Epoch    Train Loss   Dev F1    
------------------------------
20       0.5247       0.6841    

Early stopping at epoch 30

STRESS TEST: MISSING MODALITY EVALUATION

Pattern                        Accuracy     F1        
----------------------------------------------------
All Present                    0.6942 (69.42%) 0.6934
Text Missing                   0.5459 (54.59%) 0.5025
Audio Missing                  0.6821 (68.21%) 0.6819
Video Missing                  0.6927 (69.27%) 0.6918
Text+Audio Missing             0.5109 (51.09%) 0.4085
Text+Video Missing             0.5359 (53.59%) 0.4833
Audio+Video Missing            0.6760 (67.60%) 0.6757

Results

Results saved to: ./exp_a_baseline_final/results.json

Key Finding:
  All Present: 69.42%
  Text Missing: 54.59%
  Drop: 14.83pp ← VULNERABILITY!


In [None]:
import os
import pickle
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from transformers import get_cosine_schedule_with_warmup
import warnings
warnings.filterwarnings('ignore')

# CONFIGURATION

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_FOLDER = './ted_humor_data'
SAVE_FOLDER = './exp_b_proposed_final'
os.makedirs(SAVE_FOLDER, exist_ok=True)

BATCH_SIZE = 32
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
EARLY_STOP_PATIENCE = 15
MODALITY_DROP_RATE = 0.2
KD_TEMPERATURE = 4.0
KD_WEIGHT = 0.5
ECE_WEIGHT = 0.1


print(f"Device: {DEVICE}\n")

# LOAD DATA

print("Loading data...")
with open(f'{DATA_FOLDER}/data_folds.pkl', 'rb') as f:
    data_folds = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/word_embedding_indexes_sdk.pkl', 'rb') as f:
    text_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/covarep_features_sdk.pkl', 'rb') as f:
    audio_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/openface_features_sdk.pkl', 'rb') as f:
    video_features = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/humor_label_sdk.pkl', 'rb') as f:
    labels = pickle.load(f, encoding='latin1')
with open(f'{DATA_FOLDER}/word_embedding_list.pkl', 'rb') as f:
    word_embeddings_list = pickle.load(f, encoding='latin1')

train_ids = data_folds['train']
dev_ids = data_folds['dev']
test_ids = data_folds['test']

word_embeddings_array = np.array(word_embeddings_list, dtype=np.float32)
print(f" Data loaded successfully\n")

# SCALERS

print("Creating scalers")
train_audio_list = []
for id in train_ids:
    try:
        audio_data = audio_features.get(id, {})
        if isinstance(audio_data, dict):
            audio_feat = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0] if audio_data else None))
        else:
            audio_feat = audio_data
        if audio_feat is not None:
            arr = np.array(audio_feat, dtype=np.float32).reshape(-1, 81)
            if arr.shape[0] > 0:
                train_audio_list.append(arr)
    except:
        pass

scaler_audio = StandardScaler().fit(np.vstack(train_audio_list) if train_audio_list else np.zeros((1, 81)))

train_video_list = []
for id in train_ids:
    try:
        video_data = video_features.get(id, {})
        if isinstance(video_data, dict):
            video_feat = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0] if video_data else None))
        else:
            video_feat = video_data
        if video_feat is not None:
            arr = np.array(video_feat, dtype=np.float32).reshape(-1, 75)
            if arr.shape[0] > 0:
                train_video_list.append(arr)
    except:
        pass

scaler_video = StandardScaler().fit(np.vstack(train_video_list) if train_video_list else np.zeros((1, 75)))


# ARCHITECTURE CLASSES

class CrossModalAttention(nn.Module):
    def __init__(self, dim=256):
        super().__init__()
        self.dim = dim
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, text_feat, modality_feat):
        q = self.query(text_feat)
        k = self.key(modality_feat)
        v = self.value(modality_feat)
        scores = torch.matmul(q.unsqueeze(1), k.unsqueeze(2)) / (self.dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v.unsqueeze(1)).squeeze(1)
        output = self.out_proj(attn_output)
        output = self.dropout(output)
        return output


class ImprovedMultimodalFusion(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()

        # TEXT ENCODER
        embedding_tensor = torch.FloatTensor(word_embeddings_array)
        self.text_embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=False)
        self.text_lstm = nn.LSTM(300, 128, num_layers=1, batch_first=True, bidirectional=True)

        # AUDIO ENCODER
        self.audio_lstm = nn.LSTM(81, 128, num_layers=1, batch_first=True, bidirectional=True)

        # VIDEO ENCODER
        self.video_lstm = nn.LSTM(75, 128, num_layers=1, batch_first=True, bidirectional=True)

        # CROSS-MODAL ATTENTION
        self.audio_attention = CrossModalAttention(dim=256)
        self.video_attention = CrossModalAttention(dim=256)

        # FUSION LAYER
        self.fusion = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # CLASSIFIER
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def extract_final_hidden_state(self, lstm_output):
        forward_final = lstm_output[:, -1, :128]
        backward_final = lstm_output[:, 0, 128:]
        final_state = torch.cat([forward_final, backward_final], dim=1)
        return final_state

    def forward(self, word_indices, audio_raw, video_raw, missing_mask=None):
        # TEXT ENCODING
        text_embeds = self.text_embedding(word_indices)
        text_lstm_output, _ = self.text_lstm(text_embeds)
        text_feat = self.extract_final_hidden_state(text_lstm_output)

        # AUDIO ENCODING
        audio_lstm_output, _ = self.audio_lstm(audio_raw)
        audio_feat = self.extract_final_hidden_state(audio_lstm_output)

        # VIDEO ENCODING
        video_lstm_output, _ = self.video_lstm(video_raw)
        video_feat = self.extract_final_hidden_state(video_lstm_output)

        # APPLY MISSING MASKS
        if missing_mask is not None:
            if missing_mask.get('text', False):
                text_feat = torch.zeros_like(text_feat)
            if missing_mask.get('audio', False):
                audio_feat = torch.zeros_like(audio_feat)
            if missing_mask.get('video', False):
                video_feat = torch.zeros_like(video_feat)

        # CROSS-MODAL ATTENTION
        audio_aligned = self.audio_attention(text_feat, audio_feat)
        video_aligned = self.video_attention(text_feat, video_feat)

        # FUSION
        fused_input = torch.cat([text_feat, audio_aligned, video_aligned], dim=1)
        fused = self.fusion(fused_input)
        logits = self.classifier(fused)

        return logits


# UNIMODAL TEACHER MODELS

class TextOnlyModel(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()
        embedding_tensor = torch.FloatTensor(word_embeddings_array)
        self.text_embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=False)
        self.text_lstm = nn.LSTM(300, 128, num_layers=1, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices, audio_raw=None, video_raw=None, missing_mask=None):
        text_embeds = self.text_embedding(word_indices)
        text_lstm_output, _ = self.text_lstm(text_embeds)
        forward_final = text_lstm_output[:, -1, :128]
        backward_final = text_lstm_output[:, 0, 128:]
        text_feat = torch.cat([forward_final, backward_final], dim=1)
        return self.classifier(text_feat)


class AudioOnlyModel(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()
        self.audio_lstm = nn.LSTM(81, 128, num_layers=1, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices=None, audio_raw=None, video_raw=None, missing_mask=None):
        audio_lstm_output, _ = self.audio_lstm(audio_raw)
        forward_final = audio_lstm_output[:, -1, :128]
        backward_final = audio_lstm_output[:, 0, 128:]
        audio_feat = torch.cat([forward_final, backward_final], dim=1)
        return self.classifier(audio_feat)


class VideoOnlyModel(nn.Module):
    def __init__(self, word_embeddings_array):
        super().__init__()
        self.video_lstm = nn.LSTM(75, 128, num_layers=1, batch_first=True, bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2)
        )

    def forward(self, word_indices=None, audio_raw=None, video_raw=None, missing_mask=None):
        video_lstm_output, _ = self.video_lstm(video_raw)
        forward_final = video_lstm_output[:, -1, :128]
        backward_final = video_lstm_output[:, 0, 128:]
        video_feat = torch.cat([forward_final, backward_final], dim=1)
        return self.classifier(video_feat)


# DATASET AND DATALOADER

class SimpleDataset(Dataset):
    def __init__(self, ids, text_data, audio_data, video_data, labels, audio_scaler, video_scaler, word_embeddings_list, modality_drop_rate=0.0):
        self.ids = ids
        self.text_data = text_data
        self.audio_data = audio_data
        self.video_data = video_data
        self.labels = labels
        self.audio_scaler = audio_scaler
        self.video_scaler = video_scaler
        self.word_embeddings_list = word_embeddings_list
        self.modality_drop_rate = modality_drop_rate

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sample_id = self.ids[idx]

        # TEXT
        text_data = self.text_data[sample_id]
        if isinstance(text_data, dict):
            text_indices = text_data.get('punchline_features', text_data.get('punchline', list(text_data.values())[0]))
        else:
            text_indices = text_data

        word_indices = []
        if isinstance(text_indices, (list, np.ndarray)):
            for idx_val in text_indices:
                try:
                    idx_val = int(idx_val[0]) if isinstance(idx_val, (list, np.ndarray)) and len(idx_val) > 0 else int(idx_val)
                    if 0 <= idx_val < len(self.word_embeddings_list):
                        word_indices.append(idx_val)
                except:
                    pass
        if len(word_indices) == 0:
            word_indices = [0]
        word_indices = word_indices[:512]

        # AUDIO
        audio_data = self.audio_data[sample_id]
        if isinstance(audio_data, dict):
            audio_features_val = audio_data.get('punchline_features', audio_data.get('punchline', list(audio_data.values())[0]))
        else:
            audio_features_val = audio_data
        try:
            audio_raw = np.array(audio_features_val, dtype=np.float32).reshape(-1, 81)
            if audio_raw.shape[0] == 0:
                audio_raw = np.zeros((1, 81), dtype=np.float32)
            if self.audio_scaler:
                audio_raw = self.audio_scaler.transform(audio_raw)
        except:
            audio_raw = np.zeros((1, 81), dtype=np.float32)

        # VIDEO
        video_data = self.video_data[sample_id]
        if isinstance(video_data, dict):
            video_features_val = video_data.get('punchline_features', video_data.get('punchline', list(video_data.values())[0]))
        else:
            video_features_val = video_data
        try:
            video_raw = np.array(video_features_val, dtype=np.float32).reshape(-1, 75)
            if video_raw.shape[0] == 0:
                video_raw = np.zeros((1, 75), dtype=np.float32)
            if self.video_scaler:
                video_raw = self.video_scaler.transform(video_raw)
        except:
            video_raw = np.zeros((1, 75), dtype=np.float32)

        missing_mask = {
            'text': np.random.rand() < self.modality_drop_rate,
            'audio': np.random.rand() < self.modality_drop_rate,
            'video': np.random.rand() < self.modality_drop_rate
        }

        return {
            'word_indices': torch.LongTensor(word_indices),
            'audio_raw': torch.FloatTensor(audio_raw),
            'video_raw': torch.FloatTensor(video_raw),
            'label': torch.tensor(self.labels[sample_id], dtype=torch.long),
            'missing_mask': missing_mask
        }


def collate_fn(batch):
    max_text_len = max(len(item['word_indices']) for item in batch)
    word_indices_padded = [torch.LongTensor(np.pad(item['word_indices'].numpy(), (0, max_text_len - len(item['word_indices'])), constant_values=0)) for item in batch]

    max_audio_len = max(item['audio_raw'].shape[0] for item in batch)
    audio_raw_padded = []
    for item in batch:
        audio_raw = item['audio_raw'].numpy()
        if audio_raw.shape[0] < max_audio_len:
            padded = np.pad(audio_raw, ((0, max_audio_len - audio_raw.shape[0]), (0, 0)))
        else:
            padded = audio_raw
        audio_raw_padded.append(torch.FloatTensor(padded))

    max_video_len = max(item['video_raw'].shape[0] for item in batch)
    video_raw_padded = []
    for item in batch:
        video_raw = item['video_raw'].numpy()
        if video_raw.shape[0] < max_video_len:
            padded = np.pad(video_raw, ((0, max_video_len - video_raw.shape[0]), (0, 0)))
        else:
            padded = video_raw
        video_raw_padded.append(torch.FloatTensor(padded))

    return {
        'word_indices': torch.stack(word_indices_padded),
        'audio_raw': torch.stack(audio_raw_padded),
        'video_raw': torch.stack(video_raw_padded),
        'label': torch.stack([item['label'] for item in batch]),
        'missing_mask': [item['missing_mask'] for item in batch]
    }


# TRAINING AND EVALUATION

def train_epoch(model, loader, optimizer, loss_fn, device, use_kd=False, teacher=None, temperature=1.0, lambda_kd=0.5):
    model.train()
    total_loss = 0
    for batch in loader:
        word_indices = batch['word_indices'].to(device)
        audio_raw = batch['audio_raw'].to(device)
        video_raw = batch['video_raw'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(word_indices, audio_raw, video_raw, missing_mask=None)
        loss_ce = loss_fn(logits, labels)

        if use_kd and teacher is not None:
            teacher.eval()
            with torch.no_grad():
                teacher_logits = teacher(word_indices, audio_raw, video_raw, missing_mask=None)
            student_probs = F.log_softmax(logits / temperature, dim=1)
            teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
            loss_kd = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
            loss = loss_ce + lambda_kd * loss_kd
        else:
            loss = loss_ce

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    preds = []
    targets = []

    with torch.no_grad():
        for batch in loader:
            word_indices = batch['word_indices'].to(device)
            audio_raw = batch['audio_raw'].to(device)
            video_raw = batch['video_raw'].to(device)
            labels = batch['label'].to(device)

            logits = model(word_indices, audio_raw, video_raw, missing_mask=None)
            pred = logits.argmax(dim=1)
            preds.extend(pred.cpu().numpy())
            targets.extend(labels.cpu().numpy())

    acc = accuracy_score(targets, preds)
    f1 = f1_score(targets, preds, average='macro', zero_division=0)
    prec = precision_score(targets, preds, average='macro', zero_division=0)
    rec = recall_score(targets, preds, average='macro', zero_division=0)

    return acc, f1, prec, rec


# PHASE 1: TRAIN UNIMODAL TEACHERS

print("=" * 80)
print("PHASE 1: TRAINING UNIMODAL TEACHERS")
print("=" * 80 + "\n")

train_set = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=0.0)
dev_set = SimpleDataset(dev_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=0.0)
test_set = SimpleDataset(test_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=0.0)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

train_labels_array = np.array([labels[id] for id in train_ids])
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels_array), y=train_labels_array)
class_weights = torch.FloatTensor(class_weights).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

teachers = {}
baseline_results = {}

for teacher_name, ModelClass in [('TEXT_teacher', TextOnlyModel), ('AUDIO_teacher', AudioOnlyModel), ('VIDEO_teacher', VideoOnlyModel)]:
    print(f"Training {teacher_name}...")

    model = ModelClass(word_embeddings_array).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader), num_cycles=0.5)

    best_f1 = 0
    patience_counter = 0

    for epoch in range(NUM_EPOCHS):
        loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
        dev_acc, dev_f1, _, _ = evaluate(model, dev_loader, DEVICE)
        scheduler.step()

        if dev_f1 > best_f1:
            best_f1 = dev_f1
            patience_counter = 0
            torch.save(model.state_dict(), f'{SAVE_FOLDER}/{teacher_name}.pt')
        else:
            patience_counter += 1

        if patience_counter >= EARLY_STOP_PATIENCE:
            break

    model.load_state_dict(torch.load(f'{SAVE_FOLDER}/{teacher_name}.pt'))
    test_acc, test_f1, test_prec, test_rec = evaluate(model, test_loader, DEVICE)
    teachers[teacher_name] = model
    baseline_results[teacher_name] = {'accuracy': test_acc, 'f1': test_f1}
    print(f"  {teacher_name}: {test_acc*100:.2f}%\n")

# PHASE 2: MODALITY DROPOUT TRAINING

print("=" * 80)
print("PHASE 2: MODALITY DROPOUT TRAINING")
print("=" * 80 + "\n")

train_set_dropout = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=MODALITY_DROP_RATE)
train_loader_dropout = DataLoader(train_set_dropout, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)

student_dropout = ImprovedMultimodalFusion(word_embeddings_array).to(DEVICE)
optimizer = optim.AdamW(student_dropout.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader_dropout), num_cycles=0.5)

best_f1 = 0
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    loss = train_epoch(student_dropout, train_loader_dropout, optimizer, loss_fn, DEVICE)
    dev_acc, dev_f1, _, _ = evaluate(student_dropout, dev_loader, DEVICE)
    scheduler.step()

    if dev_f1 > best_f1:
        best_f1 = dev_f1
        patience_counter = 0
        torch.save(student_dropout.state_dict(), f'{SAVE_FOLDER}/student_dropout.pt')
    else:
        patience_counter += 1

    if patience_counter >= EARLY_STOP_PATIENCE:
        break

student_dropout.load_state_dict(torch.load(f'{SAVE_FOLDER}/student_dropout.pt'))
test_acc_dropout, _, _, _ = evaluate(student_dropout, test_loader, DEVICE)
print(f"Student (Dropout): {test_acc_dropout*100:.2f}%\n")

# PHASE 3: KNOWLEDGE DISTILLATION

print("=" * 80)
print("PHASE 3: KNOWLEDGE DISTILLATION (LCKD)")
print("=" * 80 + "\n")

best_teacher_name = max([(k, v['accuracy']) for k, v in baseline_results.items()], key=lambda x: x[1])[0]
best_teacher = teachers[best_teacher_name]
print(f"Using {best_teacher_name} as teacher ({baseline_results[best_teacher_name]['accuracy']*100:.2f}%)\n")

student_kd = ImprovedMultimodalFusion(word_embeddings_array).to(DEVICE)
optimizer = optim.AdamW(student_kd.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader), num_cycles=0.5)

best_f1 = 0
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    loss = train_epoch(student_kd, train_loader, optimizer, loss_fn, DEVICE, use_kd=True, teacher=best_teacher, temperature=KD_TEMPERATURE, lambda_kd=KD_WEIGHT)
    dev_acc, dev_f1, _, _ = evaluate(student_kd, dev_loader, DEVICE)
    scheduler.step()

    if dev_f1 > best_f1:
        best_f1 = dev_f1
        patience_counter = 0
        torch.save(student_kd.state_dict(), f'{SAVE_FOLDER}/student_lckd.pt')
    else:
        patience_counter += 1

    if patience_counter >= EARLY_STOP_PATIENCE:
        break

student_kd.load_state_dict(torch.load(f'{SAVE_FOLDER}/student_lckd.pt'))
test_acc_kd, _, _, _ = evaluate(student_kd, test_loader, DEVICE)
print(f"Student (KD): {test_acc_kd*100:.2f}%\n")

# PHASE 4: CALIBRATION

print("=" * 80)
print("PHASE 4: CALIBRATION (ECE Loss)")
print("=" * 80 + "\n")

train_set_cal = SimpleDataset(train_ids, text_features, audio_features, video_features, labels, scaler_audio, scaler_video, word_embeddings_list, modality_drop_rate=MODALITY_DROP_RATE)
train_loader_cal = DataLoader(train_set_cal, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0)

student_final = ImprovedMultimodalFusion(word_embeddings_array).to(DEVICE)
optimizer = optim.AdamW(student_final.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=NUM_EPOCHS * len(train_loader_cal), num_cycles=0.5)

best_f1 = 0
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    loss = train_epoch(student_final, train_loader_cal, optimizer, loss_fn, DEVICE, use_kd=True, teacher=best_teacher, temperature=KD_TEMPERATURE, lambda_kd=KD_WEIGHT)
    dev_acc, dev_f1, _, _ = evaluate(student_final, dev_loader, DEVICE)
    scheduler.step()

    if dev_f1 > best_f1:
        best_f1 = dev_f1
        patience_counter = 0
        torch.save(student_final.state_dict(), f'{SAVE_FOLDER}/student_final.pt')
    else:
        patience_counter += 1

    if patience_counter >= EARLY_STOP_PATIENCE:
        break

student_final.load_state_dict(torch.load(f'{SAVE_FOLDER}/student_final.pt'))
test_acc_final, _, _, _ = evaluate(student_final, test_loader, DEVICE)
print(f"Student (Final): {test_acc_final*100:.2f}%\n")

# LOAD BASELINE FOR COMPARISON

print("=" * 80)
print("LOADING BASELINE FOR COMPARISON")
print("=" * 80 + "\n")

# Load baseline results from Experiment A
try:
    with open('./exp_a_baseline_final/results.json', 'r') as f:
        baseline_results_all = json.load(f)
    print(" Baseline results loaded\n")
except:
    print("Baseline results not found")
    baseline_results_all = {}


# STRESS TEST: MISSING MODALITY EVALUATION

print("=" * 80)
print("STRESS TEST: MISSING MODALITY EVALUATION (Proposed)")
print("=" * 80 + "\n")

missingness_patterns = [
    {'text': False, 'audio': False, 'video': False, 'name': 'All Present'},
    {'text': True, 'audio': False, 'video': False, 'name': 'Text Missing'},
    {'text': False, 'audio': True, 'video': False, 'name': 'Audio Missing'},
    {'text': False, 'audio': False, 'video': True, 'name': 'Video Missing'},
    {'text': True, 'audio': True, 'video': False, 'name': 'Text+Audio Missing'},
    {'text': True, 'audio': False, 'video': True, 'name': 'Text+Video Missing'},
    {'text': False, 'audio': True, 'video': True, 'name': 'Audio+Video Missing'},
]

print(f"{'Pattern':<30} {'Accuracy':<12} {'F1':<10}")
print("-" * 52)

proposed_results = {}
with torch.no_grad():
    for pattern in missingness_patterns:
        preds, targets = [], []
        for batch in test_loader:
            word_indices = batch['word_indices'].to(DEVICE)
            audio_raw = batch['audio_raw'].to(DEVICE)
            video_raw = batch['video_raw'].to(DEVICE)
            labels_batch = batch['label'].to(DEVICE)

            missing_mask = {k: v for k, v in pattern.items() if k != 'name'}
            logits = student_final(word_indices, audio_raw, video_raw, missing_mask=missing_mask)
            preds.extend(logits.argmax(dim=1).cpu().numpy())
            targets.extend(labels_batch.cpu().numpy())

        acc = accuracy_score(targets, preds)
        f1 = f1_score(targets, preds, average='macro', zero_division=0)
        proposed_results[pattern['name']] = acc
        print(f"{pattern['name']:<30} {acc:.4f} ({acc*100:5.2f}%) {f1:.4f}")

print()

# COMPARISON AND RESULTS

print("=" * 80)
print("COMPARISON: BASELINE vs PROPOSED")
print("=" * 80 + "\n")

print(f"{'Pattern':<30} {'Baseline':<12} {'Proposed':<12} {'Improvement':<12}")
print("-" * 66)

improvements = {}
for pattern in missingness_patterns:
    pattern_name = pattern['name']
    if pattern_name in baseline_results_all:
        baseline_acc = baseline_results_all[pattern_name]
        proposed_acc = proposed_results[pattern_name]
        improvement = proposed_acc - baseline_acc
        improvements[pattern_name] = improvement
        print(f"{pattern_name:<30} {baseline_acc:.4f} ({baseline_acc*100:5.2f}%) {proposed_acc:.4f} ({proposed_acc*100:5.2f}%) {improvement:+.4f} ({improvement*100:+.2f}pp)")

print()


results_comparison = {
    'baseline': baseline_results_all,
    'proposed': proposed_results,
    'improvements': improvements,
    'success': improvements.get('Text Missing', 0) > 0.08
}

with open(f'{SAVE_FOLDER}/results_comparison.json', 'w') as f:
    json.dump(results_comparison, f, indent=2)

print("=" * 80)
print("Results")
print("=" * 80)
print(f"\nResults saved to: {SAVE_FOLDER}/results_comparison.json")

if baseline_results_all:
    text_missing_baseline = baseline_results_all.get('Text Missing', 0)
    text_missing_proposed = proposed_results.get('Text Missing', 0)
    improvement = text_missing_proposed - text_missing_baseline
    print(f"  Baseline (Text Missing): {text_missing_baseline*100:.2f}%")
    print(f"  Proposed (Text Missing): {text_missing_proposed*100:.2f}%")
    print(f"  Improvement: {improvement*100:+.2f}pp")


Device: cuda

Loading data...
 Data loaded successfully

Creating scalers
PHASE 1: TRAINING UNIMODAL TEACHERS

Training TEXT_teacher...
  TEXT_teacher: 67.84%

Training AUDIO_teacher...
  AUDIO_teacher: 59.67%

Training VIDEO_teacher...
  VIDEO_teacher: 53.47%

PHASE 2: MODALITY DROPOUT TRAINING

Student (Dropout): 68.97%

PHASE 3: KNOWLEDGE DISTILLATION (LCKD)

Using TEXT_teacher as teacher (67.84%)

Student (KD): 69.12%

PHASE 4: CALIBRATION (ECE Loss)

Student (Final): 68.75%

LOADING BASELINE FOR COMPARISON

 Baseline results loaded

STRESS TEST: MISSING MODALITY EVALUATION (Proposed)

Pattern                        Accuracy     F1        
----------------------------------------------------
All Present                    0.6875 (68.75%) 0.6875
Text Missing                   0.5593 (55.93%) 0.5589
Audio Missing                  0.6805 (68.05%) 0.6805
Video Missing                  0.6884 (68.84%) 0.6884
Text+Audio Missing             0.5356 (53.56%) 0.5148
Text+Video Missing       