In [4]:
!pip install -q transformers wandb


12.6
True


In [5]:
import wandb
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!


In [6]:
import json
import random
import warnings
from functools import partial
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torchaudio
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
from transformers import (
    Wav2Vec2ForSequenceClassification,
    HubertForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    AutoImageProcessor,
    TimesformerForVideoClassification,
)

warnings.filterwarnings("ignore")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
METADATA = "/content/processed_data/metadata.json"
OUT_DIR = Path("/content/trained_encoders")
OUT_DIR.mkdir(parents=True, exist_ok=True)
NUM_EMOTIONS = 8

print(f"Device: {DEVICE}")




In [7]:
class EmotionDataset(Dataset):
    def __init__(self, metadata_path: str, split: str, modality: str):
        with open(metadata_path) as f:
            data = json.load(f)
        self.samples = [s for s in data if s["split"] == split]
        self.modality = modality

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        item = {"emotion": s["emotion_idx"]}
        if self.modality == "audio":
            wav, _ = torchaudio.load(s["audio_path"])
            item["audio"] = wav.squeeze(0)
        elif self.modality == "video":
            frames = np.load(s["frames_path"])
            item["video"] = torch.from_numpy(frames).permute(0, 3, 1, 2).float() / 255.0
        return item


def collate_fn(batch):
    out = {"emotion": torch.tensor([b["emotion"] for b in batch])}
    if "audio" in batch[0]:
        out["audio"] = [b["audio"] for b in batch]
    if "video" in batch[0]:
        out["video"] = torch.stack([b["video"] for b in batch])
    return out


  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkatrinpochtar[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device: cuda | Train: 576 | Val: 144


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json: 0.00B [00:00, ?B/s]



pytorch_model.bin:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at superb/wav2vec2-base-superb-er and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([4, 256]) in the checkpoint and torch.Size([8, 256]) in the model instantiated
- classifier.bias: found shape torch.Size([4]) in the checkpoint and torch.Size([8]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/486M [00:00<?, ?B/s]

Some weights of TimesformerForVideoClassification were not initialized from the model checkpoint at facebook/timesformer-base-finetuned-k400 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([8, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([8]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/412 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/486M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.



Epoch 1/20




Training (Audio):   1%|          | 1/144 [00:00<02:00,  1.19it/s][A
Training (Audio):   2%|▏         | 3/144 [00:01<00:39,  3.54it/s][A
Training (Audio):   3%|▎         | 4/144 [00:01<00:31,  4.50it/s][A
Training (Audio):   3%|▎         | 5/144 [00:01<00:25,  5.49it/s][A
Training (Audio):   4%|▍         | 6/144 [00:01<00:21,  6.38it/s][A
Training (Audio):   5%|▍         | 7/144 [00:01<00:19,  7.10it/s][A
Training (Audio):   6%|▌         | 8/144 [00:01<00:17,  7.57it/s][A
Training (Audio):   6%|▋         | 9/144 [00:01<00:16,  8.05it/s][A
Training (Audio):   7%|▋         | 10/144 [00:01<00:16,  8.21it/s][A
Training (Audio):   8%|▊         | 11/144 [00:01<00:15,  8.39it/s][A
Training (Audio):   8%|▊         | 12/144 [00:01<00:15,  8.69it/s][A
Training (Audio):   9%|▉         | 13/144 [00:02<00:15,  8.71it/s][A
Training (Audio):  10%|▉         | 14/144 [00:02<00:14,  8.90it/s][A
Training (Audio):  10%|█         | 15/144 [00:02<00:14,  9.19it/s][A
Training (Audio):  12%|█▏ 

[Audio] Train: loss=2.0724 acc=0.1198 f1=0.0957
[Audio]   Val: loss=2.0573 acc=0.2569 f1=0.1931


Training (Video): 100%|██████████| 144/144 [01:16<00:00,  1.89it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.87it/s]


[Video] Train: loss=2.1382 acc=0.1719 f1=0.1440
[Video]   Val: loss=2.0686 acc=0.1667 f1=0.1471
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.1931)
✓ Saved best video encoder → /content/trained_encoders/best_video_encoder (F1=0.1471)

Epoch 2/20
→ Unfroze TimeSformer backbone


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.37it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.34it/s]


[Audio] Train: loss=2.0160 acc=0.2465 f1=0.1710
[Audio]   Val: loss=1.9726 acc=0.2778 f1=0.2073


Training (Video): 100%|██████████| 144/144 [01:49<00:00,  1.32it/s]
Validation (Video): 100%|██████████| 36/36 [00:18<00:00,  1.92it/s]


[Video] Train: loss=1.5132 acc=0.4462 f1=0.4366
[Video]   Val: loss=0.9649 acc=0.6736 f1=0.6420
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.2073)
✓ Saved best video encoder → /content/trained_encoders/best_video_encoder (F1=0.6420)

Epoch 3/20
→ Unfroze Wav2Vec2/HubERT feature encoder


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.63it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.84it/s]


[Audio] Train: loss=1.8947 acc=0.2726 f1=0.2019
[Audio]   Val: loss=1.8493 acc=0.2917 f1=0.2015


Training (Video): 100%|██████████| 144/144 [01:48<00:00,  1.33it/s]
Validation (Video): 100%|██████████| 36/36 [00:18<00:00,  1.91it/s]


[Video] Train: loss=0.7846 acc=0.7222 f1=0.7202
[Video]   Val: loss=0.8175 acc=0.6806 f1=0.6406

Epoch 4/20


Training (Audio): 100%|██████████| 144/144 [00:14<00:00, 10.21it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.48it/s]


[Audio] Train: loss=1.7981 acc=0.3021 f1=0.2047
[Audio]   Val: loss=1.7539 acc=0.3264 f1=0.2191


Training (Video): 100%|██████████| 144/144 [01:48<00:00,  1.32it/s]
Validation (Video): 100%|██████████| 36/36 [00:18<00:00,  1.91it/s]


[Video] Train: loss=0.5224 acc=0.8281 f1=0.8287
[Video]   Val: loss=0.5964 acc=0.7917 f1=0.7812
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.2191)
✓ Saved best video encoder → /content/trained_encoders/best_video_encoder (F1=0.7812)

Epoch 5/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.72it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.41it/s]


[Audio] Train: loss=1.6895 acc=0.3542 f1=0.2483
[Audio]   Val: loss=1.6479 acc=0.3403 f1=0.2446


Training (Video): 100%|██████████| 144/144 [01:49<00:00,  1.32it/s]
Validation (Video): 100%|██████████| 36/36 [00:18<00:00,  1.91it/s]


[Video] Train: loss=0.4221 acc=0.8490 f1=0.8468
[Video]   Val: loss=0.4022 acc=0.8194 f1=0.8172
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.2446)
✓ Saved best video encoder → /content/trained_encoders/best_video_encoder (F1=0.8172)

Epoch 6/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.61it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.56it/s]


[Audio] Train: loss=1.6300 acc=0.3819 f1=0.2900
[Audio]   Val: loss=1.5767 acc=0.3403 f1=0.2279


Training (Video): 100%|██████████| 144/144 [01:43<00:00,  1.39it/s]
Validation (Video): 100%|██████████| 36/36 [00:17<00:00,  2.04it/s]


[Video] Train: loss=0.3024 acc=0.9045 f1=0.9048
[Video]   Val: loss=0.5522 acc=0.8194 f1=0.8139

Epoch 7/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.94it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.44it/s]


[Audio] Train: loss=1.5420 acc=0.4045 f1=0.3047
[Audio]   Val: loss=1.5965 acc=0.3403 f1=0.2322


Training (Video): 100%|██████████| 144/144 [01:44<00:00,  1.38it/s]
Validation (Video): 100%|██████████| 36/36 [00:17<00:00,  2.00it/s]


[Video] Train: loss=0.2273 acc=0.9288 f1=0.9292
[Video]   Val: loss=0.5012 acc=0.8125 f1=0.8136

Epoch 8/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.81it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 31.49it/s]


[Audio] Train: loss=1.5061 acc=0.4271 f1=0.3308
[Audio]   Val: loss=1.5013 acc=0.4028 f1=0.2933


Training (Video): 100%|██████████| 144/144 [01:44<00:00,  1.38it/s]
Validation (Video): 100%|██████████| 36/36 [00:17<00:00,  2.03it/s]


[Video] Train: loss=0.2099 acc=0.9306 f1=0.9313
[Video]   Val: loss=0.5110 acc=0.8333 f1=0.8274
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.2933)
✓ Saved best video encoder → /content/trained_encoders/best_video_encoder (F1=0.8274)

Epoch 9/20


Training (Audio): 100%|██████████| 144/144 [00:13<00:00, 10.66it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 31.49it/s]


[Audio] Train: loss=1.4045 acc=0.4878 f1=0.3934
[Audio]   Val: loss=1.4232 acc=0.4028 f1=0.2931


Training (Video): 100%|██████████| 144/144 [01:51<00:00,  1.29it/s]
Validation (Video): 100%|██████████| 36/36 [00:18<00:00,  1.90it/s]


[Video] Train: loss=0.2221 acc=0.9236 f1=0.9237
[Video]   Val: loss=0.4005 acc=0.8958 f1=0.8955
✓ Saved best video encoder → /content/trained_encoders/best_video_encoder (F1=0.8955)

Epoch 10/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.85it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.32it/s]


[Audio] Train: loss=1.3594 acc=0.4931 f1=0.4068
[Audio]   Val: loss=1.3871 acc=0.4167 f1=0.3015


Training (Video): 100%|██████████| 144/144 [01:49<00:00,  1.32it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.85it/s]


[Video] Train: loss=0.1794 acc=0.9549 f1=0.9550
[Video]   Val: loss=0.4950 acc=0.8472 f1=0.8481
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.3015)

Epoch 11/20


Training (Audio): 100%|██████████| 144/144 [00:09<00:00, 14.54it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 31.79it/s]


[Audio] Train: loss=1.2497 acc=0.5729 f1=0.4938
[Audio]   Val: loss=1.2355 acc=0.5486 f1=0.4828


Training (Video): 100%|██████████| 144/144 [01:49<00:00,  1.32it/s]
Validation (Video): 100%|██████████| 36/36 [00:18<00:00,  1.91it/s]


[Video] Train: loss=0.1456 acc=0.9583 f1=0.9583
[Video]   Val: loss=0.4338 acc=0.8819 f1=0.8800
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.4828)

Epoch 12/20


Training (Audio): 100%|██████████| 144/144 [00:13<00:00, 10.93it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 31.17it/s]


[Audio] Train: loss=1.2249 acc=0.5955 f1=0.5304
[Audio]   Val: loss=1.2299 acc=0.5903 f1=0.5289


Training (Video): 100%|██████████| 144/144 [01:51<00:00,  1.30it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.86it/s]


[Video] Train: loss=0.0970 acc=0.9809 f1=0.9809
[Video]   Val: loss=0.3225 acc=0.8819 f1=0.8809
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.5289)

Epoch 13/20


Training (Audio): 100%|██████████| 144/144 [00:09<00:00, 14.49it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.08it/s]


[Audio] Train: loss=1.0988 acc=0.6302 f1=0.5657
[Audio]   Val: loss=1.1848 acc=0.5625 f1=0.4912


Training (Video): 100%|██████████| 144/144 [01:51<00:00,  1.30it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.84it/s]


[Video] Train: loss=0.1122 acc=0.9618 f1=0.9619
[Video]   Val: loss=0.5836 acc=0.8333 f1=0.8281

Epoch 14/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 14.26it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 32.30it/s]


[Audio] Train: loss=1.0982 acc=0.6267 f1=0.5727
[Audio]   Val: loss=1.3265 acc=0.5347 f1=0.4576


Training (Video): 100%|██████████| 144/144 [01:50<00:00,  1.30it/s]
Validation (Video): 100%|██████████| 36/36 [00:20<00:00,  1.79it/s]


[Video] Train: loss=0.1476 acc=0.9618 f1=0.9624
[Video]   Val: loss=0.6223 acc=0.8264 f1=0.8239

Epoch 15/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 14.07it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 31.82it/s]


[Audio] Train: loss=1.0496 acc=0.6597 f1=0.6133
[Audio]   Val: loss=1.1054 acc=0.5972 f1=0.5367


Training (Video): 100%|██████████| 144/144 [01:51<00:00,  1.29it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.84it/s]


[Video] Train: loss=0.0956 acc=0.9740 f1=0.9739
[Video]   Val: loss=0.5583 acc=0.8681 f1=0.8659
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.5367)

Epoch 16/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 14.04it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 31.82it/s]


