In [1]:
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# ==================== METRICS ====================
def ccc(y_true, y_pred):
    """
    This function calculates loss based on concordance correlation coefficient of two series: 'ser1' and 'ser2'
    TensorFlow methods are used
    """

    y_true_mean = np.mean(y_true)
    y_pred_mean = np.mean(y_pred)

    y_true_var = np.mean(np.square(y_true-y_true_mean))
    y_pred_var = np.mean(np.square(y_pred-y_pred_mean))

    cov = np.mean((y_true-y_true_mean)*(y_pred-y_pred_mean))

    ccc = np.multiply(2., cov) / (y_true_var + y_pred_var + np.square(y_true_mean - y_pred_mean))
    ccc_loss=np.mean(ccc)
    return ccc_loss

def acc_func(trues, preds):
    # print('acc', trues, preds)
    acc = []
    for i in range(5):
        acc.append(np.abs(trues - preds))
    acc = 1 - np.asarray(acc)
    return np.mean(acc)

def mf1(targets: list[np.ndarray] | np.ndarray, 
                         predicts: list[np.ndarray] | np.ndarray,
                         return_scores: bool = False) -> float | tuple[float, list[float]]:
    """Calculates mean Macro F1 score (emotional multilabel mMacroF1)
    
    Args:
        targets: Targets array (ground truth)
        predicts: Predicts array (model predictions)
        return_scores: If True, returns both mean and per-class scores
        
    Returns:
        float: Mean Macro F1 score across all classes
        or
        tuple[float, list[float]]: If return_scores=True, returns (mean, per_class_scores)
    """
    targets = np.array(targets)
    predicts = np.array(predicts)

    f1_macro_scores = []
    for i in range(predicts.shape[1]):
        cr = classification_report(targets[:, i], predicts[:, i], 
                                         output_dict=True, zero_division=0)
        f1_macro_scores.append(cr['macro avg']['f1-score'])

    if return_scores:
        return np.mean(f1_macro_scores), f1_macro_scores
    return np.mean(f1_macro_scores)


def uar(targets: list[np.ndarray] | np.ndarray,
                    predicts: list[np.ndarray] | np.ndarray,
                    return_scores: bool = False) -> float | tuple[float, list[float]]:
    """Calculates mean Unweighted Average Recall (emotional multilabel mUAR)
    
    Args:
        targets: Targets array (ground truth)
        predicts: Predicts array (model predictions)
        return_scores: If True, returns both mean and per-class scores
        
    Returns:
        float: Mean UAR across all classes
        or
        tuple[float, list[float]]: If return_scores=True, returns (mean, per_class_scores)
    """
    targets = np.array(targets)
    predicts = np.array(predicts)

    uar_scores = []
    for i in range(predicts.shape[1]):
        cr = classification_report(targets[:, i], predicts[:, i],
                                         output_dict=True, zero_division=0)
        uar_scores.append(cr['macro avg']['recall'])

    if return_scores:
        return np.mean(uar_scores), uar_scores
    return np.mean(uar_scores)

def process_predictions(pred_emo, true_emo):
    pred_emo = F.softmax(pred_emo, dim=1).cpu().detach().numpy()
    true_emo = true_emo.cpu().detach().numpy()
    threshold1 = 1 - 1 / 7
    threshold2 = 1 / 7
    mask1 = pred_emo[:, 0] >= threshold1
    transformed = (pred_emo[:, 1:] >= threshold2).astype(int)
    result = np.zeros_like(transformed)
    result[~mask1] = transformed[~mask1]
    pred_emo_bin = result.tolist()
    true_emo_bin = (true_emo[:, 1:] > 0).astype(int).tolist()
    return pred_emo_bin, true_emo_bin

# ==================== DATA LOADING ====================
def load_fiv2(pkl_path, emb_dir):
    import pickle
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f, encoding='latin1')

    trait_keys = ['openness', 'conscientiousness', 'extraversion', 'agreeableness', 'neuroticism']
    records = []
    for name in data['openness'].keys():
        short = name.replace('.mp4', '')
        record = {'video_name': short}
        for key in trait_keys:
            record[key] = data[key][name]
        records.append(record)
    df = pd.DataFrame(records)

    X, y, lengths = [], [], []
    for folder in Path(emb_dir).iterdir():
        for pt_file in folder.glob('*.pt'):
            video_id = pt_file.stem
            if video_id not in df['video_name'].values:
                continue
            emb = torch.load(pt_file, weights_only=True)
            if isinstance(emb, dict):
                emb = emb['emb']
            if torch.isnan(emb).any() or emb.shape != (30, 1024):
                continue
            target = df[df['video_name'] == video_id][trait_keys].values[0].astype('float32')
            X.append(emb)
            y.append(torch.tensor(target))
            lengths.append(torch.tensor(emb.shape[0]))
    return TensorDataset(torch.stack(X), torch.stack(y), torch.stack(lengths))

