In [1]:
import os
import random
import torch
import torchaudio
import torchvision
import torchvision.transforms as T
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import OneCycleLR
from PIL import Image
from transformers import Wav2Vec2Model, Wav2Vec2Processor, CLIPTextModel, CLIPTokenizer
from tqdm.auto import tqdm

### Constants

In [2]:
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
BASE_DIR      = "/media/dtsarev/SatSSD/data"
TRAIN_CSV     = os.path.join(BASE_DIR, "train_split.csv")
VAL_CSV       = os.path.join(BASE_DIR, "valid_split.csv")
AUDIO_DIR     = os.path.join(BASE_DIR, "audio")
IMAGE_DIR     = os.path.join(BASE_DIR, "face_images")
TEXT_DIR      = os.path.join(BASE_DIR, "text")

# Hyperparameters
BATCH_SIZE    = 2
NUM_EPOCHS    = 20
LR            = 1e-4
WEIGHT_DECAY  = 1e-5
SAVE_EVERY    = 3       # epochs
NUM_WORKERS   = 4
PIN_MEMORY    = True

# Data settings
N_FRAMES      = 16      # frames per sample
MAX_AUDIO_LEN = 160000  # max waveform length (~10s @16kHz)

# Model dims
VIS_DIM       = 768
AUD_DIM       = 768
TXT_DIM       = 512
FUSION_DIM    = 512

In [3]:
print(DEVICE)

cuda


### Dataset and DataLoader

In [4]:
import pandas as pd
import numpy as np

