In [1]:
import torch.nn as nn

class TransformerBackbone(nn.Module):
    def __init__(self, input_dim=1024, d_model=512, num_layers=3, n_heads=8, dropout=0.2, max_len=30):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 4, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        x = self.input_proj(x) + self.pos_embed[:, :x.size(1)]
        x = self.encoder(x)
        mask = torch.arange(x.size(1), device=lengths.device)[None, :] < lengths[:, None]
        mask = mask.float().unsqueeze(2)
        summed = (x * mask).sum(dim=1)
        count = mask.sum(dim=1).clamp(min=1)
        pooled = summed / count
        return self.dropout(pooled)


class FrozenTransformerWrapper(nn.Module):
    def __init__(self, scripted_model_path, backbone_type='traits'):
        super().__init__()
        self.model = TransformerBackbone()
        #state_dict = torch.jit.load(scripted_model_path).state_dict()
        state_dict = torch.load(scripted_model_path, weights_only=True)['model_state_dict']
        filtered = {k: v for k, v in state_dict.items() if not k.startswith('fc')}
        self.model.load_state_dict(filtered, strict=False)
        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, x, lengths):
        return self.model(x, lengths)


class CrossAttentionBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=8, dropout=0.2):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, kv):
        out, _ = self.attn(q, kv, kv)
        return self.norm(q + self.dropout(out))



class MultiTaskFusionModel(nn.Module):
    def __init__(self, trait_model_path="models_checkpoints/fiv2_best_checkpoint.pth", emo_model_path="models_checkpoints/cmu_mosei_best_checkpoint.pth", d_model=512, dropout=0.2):
        super().__init__()
        
        # self.trait_model = FrozenTransformerWrapper("models_checkpoints/fiv2_best_checkpoint.pth")
        # self.emo_model = FrozenTransformerWrapper("models_checkpoints/cmu_mosei_best_checkpoint.pth")

        self.trait_model = FrozenTransformerWrapper(trait_model_path)
        self.emo_model = FrozenTransformerWrapper(emo_model_path)

        self.trait_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, 8, 4*d_model, dropout=dropout, batch_first=True), num_layers=2)
        self.emo_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, 8, 4*d_model, dropout=dropout, batch_first=True), num_layers=2)

        self.cross1 = CrossAttentionBlock(d_model, 8, dropout)
        self.cross2 = CrossAttentionBlock(d_model, 8, dropout)

        self.shared_mlp = nn.Sequential(
            nn.Linear(2 * d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.trait_head = nn.Linear(d_model, 5)
        self.emo_head = nn.Linear(d_model, 7)

    def forward(self, x, lengths, task='traits'):
        trait_feat = self.trait_model(x, lengths).unsqueeze(1)
        emo_feat = self.emo_model(x, lengths).unsqueeze(1)

        trait_encoded = self.trait_encoder(trait_feat)
        emo_encoded = self.emo_encoder(emo_feat)

        trait_cross = self.cross1(trait_encoded, emo_encoded)
        emo_cross = self.cross2(emo_encoded, trait_encoded)

        fused = torch.cat([trait_cross, emo_cross], dim=-1).squeeze(1)
        hidden = self.shared_mlp(fused)

        if task == 'traits':
            return self.trait_head(hidden)
        elif task == 'emotions':
            return self.emo_head(hidden)
        else:
            raise ValueError("task must be either 'traits' or 'emotions'")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
import cv2
import mediapipe as mp
from emonext_model import get_model

# =================== CONFIG ===================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_FRAMES = 30
FRAME_SIZE = 224
EMO_CHECKPOINT = 'models_checkpoints/cmu_mosei_best_checkpoint.pth'
TRAIT_CHECKPOINT = 'models_checkpoints/fiv2_best_checkpoint.pth'
FUSION_MODEL_PATH = 'models_checkpoints/multitask_fusion_model.pth'

# ================== Load EmoNeXt ==================
emonext = get_model(num_classes=7, model_size="base", in_22k=False).to(DEVICE)
emonext.eval()

# ================ Face Detection ==================
mp_face_detection = mp.solutions.face_detection
face_detector = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.6)

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def get_evenly_spaced_frames(video_path, num_frames):
    cap = cv2.VideoCapture(str(video_path))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    idxs = np.linspace(0, total - 1, num=num_frames, dtype=int)
    frames = []
    for i in idxs:
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frames.append(frame)
    cap.release()
    return frames

def get_face(frame):
    results = face_detector.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    if results.detections:
        detection = results.detections[0]
        bbox = detection.location_data.relative_bounding_box
        h, w, _ = frame.shape
        x1 = max(int(bbox.xmin * w), 0)
        y1 = max(int(bbox.ymin * h), 0)
        x2 = min(int((bbox.xmin + bbox.width) * w), w)
        y2 = min(int((bbox.ymin + bbox.height) * h), h)
        return frame[y1:y2, x1:x2]
    return None

def get_embeddings_from_video(video_path, num_frames):
    frames = get_evenly_spaced_frames(video_path, num_frames)
    face_tensors = []
    last_face = None
    for frame in frames:
        face = get_face(frame)
        if face is not None:
            last_face = face
        if last_face is not None:
            tensor = transform(last_face)
            face_tensors.append(tensor)

    if not face_tensors:
        return None

    batch = torch.stack(face_tensors).to(DEVICE)
    with torch.no_grad():
        aligned = emonext.stn(batch)
        emb = emonext.forward_features(aligned)
    return emb.cpu()  # shape: (30, 1024)

# ================== Load Model ==================
class FrozenTransformerWrapper(nn.Module):
    def __init__(self, checkpoint_path):
        super().__init__()
        self.model = TransformerBackbone()
        state_dict = torch.load(checkpoint_path, weights_only=True)['model_state_dict']
        filtered = {k: v for k, v in state_dict.items() if not k.startswith('fc')}
        self.model.load_state_dict(filtered, strict=False)
        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, x, lengths):
        return self.model(x, lengths)