[Audio] Train: loss=0.9645 acc=0.6927 f1=0.6651
[Audio]   Val: loss=1.0109 acc=0.7083 f1=0.6989


Training (Video): 100%|██████████| 144/144 [01:50<00:00,  1.30it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.86it/s]


[Video] Train: loss=0.1229 acc=0.9670 f1=0.9670
[Video]   Val: loss=0.6023 acc=0.8750 f1=0.8739
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.6989)

Epoch 17/20


Training (Audio): 100%|██████████| 144/144 [00:09<00:00, 15.16it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 33.02it/s]


[Audio] Train: loss=0.8948 acc=0.7274 f1=0.7021
[Audio]   Val: loss=0.9164 acc=0.7222 f1=0.7181


Training (Video): 100%|██████████| 144/144 [01:49<00:00,  1.31it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.84it/s]


[Video] Train: loss=0.0797 acc=0.9774 f1=0.9775
[Video]   Val: loss=0.6001 acc=0.8403 f1=0.8374
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.7181)

Epoch 18/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.79it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 31.44it/s]


[Audio] Train: loss=0.8822 acc=0.7378 f1=0.7227
[Audio]   Val: loss=0.9148 acc=0.7292 f1=0.7253


Training (Video): 100%|██████████| 144/144 [01:51<00:00,  1.29it/s]
Validation (Video): 100%|██████████| 36/36 [00:19<00:00,  1.84it/s]


