In [1]:
import torch
import torch.nn as nn
import torchvision.models as models

class VisualBackbone(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-1])  # output: (B, 512, 1, 1)
        self.fc = nn.Linear(512, embed_dim)

    def forward(self, x_seq):  # x_seq: (B, T, 3, H, W)
        B, T, C, H, W = x_seq.shape
        x_seq = x_seq.view(B * T, C, H, W)
        features = self.cnn(x_seq)       # (B*T, 512, 1, 1)
        features = features.view(B, T, -1).mean(dim=1)  # average over time
        return self.fc(features)         # (B, embed_dim)

class AudioBranch(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        self.lstm = nn.LSTM(input_size=40, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(128 * 2, embed_dim)

    def forward(self, x):  # x: (B, 1, 40, T)
        x = x.squeeze(1).permute(0, 2, 1)  # → (B, T, 40)
        _, (hn, _) = self.lstm(x)         # hn shape: (4, B, 128)
        out = torch.cat((hn[-2], hn[-1]), dim=1)  # (B, 256)
        return self.fc(out)               # (B, embed_dim)

class MultimodalEmotionRecognizer(nn.Module):
    def __init__(self, num_classes=4, embed_dim=256):
        super().__init__()
        self.visual_branch = VisualBackbone(embed_dim)
        self.audio_branch = AudioBranch(embed_dim)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, face_seq, mfcc):
        visual_feat = self.visual_branch(face_seq)  # (B, embed_dim)
        audio_feat  = self.audio_branch(mfcc)       # (B, embed_dim)
        fused = torch.cat((visual_feat, audio_feat), dim=1)
        return self.classifier(fused)

In [2]:
model = MultimodalEmotionRecognizer(num_classes=4)
x_img = torch.randn(2, 5, 3, 224, 224)
x_mfcc = torch.randn(2, 1, 40, 100)
out = model(x_img, x_mfcc)
print(out.shape)  # should be (2, 4)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\mahmo/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 70.8MB/s]


torch.Size([2, 4])