# ================== Inference Function ==================
@torch.no_grad()
def infer_from_video(video_path, model):
    emb = get_embeddings_from_video(video_path, num_frames=30)
    if emb is None or torch.isnan(emb).any():
        raise ValueError("Embedding failed or NaNs present")

    emb = emb.unsqueeze(0).to(DEVICE)  # [1, 30, 1024]
    lengths = torch.tensor([emb.shape[1]], device=DEVICE)

    model.eval()
    emo_logits = model(emb, lengths, task='emotions')  # (1, 7)
    trait_out = model(emb, lengths, task='traits')     # (1, 5)

    emo_probs = F.softmax(emo_logits, dim=1).squeeze(0).cpu()
    trait_scores = trait_out.squeeze(0).cpu()
    return emo_probs, trait_scores


In [3]:
# ================== Main ==================

fusion_model = MultiTaskFusionModel().to(DEVICE)
fusion_model.load_state_dict(torch.load(FUSION_MODEL_PATH, weights_only=True))
fusion_model.eval()

VIDEO_PATH = 'sample_data/meow2.mp4'
emotions, traits = infer_from_video(VIDEO_PATH, fusion_model)


In [4]:
print("\nEmotion distribution (softmax, sum=1):")
emotion_names = ['Neutral', 'Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise']
for emotion_name, emotion in zip(emotion_names, emotions.tolist()):
    print(f"  {emotion_name}: {emotion:.4f}")

print("\nPersonality traits (OCEAN):")
trait_names = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Neuroticism']
for name, v in zip(trait_names, traits.tolist()):
    print(f"  {name}: {v:.4f}")


Emotion distribution (softmax, sum=1):
  Neutral: 0.0096
  Anger: 0.0062
  Disgust: 0.0616
  Fear: 0.0133
  Happiness: 0.6115
  Sadness: 0.2362
  Surprise: 0.0617

Personality traits (OCEAN):
  Openness: 0.5722
  Conscientiousness: 0.5299
  Extraversion: 0.4823
  Agreeableness: 0.5799
  Neuroticism: 0.5416