[Video] Train: loss=0.1032 acc=0.9740 f1=0.9740
[Video]   Val: loss=0.5703 acc=0.8611 f1=0.8609
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.7253)

Epoch 19/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 13.79it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 33.29it/s]


[Audio] Train: loss=0.8069 acc=0.7726 f1=0.7607
[Audio]   Val: loss=0.9015 acc=0.7778 f1=0.7791


Training (Video): 100%|██████████| 144/144 [01:49<00:00,  1.32it/s]
Validation (Video): 100%|██████████| 36/36 [00:18<00:00,  1.92it/s]


[Video] Train: loss=0.0678 acc=0.9826 f1=0.9826
[Video]   Val: loss=0.6097 acc=0.8681 f1=0.8636
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.7791)

Epoch 20/20


Training (Audio): 100%|██████████| 144/144 [00:10<00:00, 14.29it/s]
Validation (Audio): 100%|██████████| 36/36 [00:01<00:00, 33.06it/s]


[Audio] Train: loss=0.7226 acc=0.7882 f1=0.7777
[Audio]   Val: loss=0.8005 acc=0.7847 f1=0.7864


Training (Video): 100%|██████████| 144/144 [01:52<00:00,  1.28it/s]
Validation (Video): 100%|██████████| 36/36 [00:20<00:00,  1.79it/s]