class EMIDataset(Dataset):
    def __init__(self, csv_file, audio_dir, img_dir, text_dir, transform=None, wav_model="facebook/wav2vec2-base-960h"):
        df = pd.read_csv(csv_file)
        emo = ['Admiration','Amusement','Determination',
               'Empathic Pain','Excitement','Joy']
        # force floats, drop rows with NaN
        df[emo] = df[emo].apply(pd.to_numeric, errors='coerce')
        df = df.dropna(subset=emo).reset_index(drop=True)
        self.df = df
        self.audio_dir = audio_dir
        self.img_dir = img_dir
        self.text_dir = text_dir
        self.transform = transform or T.Compose([
            T.Resize((224,224)),
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
        self.wav_processor = Wav2Vec2Processor.from_pretrained(wav_model)
        self.txt_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fname = f"{int(row['Filename']):05d}"

        # --- Visual frames with fallback ---
        folder = os.path.join(self.img_dir, fname)
        imgs = []
        if os.path.isdir(folder):
            files = [f for f in os.listdir(folder) if f.lower().endswith(".jpg")]
            files = sorted(files)
        else:
            files = []

        if len(files) == 0:
            # no frames: use zeros
            for _ in range(N_FRAMES):
                imgs.append(torch.zeros(3,224,224))
        else:
            if len(files) >= N_FRAMES:
                idxs = np.linspace(0, len(files)-1, N_FRAMES, dtype=int)
            else:
                idxs = list(range(len(files))) + [len(files)-1] * (N_FRAMES - len(files))
            for i in idxs:
                img_path = os.path.join(folder, files[i])
                img = Image.open(img_path).convert("RGB")
                imgs.append(self.transform(img))
        visual = torch.stack(imgs)  # [T,3,224,224]

        # --- Audio (unchanged) ---
        wav, sr = torchaudio.load(os.path.join(self.audio_dir, f"{fname}.mp3"))
        wav = wav.mean(0, keepdim=True)
        # trim/pad waveform
        if wav.shape[1] > MAX_AUDIO_LEN:
            wav = wav[:, :MAX_AUDIO_LEN]
        else:
            wav = nn.functional.pad(wav, (0, MAX_AUDIO_LEN - wav.shape[1]))

        # **IMPORTANT**: pass a list of arrays
        wav_np = wav.squeeze(0).numpy()   # shape [MAX_AUDIO_LEN]
        proc = self.wav_processor(
            [wav_np],                       # <-- note the list
            sampling_rate=sr,
            return_tensors="pt",
            padding="longest",
            return_attention_mask=True
        )
        # now proc.input_values shape [1, seq_len], attention_mask [1, seq_len]
        audio_values = proc.input_values[0]     # [seq_len]
        audio_mask   = proc.attention_mask[0]   # [seq_len]

        # --- Text (unchanged) ---
        txt = open(os.path.join(self.text_dir, f"{fname}.txt")).read().strip()
        if len(txt)==0:
            txt = "neutral"
        toks = self.txt_tokenizer(txt, return_tensors="pt", padding="max_length", truncation=True)
        input_ids = toks.input_ids.squeeze(0)
        attn_mask = toks.attention_mask.squeeze(0)

        # --- Labels ---
        labels = torch.tensor([
            row['Admiration'], row['Amusement'], row['Determination'],
            row['Empathic Pain'], row['Excitement'], row['Joy']
        ], dtype=torch.float32)
        assert not torch.isnan(labels).any(), f"NaN in labels row {idx}"

        return {"visual": visual, "audio_values": audio_values, "audio_mask": audio_mask, "input_ids": input_ids, "attn_mask": attn_mask, "labels": labels}

In [6]:
# DataLoaders
train_ds = EMIDataset(TRAIN_CSV, AUDIO_DIR, IMAGE_DIR, TEXT_DIR)
val_ds   = EMIDataset(VAL_CSV,   AUDIO_DIR, IMAGE_DIR, TEXT_DIR)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

### Model

In [7]:
class EMIModel(nn.Module):
    def __init__(self):
        super().__init__()
        # (backbones and temporal modules unchanged) …
        self.vis_backbone = torchvision.models.vit_b_16(pretrained=True)
        self.vis_backbone.heads = nn.Identity()
        self.aud_backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        self.txt_backbone = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
        vis_layer = nn.TransformerEncoderLayer(d_model=VIS_DIM, nhead=8)
        self.vis_temp = nn.TransformerEncoder(vis_layer, num_layers=2)
        aud_layer = nn.TransformerEncoderLayer(d_model=AUD_DIM, nhead=8)
        self.aud_temp = nn.TransformerEncoder(aud_layer, num_layers=2)
        self.fusion = nn.Sequential(
            nn.Linear(VIS_DIM + AUD_DIM + TXT_DIM, FUSION_DIM),
            nn.ReLU(), nn.Dropout(0.3)
        )
        self.head = nn.Sequential(nn.Linear(FUSION_DIM, 6), nn.Sigmoid())

    def forward(self,
                visual,        # [B, T, 3, 224,224]
                audio_values,  # [B, seq_len]
                audio_mask,    # [B, seq_len]
                input_ids,     # [B, L]
                attn_mask      # [B, L]
               ):
        bs = visual.size(0)

        # Visual stream
        v = visual.view(-1,3,224,224)
        v_feat = self.vis_backbone(v)                        # [B*T, VIS_DIM]
        v_seq  = v_feat.view(bs, N_FRAMES, -1).permute(1,0,2) # [T, B, VIS_DIM]
        v_out  = self.vis_temp(v_seq).mean(0)                 # [B, VIS_DIM]

        # Audio stream (with mask)
        a_feats = self.aud_backbone(
            input_values=audio_values,       # [bs, seq_len]
            attention_mask=audio_mask        # [bs, seq_len]
        ).last_hidden_state                 # [bs, L, AUD_DIM]
        a_seq   = a_feats.permute(1,0,2)    # [L, bs, AUD_DIM]
        a_out   = self.aud_temp(a_seq).mean(0)  # [bs, AUD_DIM]

        # Text stream
        t_out = self.txt_backbone(
            input_ids=input_ids,
            attention_mask=attn_mask
        ).pooler_output                                     # [B, TXT_DIM]
        
        # debug checks
        if torch.isnan(v_out).any():
            raise ValueError("v_out has NaNs")
        if torch.isnan(a_out).any():
            raise ValueError("NaNs in a_out")
        if torch.isnan(t_out).any():
            raise ValueError("t_out has NaNs")

        # Fusion & prediction
        x = torch.cat([v_out, a_out, t_out], dim=1)        # [B, sum_dims]
        x = self.fusion(x)                                 # [B, FUSION_DIM]
        return self.head(x)                                # [B, 6]

In [8]:
# instantiate
model = EMIModel().to(DEVICE)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Trainig

In [9]:
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

def pearson_corr(preds, targets):
    preds = preds.detach().cpu()
    targets = targets.detach().cpu()
    vx = preds - preds.mean(0)
    vy = targets - targets.mean(0)
    corr = (vx*vy).sum(0) / (torch.sqrt((vx**2).sum(0) * (vy**2).sum(0)) + 1e-8)
    return corr.mean().item()

In [10]:
scheduler = OneCycleLR(
    optimizer,
    max_lr=5e-5,
    total_steps=NUM_EPOCHS * len(train_loader),
    pct_start=0.1,
    div_factor=100,
    final_div_factor=1e4,
    anneal_strategy="cos"
)

In [11]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc="Training"):
        visual      = batch["visual"].to(DEVICE)         # [B,T,3,224,224]
        audio_vals  = batch["audio_values"].to(DEVICE)   # [B, seq_len]
        audio_mask  = batch["audio_mask"].to(DEVICE)     # [B, seq_len]
        input_ids   = batch["input_ids"].to(DEVICE)      # [B, L]
        attn_mask   = batch["attn_mask"].to(DEVICE)      # [B, L]
        labels      = batch["labels"].to(DEVICE)         # [B, 6]

        preds = model(visual, audio_vals, audio_mask, input_ids, attn_mask)
        loss  = criterion(preds, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        #scheduler.step()
        total_loss += loss.item() * visual.size(0)
        #print(f'{loss} | {preds} | {labels}')
    return total_loss / len(loader.dataset)

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    total_corr = 0.0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            visual     = batch["visual"].to(DEVICE)
            audio_vals = batch["audio_values"].to(DEVICE)
            audio_mask  = batch["audio_mask"].to(DEVICE)
            input_ids  = batch["input_ids"].to(DEVICE)
            attn_mask  = batch["attn_mask"].to(DEVICE)
            labels     = batch["labels"].to(DEVICE)

            preds = model(visual, audio_vals, audio_mask, input_ids, attn_mask)
            loss  = criterion(preds, labels)
            total_loss += loss.item() * visual.size(0)
            total_corr += pearson_corr(preds, labels) * visual.size(0)
    return total_loss / len(loader.dataset), total_corr / len(loader.dataset)

In [12]:
for epoch in range(1, NUM_EPOCHS+1):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_corr = validate(model, val_loader, criterion)

    print(f"Epoch {epoch}/{NUM_EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Val Pearson: {val_corr:.4f}")

    if epoch % SAVE_EVERY == 0:
        ckpt = f"emi_epoch_{epoch}.pth"
        torch.save(model.state_dict(), ckpt)
        print(f"Saved checkpoint: {ckpt}")

print("Training complete!")

Training:   0%|          | 0/4036 [00:00<?, ?it/s]

Validation:   0%|          | 0/2294 [00:00<?, ?it/s]

Epoch 1/20 | Train Loss: 0.0425 | Val Loss: 0.0333 | Val Pearson: 0.1434


Training:   0%|          | 0/4036 [00:00<?, ?it/s]

Validation:   0%|          | 0/2294 [00:00<?, ?it/s]

Epoch 2/20 | Train Loss: 0.0336 | Val Loss: 0.0319 | Val Pearson: 0.1557


Training:   0%|          | 0/4036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e76abfd6e80>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1646, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/dtsarev/anaconda3/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e76abfd6e80>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1663, in __de

Validation:   0%|          | 0/2294 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e76abfd6e80>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1646, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/dtsarev/anaconda3/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e76abfd6e80>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1663, in __de

Epoch 3/20 | Train Loss: 0.0295 | Val Loss: 0.0327 | Val Pearson: 0.1606
Saved checkpoint: emi_epoch_3.pth


Training:   0%|          | 0/4036 [00:00<?, ?it/s]

Validation:   0%|          | 0/2294 [00:00<?, ?it/s]

Epoch 4/20 | Train Loss: 0.0259 | Val Loss: 0.0322 | Val Pearson: 0.1675


Training:   0%|          | 0/4036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e76abfd6e80>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1646, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/dtsarev/anaconda3/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e76abfd6e80>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1663, in __de

Validation:   0%|          | 0/2294 [00:00<?, ?it/s]

Epoch 5/20 | Train Loss: 0.0223 | Val Loss: 0.0334 | Val Pearson: 0.1638


Training:   0%|          | 0/4036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e76abfd6e80>
Traceback (most recent call last):
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "/home/dtsarev/master_of_cv/DIPLOM/repos/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1627, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/dtsarev/anaconda3/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dtsarev/anaconda3/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dtsarev/anaconda3/lib/python3.11/multiprocessing/connection.py", line 930, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^

KeyboardInterrupt: 