def load_mosei(csv_path, emb_dir):
    df = pd.read_csv(csv_path).drop(columns=["text", "Other"])
    X, y, lengths = [], [], []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        video_id = row['video_name']
        pt_file = Path(emb_dir) / f"{video_id}.pt"
        if not pt_file.exists():
            continue
        emb = torch.load(pt_file, weights_only=True)
        if isinstance(emb, dict):
            emb = emb['emb']
        if torch.isnan(emb).any() or emb.shape != (30, 1024):
            continue
        X.append(emb)
        y.append(torch.tensor(row.drop("video_name").values.astype('float32')))
        lengths.append(torch.tensor(emb.shape[0]))
    return TensorDataset(torch.stack(X), torch.stack(y), torch.stack(lengths))


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load data
fiv2_train = load_fiv2("Embeddings_emonext/fiv2_embeddings/annotation_training.pkl", "Embeddings_emonext/fiv2_embeddings/train")
fiv2_val = load_fiv2("Embeddings_emonext/fiv2_embeddings/annotation_test.pkl", "Embeddings_emonext/fiv2_embeddings/test")
mosei_train = load_mosei("emo_video/train_full.csv", "Embeddings_emonext/cmu_mosei_embeddings/train")
mosei_val = load_mosei("emo_video/test_full/test_full.csv", "Embeddings_emonext/cmu_mosei_embeddings/test")

100%|██████████| 16274/16274 [04:20<00:00, 62.37it/s]
100%|██████████| 4653/4653 [01:41<00:00, 45.68it/s]


In [3]:
train_loader_traits = DataLoader(fiv2_train, batch_size=1024, shuffle=True)
val_loader_traits = DataLoader(fiv2_val, batch_size=1024)
train_loader_emo = DataLoader(mosei_train, batch_size=1024, shuffle=True)
val_loader_emo = DataLoader(mosei_val, batch_size=1024)

In [4]:
# ==================== MODEL ====================

class TransformerBackbone(nn.Module):
    def __init__(self, input_dim=1024, d_model=512, num_layers=3, n_heads=8, dropout=0.2, max_len=30):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        x = self.input_proj(x) + self.pos_embed[:, :x.size(1)]
        x = self.encoder(x)
        mask = torch.arange(x.size(1), device=lengths.device)[None, :] < lengths[:, None]
        mask = mask.float().unsqueeze(2)
        summed = (x * mask).sum(dim=1)
        count = mask.sum(dim=1).clamp(min=1)
        pooled = summed / count
        return self.dropout(pooled)


class FrozenTransformerWrapper(nn.Module):
    def __init__(self, scripted_model_path, backbone_type='traits'):
        super().__init__()
        self.model = TransformerBackbone()
        #state_dict = torch.jit.load(scripted_model_path).state_dict()
        state_dict = torch.load(scripted_model_path, weights_only=True)['model_state_dict']
        filtered = {k: v for k, v in state_dict.items() if not k.startswith('fc')}
        self.model.load_state_dict(filtered, strict=False)
        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, x, lengths):
        return self.model(x, lengths)


class CrossAttentionBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=8, dropout=0.2):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, kv):
        out, _ = self.attn(q, kv, kv)
        return self.norm(q + self.dropout(out))