[Video] Train: loss=0.0580 acc=0.9878 f1=0.9879
[Video]   Val: loss=0.4600 acc=0.8750 f1=0.8753
✓ Saved best audio encoder → /content/trained_encoders/best_audio_encoder (F1=0.7864)

Training complete!
Best Audio F1: 0.7864 | Best Video F1: 0.8955


0,1
audio/train_acc,▁▂▃▃▃▄▄▄▅▅▆▆▆▆▇▇▇▇██
audio/train_f1,▁▂▂▂▃▃▃▃▄▄▅▅▆▆▆▇▇▇██
audio/train_loss,██▇▇▆▆▅▅▅▄▄▄▃▃▃▂▂▂▁▁
audio/val_acc,▁▁▁▂▂▂▂▃▃▃▅▅▅▅▆▇▇▇██
audio/val_f1,▁▁▁▁▂▁▁▂▂▂▄▅▅▄▅▇▇▇██
audio/val_loss,██▇▆▆▅▅▅▄▄▃▃▃▄▃▂▂▂▂▁
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
video/train_acc,▁▃▆▇▇▇▇█▇███████████
video/train_f1,▁▃▆▇▇▇██▇███████████
video/train_loss,█▆▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
audio/train_acc,0.78819
audio/train_f1,0.77765
audio/train_loss,0.72265
audio/val_acc,0.78472
audio/val_f1,0.78639
audio/val_loss,0.80054
epoch,20
video/train_acc,0.98785
video/train_f1,0.98786
video/train_loss,0.05797


