# Model Training — Multimodal (toy)

In [None]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd, numpy as np, random
SEED=1337; torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)


In [None]:
class EAGTDataset(Dataset):
    def __init__(self, csv_path):
        self.df = pd.read_csv(csv_path)
        self.lmap={'frustration':0,'confusion':1,'boredom':2,'engagement':3}
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        # TODO: real decoding; this is a toy random batch for structure
        return torch.randn(3,112,112), torch.randn(16000), self.lmap.get(str(self.df.iloc[i]['label']).lower(),0)


In [None]:
class VisionCNN(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        self.net=nn.Sequential(nn.Conv2d(3,32,3,2,1),nn.ReLU(),nn.Conv2d(32,64,3,2,1),nn.ReLU(),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(64,out_dim))
    def forward(self,x): return self.net(x)
class AudioEnc(nn.Module):
    def __init__(self,out_dim=128):
        super().__init__(); self.net=nn.Sequential(nn.Conv1d(1,32,9,4,4),nn.ReLU(),nn.Conv1d(32,64,9,4,4),nn.ReLU(),nn.AdaptiveAvgPool1d(1),nn.Flatten(),nn.Linear(64,out_dim))
    def forward(self,x): x=x.unsqueeze(0) if x.dim()==1 else x; x=x.unsqueeze(1); return self.net(x)
class Fusion(nn.Module):
    def __init__(self):
        super().__init__(); self.v=VisionCNN(); self.a=AudioEnc(); self.fc=nn.Sequential(nn.Linear(256,256),nn.ReLU(),nn.Dropout(0.3),nn.Linear(256,4))
    def forward(self,xf,xa): vf=self.v(xf); af=self.a(xa); return self.fc(torch.cat([vf,af],-1))


In [None]:
def train(csv_path, epochs=2, bs=8, lr=3e-4):
    ds=EAGTDataset(csv_path); dl=DataLoader(ds,batch_size=bs,shuffle=True)
    m=Fusion(); opt=optim.AdamW(m.parameters(),lr=lr); lossfn=nn.CrossEntropyLoss()
    for ep in range(1,epochs+1):
        m.train(); tot=0; n=0; corr=0
        for xf,xa,y in dl:
            opt.zero_grad(); logits=m(xf,xa); loss=lossfn(logits,y); loss.backward(); opt.step()
            pred=logits.argmax(-1); corr += (pred==y).sum().item(); n+=y.numel(); tot += loss.item()*y.numel()
        print(f'Epoch {ep}: loss={tot/n:.4f} acc={corr/n:.3f}')
    return m
# train('configs/daisee_split.csv', epochs=1)
