# EAGT Training — Real Frame/Audio Decoding + Dataset Caching (All-in-One)

This notebook provides a **practical training pipeline** for the Emotion-Aware Generative AI Tutor (EAGT):
1. **CSV ingestion** compatible with DAiSEE/SEMAINE-style manifests.
2. **Robust video frame decoding** using OpenCV (cv2) with clip sampling.
3. **Audio decoding** with `soundfile`/`librosa`, or **FFmpeg** fallback extraction from video.
4. **Feature transforms** (resize/normalize for vision, Mel-spectrograms for audio).
5. **Disk caching** of preprocessed tensors to accelerate subsequent epochs.
6. **PyTorch dataset + dataloaders** with multi-worker I/O.
7. **Toy fusion model** (vision CNN + audio CNN) and a training loop scaffold.


In [None]:
# !pip install opencv-python soundfile librosa torch torchvision torchaudio matplotlib scikit-learn pyyaml
# !apt-get -y install ffmpeg
import os, sys, math, time, json, csv, hashlib, subprocess
from pathlib import Path
import numpy as np
import pandas as pd
import cv2
import soundfile as sf
import librosa
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
SEED=1337
np.random.seed(SEED); torch.manual_seed(SEED)

## 1. Configuration

In [None]:
from dataclasses import dataclass
@dataclass
class Config:
    csv_train: str = 'configs/daisee_train.csv'
    csv_val:   str = 'configs/daisee_val.csv'
    cache_dir: str = '.cache_eagt'
    num_workers: int = 4
    batch_size: int = 8
    epochs: int = 2
    lr: float = 3e-4
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    clip_seconds: float = 1.0
    fps: int = 8
    img_size: int = 112
    sr: int = 16000
    n_mels: int = 64
CFG = Config()

## 2. FFmpeg Fallback