In [7]:
def crop_audio(wav, sr, duration, train):
    L = int(round(duration * sr))
    n = wav.numel()
    if n <= L:
        return torch.nn.functional.pad(wav, (0, L - n))
    start = torch.randint(0, n - L + 1, ()).item() if train else (n - L) // 2
    return wav[start:start + L]


def crop_video(video, n_frames, train):
    T = video.shape[0]
    if T <= n_frames:
        idx = torch.linspace(0, T - 1, n_frames).round().long()
        return video[idx]
    start = torch.randint(0, T - n_frames + 1, ()).item() if train else (T - n_frames) // 2
    return video[start:start + n_frames]


def prepare_audio(batch, processor, window_s, device, train=True):
    sr = 16000
    wavs = [crop_audio(a, sr, window_s, train).numpy() for a in batch["audio"]]
    enc = processor(wavs, sampling_rate=sr, return_tensors="pt", padding=True,
                    truncation=True, max_length=int(window_s * sr))
    kwargs = {"input_values": enc["input_values"].to(device)}
    if "attention_mask" in enc:
        kwargs["attention_mask"] = enc["attention_mask"].to(device)
    return kwargs, batch["emotion"].to(device)


def prepare_video(batch, processor, n_frames, device, train=True):
    clips = []
    for v in batch["video"]:
        clip = crop_video(v, n_frames, train)
        clips.append([clip[i].permute(1, 2, 0).numpy() for i in range(clip.shape[0])])
    enc = processor(clips, return_tensors="pt", do_rescale=False)
    return {"pixel_values": enc["pixel_values"].to(device)}, batch["emotion"].to(device)

In [None]:
def train_one_epoch(model, loader, prep_fn, optimizer, scaler):
    model.train()
    total_loss, preds, labels = 0.0, [], []
    for batch in tqdm(loader, leave=False):
        kwargs, y = prep_fn(batch, train=True)
        optimizer.zero_grad(set_to_none=True)
        with autocast("cuda", enabled=DEVICE == "cuda"):
            logits = model(**kwargs).logits
            loss = nn.CrossEntropyLoss()(logits, y)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        preds.extend(logits.argmax(1).detach().cpu().tolist())
        labels.extend(y.cpu().tolist())
    return {
        "loss": total_loss / len(loader),
        "acc": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted"),
    }