class MultiTaskFusionModel(nn.Module):
    def __init__(self, trait_model_path="models_checkpoints/fiv2_best_checkpoint.pth", emo_model_path="models_checkpoints/cmu_mosei_best_checkpoint.pth", d_model=512, dropout=0.2):
        super().__init__()
        
        self.trait_model = FrozenTransformerWrapper(trait_model_path)
        self.emo_model = FrozenTransformerWrapper(emo_model_path)

        self.trait_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, 8, 4*d_model, dropout=dropout, batch_first=True), num_layers=2)
        self.emo_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, 8, 4*d_model, dropout=dropout, batch_first=True), num_layers=2)

        self.cross1 = CrossAttentionBlock(d_model, 8, dropout)
        self.cross2 = CrossAttentionBlock(d_model, 8, dropout)

        self.shared_mlp = nn.Sequential(
            nn.Linear(2 * d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.trait_head = nn.Linear(d_model, 5)
        self.emo_head = nn.Linear(d_model, 7)

    def forward(self, x, lengths, task='traits'):
        trait_feat = self.trait_model(x, lengths).unsqueeze(1)
        emo_feat = self.emo_model(x, lengths).unsqueeze(1)

        trait_encoded = self.trait_encoder(trait_feat)
        emo_encoded = self.emo_encoder(emo_feat)

        trait_cross = self.cross1(trait_encoded, emo_encoded)
        emo_cross = self.cross2(emo_encoded, trait_encoded)

        fused = torch.cat([trait_cross, emo_cross], dim=-1).squeeze(1)
        hidden = self.shared_mlp(fused)

        if task == 'traits':
            return self.trait_head(hidden)
        elif task == 'emotions':
            return self.emo_head(hidden)
        else:
            raise ValueError("task must be either 'traits' or 'emotions'")

In [8]:
# ==================== TRAINING ====================

model = MultiTaskFusionModel().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

criterion_traits = nn.MSELoss()

def weighted_kl_div(log_probs, target_probs, class_weights):
    """
    Manually computes class-weighted KL divergence:
    L = sum_j w_j * q_j * (log q_j - log p_j)
    where q_j is target prob and p_j is predicted prob.
    """
    kl_per_class = target_probs * (torch.log(target_probs + 1e-8) - log_probs)  # (B, C)
    weighted_kl = kl_per_class * class_weights  # apply weights per class
    return weighted_kl.sum(dim=1).mean()  # average over batch

class_weights = torch.tensor([3.0, 10.0, 13.0, 20.0, 2.0, 15.0, 18.0], device=DEVICE)

for epoch in range(1, 201):
    model.train()
    total_loss_emo, total_loss_traits = 0, 0

    for xb, yb, lb in train_loader_emo:
        xb, yb, lb = xb.to(DEVICE), yb.to(DEVICE), lb.to(DEVICE)
        optimizer.zero_grad()
        logits = model(xb, lb, task='emotions')
        log_probs = F.log_softmax(logits, dim=1)
        loss = weighted_kl_div(log_probs, yb, class_weights)
        loss.backward()
        optimizer.step()
        total_loss_emo += loss.item()

    for xb, yb, lb in train_loader_traits:
        xb, yb, lb = xb.to(DEVICE), yb.to(DEVICE), lb.to(DEVICE)
        optimizer.zero_grad()
        preds = model(xb, lb, task='traits')
        loss = criterion_traits(preds, yb)
        loss.backward()
        optimizer.step()
        total_loss_traits += loss.item()

    # === Validation ===
    model.eval()

    # Emotions
    emo_val_loss, raw_preds, raw_targets = 0, [], []
    with torch.no_grad():
        for xb, yb, lb in val_loader_emo:
            xb, yb, lb = xb.to(DEVICE), yb.to(DEVICE), lb.to(DEVICE)
            logits = model(xb, lb, task='emotions')
            log_probs = F.log_softmax(logits, dim=1)
            loss = weighted_kl_div(log_probs, yb, class_weights)
            emo_val_loss += loss.item()
            raw_preds.append(logits)
            raw_targets.append(yb)
    bin_preds, bin_targets = process_predictions(torch.cat(raw_preds), torch.cat(raw_targets))
    f1 = mf1(bin_targets, bin_preds)
    recall = uar(bin_targets, bin_preds)

    # Traits
    trait_val_loss, preds, targets = 0, [], []
    with torch.no_grad():
        for xb, yb, lb in val_loader_traits:
            xb, yb, lb = xb.to(DEVICE), yb.to(DEVICE), lb.to(DEVICE)
            out = model(xb, lb, task='traits')
            loss = criterion_traits(out, yb)
            trait_val_loss += loss.item()
            preds.append(out.cpu())
            targets.append(yb.cpu())
    preds = torch.cat(preds).numpy()
    targets = torch.cat(targets).numpy()
    ccc_score = np.mean([ccc(targets[:, i], preds[:, i]) for i in range(5)])
    acc_score = np.mean([acc_func(targets[:, i], preds[:, i]) for i in range(5)])

    print(f"[Epoch {epoch}] Emo Train Loss: {total_loss_emo:.4f} | Emo Val Loss: {emo_val_loss:.4f} | mF1: {f1:.4f} | mUAR: {recall:.4f}")
    print(f"           Trait Train Loss: {total_loss_traits:.4f} | Trait Val Loss: {trait_val_loss:.4f} | CCC: {ccc_score:.4f} | mACC: {acc_score:.4f}")


[Epoch 1] Emo Train Loss: 115.8124 | Emo Val Loss: 61.5978 | mF1: 0.5458 | mUAR: 0.5683
         Trait Train Loss: 1.0879 | Trait Val Loss: 0.3517 | CCC: 0.0688 | mACC: 0.6307390332221985
[Epoch 2] Emo Train Loss: 99.2456 | Emo Val Loss: 58.4184 | mF1: 0.5232 | mUAR: 0.5682
         Trait Train Loss: 0.4226 | Trait Val Loss: 0.0775 | CCC: 0.1132 | mACC: 0.8410788774490356
[Epoch 3] Emo Train Loss: 95.3153 | Emo Val Loss: 62.0565 | mF1: 0.5429 | mUAR: 0.5702
         Trait Train Loss: 0.4820 | Trait Val Loss: 0.0757 | CCC: 0.1291 | mACC: 0.8443183898925781
[Epoch 4] Emo Train Loss: 94.4056 | Emo Val Loss: 62.9582 | mF1: 0.5329 | mUAR: 0.5651
         Trait Train Loss: 0.5657 | Trait Val Loss: 0.0790 | CCC: 0.1773 | mACC: 0.8419104814529419
[Epoch 5] Emo Train Loss: 92.7408 | Emo Val Loss: 61.3848 | mF1: 0.5218 | mUAR: 0.5598
         Trait Train Loss: 0.6871 | Trait Val Loss: 0.0599 | CCC: 0.1956 | mACC: 0.8613244891166687
[Epoch 6] Emo Train Loss: 92.8208 | Emo Val Loss: 58.5034 | mF1:

In [10]:
# Save model weights
torch.save(model.state_dict(), 'models_checkpoints/multitask_fusion_model.pth')