In [None]:
def have_ffmpeg():
    try:
        subprocess.run(['ffmpeg','-version'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        return True
    except Exception:
        return False
def extract_audio_ffmpeg(video_path: str, out_wav: str, sr: int=16000, overwrite=False) -> bool:
    Path(out_wav).parent.mkdir(parents=True, exist_ok=True)
    cmd = ['ffmpeg', '-y' if overwrite else '-n', '-i', video_path, '-ac','1','-ar',str(sr),'-vn','-loglevel','error', out_wav]
    try:
        subprocess.run(cmd, check=True)
        return True
    except subprocess.CalledProcessError:
        return False

## 3. Video Decoding

In [None]:
def read_video_clip(path: str, clip_seconds: float, fps: int, img_size: int) -> np.ndarray:
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise RuntimeError(f'Cannot open video: {path}')
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    src_fps = cap.get(cv2.CAP_PROP_FPS) or fps
    duration = total_frames / max(src_fps, 1)
    center_t = duration/2.0
    needed = int(clip_seconds*fps)
    times = np.linspace(center_t - clip_seconds/2, center_t + clip_seconds/2, num=needed, endpoint=False)
    frames = []
    for t in times:
        idx = int(t * src_fps)
        idx = np.clip(idx, 0, max(total_frames-1,0))
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ok, frame = cap.read()
        if not ok:
            frame = np.zeros((img_size, img_size,3), dtype=np.uint8)
        else:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, (img_size, img_size), interpolation=cv2.INTER_AREA)
        frames.append(frame)
    cap.release()
    arr = np.stack(frames, axis=0).astype(np.float32)/255.0
    return arr

## 4. Audio Decoding

In [None]:
def read_audio_clip(audio_path: str, video_path: str, clip_seconds: float, sr: int, cache_root: Path) -> np.ndarray:
    wav = None; path=None
    if audio_path and Path(audio_path).exists(): path=audio_path
    else:
        if have_ffmpeg() and Path(video_path).exists():
            import hashlib
            vid_hash = hashlib.md5(video_path.encode('utf-8')).hexdigest()[:10]
            out_wav = cache_root/ 'extracted_audio' / f'{vid_hash}.wav'
            if not out_wav.exists(): extract_audio_ffmpeg(video_path, str(out_wav), sr=sr, overwrite=False)
            if out_wav.exists(): path=str(out_wav)
    if path and Path(path).exists():
        try:
            wav, srr = sf.read(path, dtype='float32')
            if wav.ndim==2: wav = wav.mean(-1)
            if srr!=sr: wav = librosa.resample(wav, orig_sr=srr, target_sr=sr)
        except Exception:
            wav=None
    if wav is None: wav = np.zeros(int(sr*clip_seconds), dtype=np.float32)
    target = int(sr*clip_seconds)
    if len(wav)>target:
        s=(len(wav)-target)//2; wav=wav[s:s+target]
    elif len(wav)<target:
        pad=target-len(wav); wav=np.pad(wav,(pad//2,pad-pad//2))
    return wav.astype(np.float32)

## 5. Transforms

In [None]:
IM_MEAN=np.array([0.485,0.456,0.406],dtype=np.float32); IM_STD=np.array([0.229,0.224,0.225],dtype=np.float32)
def vision_transform(frames: np.ndarray)->torch.Tensor:
    x=(frames-IM_MEAN)/IM_STD
    x=np.transpose(x,(3,0,1,2))
    return torch.from_numpy(x.astype(np.float32))
def audio_mel_transform(wav: np.ndarray, sr:int, n_mels:int)->torch.Tensor:
    mel=librosa.feature.melspectrogram(y=wav,sr=sr,n_mels=n_mels,fmin=30,fmax=sr//2)
    mel=np.log(np.maximum(1e-8,mel)).astype(np.float32)
    m,s=mel.mean(),mel.std()+1e-6; mel=(mel-m)/s
    return torch.from_numpy(mel)

## 6. Cache & Dataset

In [None]:
import json, hashlib
def cache_key(video_path:str,audio_path:str,label:str,cfg)->str:
    h=hashlib.md5(); h.update(video_path.encode()); h.update((audio_path or '').encode()); h.update(label.encode())
    blob=json.dumps({'clip_seconds':cfg.clip_seconds,'fps':cfg.fps,'img_size':cfg.img_size,'sr':cfg.sr,'n_mels':cfg.n_mels},sort_keys=True)
    h.update(blob.encode()); return h.hexdigest()
def cache_path(root:Path,key:str)->Path: return root/f'{key}.pt'

LABELS=['frustration','confusion','boredom','engagement']; L2I={l:i for i,l in enumerate(LABELS)}
class VideoAudioDataset(torch.utils.data.Dataset):
    def __init__(self,csv_path:str,cfg):
        self.df=pd.read_csv(csv_path); self.cfg=cfg; self.root=Path(cfg.cache_dir); self.root.mkdir(parents=True,exist_ok=True)
    def __len__(self): return len(self.df)
    def __getitem__(self,idx:int):
        r=self.df.iloc[idx]
        v=str(r['video_path']); a=str(r['audio_path']) if 'audio_path' in r and not pd.isna(r['audio_path']) else ''
        lab=str(r['label']).lower().strip(); y=L2I.get(lab,0)
        key=cache_key(v,a,lab,self.cfg); c=cache_path(self.root,key)
        if c.exists():
            blob=torch.load(c); return blob['vision'],blob['audio'],y
        frames=read_video_clip(v,self.cfg.clip_seconds,self.cfg.fps,self.cfg.img_size)
        wav=read_audio_clip(a,v,self.cfg.clip_seconds,self.cfg.sr,self.root)
        vis=vision_transform(frames); mel=audio_mel_transform(wav,self.cfg.sr,self.cfg.n_mels)
        torch.save({'vision':vis,'audio':mel}, c)
        return vis,mel,y

## 7. Model (Toy Fusion)

In [None]:
class VisionTiny(nn.Module):
    def __init__(self,out_dim=128):
        super().__init__(); self.net=nn.Sequential(nn.Conv2d(3*CFG.fps,64,3,2,1),nn.ReLU(),nn.Conv2d(64,128,3,2,1),nn.ReLU(),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(128,out_dim))
    def forward(self,x): b,c,t,h,w=x.shape; x=x.reshape(b,c*t,h,w); return self.net(x)
class AudioTiny(nn.Module):
    def __init__(self,out_dim=128):
        super().__init__(); self.net=nn.Sequential(nn.Conv2d(1,32,5,2,2),nn.ReLU(),nn.Conv2d(32,64,5,2,2),nn.ReLU(),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(64,out_dim))
    def forward(self,mel): return self.net(mel.unsqueeze(1))
class FusionClassifier(nn.Module):
    def __init__(self,feat=128,nc=len(LABELS)):
        super().__init__(); self.v=VisionTiny(feat); self.a=AudioTiny(feat); self.cls=nn.Sequential(nn.ReLU(),nn.Dropout(0.3),nn.Linear(2*feat,nc))
    def forward(self,vis,mel): return self.cls(torch.cat([self.v(vis),self.a(mel)],-1))

## 8. Loaders

In [None]:
def make_loaders(cfg: Config):
    tr=VideoAudioDataset(cfg.csv_train,cfg); va=VideoAudioDataset(cfg.csv_val,cfg)
    dl_tr=DataLoader(tr,batch_size=cfg.batch_size,shuffle=True,num_workers=cfg.num_workers,pin_memory=True,persistent_workers=cfg.num_workers>0)
    dl_va=DataLoader(va,batch_size=cfg.batch_size,shuffle=False,num_workers=cfg.num_workers,pin_memory=True,persistent_workers=cfg.num_workers>0)
    return dl_tr, dl_va

## 9. Train

In [None]:
def run_epoch(model, loader, opt, device, train=True):
    model.train(train); tot=0; n=0; corr=0
    for vis,mel,y in loader:
        vis,mel,y=vis.to(device),mel.to(device),y.to(device)
        if train: opt.zero_grad()
        logits=model(vis,mel); loss=nn.CrossEntropyLoss()(logits,y)
        if train: loss.backward(); opt.step()
        pred=logits.argmax(-1); corr+=(pred==y).sum().item(); n+=y.numel(); tot+=loss.item()*y.numel()
    return tot/max(n,1), corr/max(n,1)

def validate_report(model, loader, device):
    y_true=[]; y_pred=[]; model.eval()
    with torch.no_grad():
        for vis,mel,y in loader:
            logits=model(vis.to(device),mel.to(device)); y_pred+=logits.argmax(-1).cpu().tolist(); y_true+=y.tolist()
    print(classification_report(y_true,y_pred,target_names=LABELS,digits=3))

def train_main(cfg: Config):
    dl_tr, dl_va = make_loaders(cfg)
    model=FusionClassifier().to(cfg.device); opt=optim.AdamW(model.parameters(),lr=cfg.lr)
    for ep in range(1,cfg.epochs+1):
        tr_loss,tr_acc=run_epoch(model,dl_tr,opt,cfg.device,train=True)
        va_loss,va_acc=run_epoch(model,dl_va,opt,cfg.device,train=False)
        print(f'Epoch {ep:02d} | train {tr_loss:.4f}/{tr_acc:.3f} | val {va_loss:.4f}/{va_acc:.3f}')
    print('\nValidation Report:'); validate_report(model,dl_va,cfg.device); return model

# model = train_main(CFG)

## 10. Visual Check

In [None]:
def show_sample(ds: VideoAudioDataset, idx=0):
    vis, mel, y = ds[idx]
    c,t,h,w = vis.shape
    cols=min(t,8); rows=int(np.ceil(t/cols))
    fig,ax=plt.subplots(rows,cols,figsize=(cols*1.5,rows*1.5)); ax=np.array(ax).reshape(rows,cols)
    k=0
    for r in range(rows):
        for c0 in range(cols):
            a=ax[r,c0]; a.axis('off')
            if k<t:
                img=vis[:,k].permute(1,2,0).numpy(); img=(img*IM_STD+IM_MEAN).clip(0,1); a.imshow(img)
            k+=1
    plt.suptitle(f'Label: {LABELS[y]} (frames)'); plt.show()
    plt.figure(figsize=(6,3)); plt.imshow(mel.numpy(),aspect='auto',origin='lower'); plt.colorbar(); plt.title('Mel (norm)'); plt.tight_layout(); plt.show()

# ds = VideoAudioDataset(CFG.csv_train, CFG); show_sample(ds,0)