@torch.no_grad()
def evaluate(model, loader, prep_fn):
    model.eval()
    total_loss, preds, labels = 0.0, [], []
    for batch in tqdm(loader, leave=False):
        kwargs, y = prep_fn(batch, train=False)
        with autocast("cuda", enabled=DEVICE == "cuda"):
            logits = model(**kwargs).logits
            loss = nn.CrossEntropyLoss()(logits, y)
        total_loss += loss.item()
        preds.extend(logits.argmax(1).cpu().tolist())
        labels.extend(y.cpu().tolist())
    return {
        "loss": total_loss / len(loader),
        "acc": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted"),
    }

In [None]:
def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def run_experiment(cfg):
    seed_all(42)
    wandb.init(project="uncanny-valley-encoders", name=cfg["name"],
               group=cfg["modality"], config=cfg, reinit=True)

    modality = cfg["modality"]
    train_ds = EmotionDataset(METADATA, "train", modality)
    val_ds = EmotionDataset(METADATA, "val", modality)
    train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True,
                              num_workers=0, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False,
                            num_workers=0, collate_fn=collate_fn)

    if modality == "audio":
        model_cls = (HubertForSequenceClassification if "hubert" in cfg["model"].lower()
                     else Wav2Vec2ForSequenceClassification)
        model = model_cls.from_pretrained(
            cfg["model"], num_labels=NUM_EMOTIONS, ignore_mismatched_sizes=True)
        processor = Wav2Vec2FeatureExtractor.from_pretrained(cfg["model"])
        prep_fn = partial(prepare_audio, processor=processor,
                          window_s=cfg.get("window_s", 3.0), device=DEVICE)
        if hasattr(model, "freeze_feature_encoder"):
            model.freeze_feature_encoder()
    else:
        model = TimesformerForVideoClassification.from_pretrained(
            cfg["model"], num_labels=NUM_EMOTIONS, ignore_mismatched_sizes=True)
        processor = AutoImageProcessor.from_pretrained(cfg["model"])
        prep_fn = partial(prepare_video, processor=processor,
                          n_frames=cfg.get("n_frames", 8), device=DEVICE)
        for n, p in model.named_parameters():
            if "classifier" not in n:
                p.requires_grad = False

    model.to(DEVICE)
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr"])
    scaler = GradScaler(enabled=DEVICE == "cuda")

    best_f1, patience_cnt = 0.0, 0
    save_path = OUT_DIR / cfg["name"]

    for epoch in range(cfg["epochs"]):
        # Unfreeze backbone after warmup
        if epoch == cfg.get("freeze_epochs", 2):
            for p in model.parameters():
                p.requires_grad = True
            optimizer = torch.optim.AdamW(model.parameters(), lr=cfg["lr"] * 0.1)
            scaler = GradScaler(enabled=DEVICE == "cuda")

        t = train_one_epoch(model, train_loader, prep_fn, optimizer, scaler)
        v = evaluate(model, val_loader, prep_fn)

        wandb.log({
            "epoch": epoch + 1,
            "train/loss": t["loss"], "train/acc": t["acc"], "train/f1": t["f1"],
            "val/loss": v["loss"], "val/acc": v["acc"], "val/f1": v["f1"],
            "lr": optimizer.param_groups[0]["lr"],
        })
        print(f"  [{epoch+1:2d}/{cfg['epochs']}] "
              f"t_f1={t['f1']:.3f} v_f1={v['f1']:.3f} v_loss={v['loss']:.3f}")

        if v["f1"] > best_f1:
            best_f1 = v["f1"]
            save_path.mkdir(parents=True, exist_ok=True)
            model.save_pretrained(str(save_path))
            processor.save_pretrained(str(save_path))
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= cfg.get("patience", 5):
                print(f"  Early stopping at epoch {epoch+1}")
                break

    wandb.log({"best_val_f1": best_f1})
    wandb.finish()
    del model
    torch.cuda.empty_cache()
    print(f"  Best F1: {best_f1:.4f} -> {save_path}\n")
    return {"name": cfg["name"], "best_f1": best_f1, "path": str(save_path)}

In [None]:
EXPERIMENTS = [
    # --- Audio: Wav2Vec2 (3 LRs x 2 window sizes) ---
    {"name": "wav2vec2-lr1e5-w3s", "modality": "audio",
     "model": "superb/wav2vec2-base-superb-er",
     "lr": 1e-5, "window_s": 3.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    {"name": "wav2vec2-lr3e5-w3s", "modality": "audio",
     "model": "superb/wav2vec2-base-superb-er",
     "lr": 3e-5, "window_s": 3.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    {"name": "wav2vec2-lr5e5-w3s", "modality": "audio",
     "model": "superb/wav2vec2-base-superb-er",
     "lr": 5e-5, "window_s": 3.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    {"name": "wav2vec2-lr3e5-w2s", "modality": "audio",
     "model": "superb/wav2vec2-base-superb-er",
     "lr": 3e-5, "window_s": 2.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    # --- Audio: HuBERT (3 LRs x 2 window sizes) ---
    {"name": "hubert-lr1e5-w3s", "modality": "audio",
     "model": "superb/hubert-base-superb-er",
     "lr": 1e-5, "window_s": 3.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    {"name": "hubert-lr3e5-w3s", "modality": "audio",
     "model": "superb/hubert-base-superb-er",
     "lr": 3e-5, "window_s": 3.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    {"name": "hubert-lr5e5-w3s", "modality": "audio",
     "model": "superb/hubert-base-superb-er",
     "lr": 5e-5, "window_s": 3.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    {"name": "hubert-lr3e5-w2s", "modality": "audio",
     "model": "superb/hubert-base-superb-er",
     "lr": 3e-5, "window_s": 2.0, "batch_size": 8,
     "epochs": 25, "freeze_epochs": 3, "patience": 5},

    # --- Video: TimeSformer (2 LRs x 2 frame counts + freeze variant) ---
    {"name": "timesformer-lr1e5-8f", "modality": "video",
     "model": "facebook/timesformer-base-finetuned-k400",
     "lr": 1e-5, "n_frames": 8, "batch_size": 4,
     "epochs": 15, "freeze_epochs": 1, "patience": 4},

    {"name": "timesformer-lr3e5-8f", "modality": "video",
     "model": "facebook/timesformer-base-finetuned-k400",
     "lr": 3e-5, "n_frames": 8, "batch_size": 4,
     "epochs": 15, "freeze_epochs": 1, "patience": 4},

    {"name": "timesformer-lr1e5-16f", "modality": "video",
     "model": "facebook/timesformer-base-finetuned-k400",
     "lr": 1e-5, "n_frames": 16, "batch_size": 2,
     "epochs": 15, "freeze_epochs": 1, "patience": 4},

    {"name": "timesformer-lr3e5-16f", "modality": "video",
     "model": "facebook/timesformer-base-finetuned-k400",
     "lr": 3e-5, "n_frames": 16, "batch_size": 2,
     "epochs": 15, "freeze_epochs": 1, "patience": 4},

    {"name": "timesformer-lr1e5-8f-freeze3", "modality": "video",
     "model": "facebook/timesformer-base-finetuned-k400",
     "lr": 1e-5, "n_frames": 8, "batch_size": 4,
     "epochs": 15, "freeze_epochs": 3, "patience": 4},
]

results = []
for exp in EXPERIMENTS:
    print(f"{'='*60}\n{exp['name']}\n{'='*60}")
    results.append(run_experiment(exp))

In [None]:
print(f"\n{'='*60}")
print("RESULTS SUMMARY")
print(f"{'='*60}")
print(f"{'Name':30s} {'Best Val F1':>12s}")
print("-" * 44)
for r in sorted(results, key=lambda x: -x["best_f1"]):
    print(f"{r['name']:30s} {r['best_f1']:12.4f}")