In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
class EmbeddingDataset(Dataset):
    def __init__(self, img, vid, aud):

        self.img_emb = img["embeddings"]
        self.vid_emb = vid["embeddings"]
        self.aud_emb = aud["embeddings"]
        self.labels = torch.logical_and(img["labels"] == 1, aud["labels"] == 1).long()

        assert len(self.img_emb) == len(self.vid_emb)
        # print(len(self.img_emb), len(self.vid_emb), len(self.aud_emb))
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (
            self.img_emb[idx],       # (6529)
            self.vid_emb[idx],       # (6529)
            self.aud_emb[idx],       # (53868)
            self.labels[idx]         # scalar
        )


In [2]:
img_path = "embeddings/image_embeddings.pt"
vid_path = "embeddings/video_embeddings.pt"
aud_path = "embeddings/audio_embeddings.pt"
img = torch.load(img_path)
vid = torch.load(vid_path)
aud = torch.load(aud_path)


In [3]:
aud["labels"].shape

torch.Size([53868])

In [4]:
type(aud)

dict

In [5]:
aud_projection = nn.Linear(768, 512)
aud_embed_512 = aud_projection(aud["embeddings"])
dim_reduced_aud = {}
dim_reduced_aud["embeddings"] = aud_embed_512
dim_reduced_aud["labels"] = aud["labels"]
dim_reduced_aud

{'embeddings': tensor([[-0.3066, -0.2047, -0.0683,  ..., -0.1312, -0.1106,  0.1742],
         [-0.1104, -0.1327,  0.0353,  ..., -0.0845,  0.0210, -0.0060],
         [-0.0369, -0.3680,  0.0231,  ..., -0.0512, -0.0616, -0.0716],
         ...,
         [-0.3582, -0.3175, -0.1303,  ..., -0.0630,  0.0294,  0.1945],
         [-0.2436, -0.3054, -0.0577,  ...,  0.0968, -0.0503,  0.2423],
         [-0.1108, -0.4214,  0.0558,  ..., -0.0789, -0.0576, -0.0923]],
        grad_fn=<AddmmBackward0>),
 'labels': tensor([0, 1, 1,  ..., 0, 0, 1])}

In [6]:
emb = dim_reduced_aud["embeddings"]  
lbl = dim_reduced_aud["labels"] 

B, D = emb.shape
k = B // 6529
B_new = 6529 * k

perm = torch.randperm(B)
emb_shuffled = emb[perm]
lbl_shuffled = lbl[perm]

emb_trim = emb_shuffled[:B_new]
lbl_trim = lbl_shuffled[:B_new]

emb_reshaped = emb_trim.view(6529, k, 512)      # [6529, k, 512]
emb_pooled = emb_reshaped.mean(dim=1)           # [6529, 512]

lbl_reshaped = lbl_trim.view(6529, k)           # [6529, k]
lbl_final = lbl_reshaped.mode(dim=1).values     # majority vote, shape [6529]

# 5. Pack results
aud_pooled = {
    "embeddings": emb_pooled,
    "labels": lbl_final
}


In [7]:
aud_pooled

{'embeddings': tensor([[-0.2038, -0.2742, -0.0172,  ..., -0.0917,  0.0087,  0.0459],
         [-0.1821, -0.3958, -0.0130,  ..., -0.1146,  0.0150,  0.0079],
         [-0.1709, -0.2794,  0.0339,  ..., -0.0696,  0.1286, -0.0282],
         ...,
         [-0.1735, -0.3745, -0.0219,  ..., -0.0481, -0.0325,  0.0464],
         [-0.2033, -0.3057,  0.0142,  ..., -0.1043,  0.0705,  0.0232],
         [-0.2173, -0.2382,  0.0565,  ..., -0.0454, -0.0009,  0.0303]],
        grad_fn=<MeanBackward1>),
 'labels': tensor([0, 0, 0,  ..., 1, 0, 0])}

In [8]:
dataset = EmbeddingDataset(img, vid, aud_pooled)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
dataset.img_emb.shape, dataset.vid_emb.shape, dataset.aud_emb.shape
# for img_emb, vid_emb, aud_emb in loader:
#     print("img embeddings: ", img_emb.shape)
#     print("vid embeddings: ", vid_emb.shape)
#     print("aud embeddings: ", aud_emb.shape)

(torch.Size([6529, 512]), torch.Size([6529, 512]), torch.Size([6529, 512]))

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionClassifier(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()

        fusion_dim = embed_dim * 3   # img_emb + vid_emb = 1024

        self.fc1 = nn.Linear(fusion_dim, 512)
        self.ln1 = nn.LayerNorm(512)

        self.fc2 = nn.Linear(512, 256)
        self.ln2 = nn.LayerNorm(256)

        self.dropout = nn.Dropout(0.3)

        self.fc_out = nn.Linear(256, 2)  # For CrossEntropy (real/fake)

    def forward(self, img_emb, vid_emb, aud_emb):
        x = torch.cat([img_emb, vid_emb, aud_emb], dim=1)
        x = F.relu(self.ln1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.ln2(self.fc2(x)))
        return self.fc_out(x)

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = FusionClassifier(embed_dim=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

EPOCHS = 50

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for img_emb, vid_emb, aud_emb, labels in loader:
        
        img_emb = img_emb.detach().to(device)
        vid_emb = vid_emb.detach().to(device)
        aud_emb = aud_emb.detach().to(device)
        
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(img_emb, vid_emb, aud_emb)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss={total_loss:.4f} | Acc={acc:.4f}")


Epoch 1/50 | Loss=114.3643 | Acc=0.7134
Epoch 2/50 | Loss=98.5730 | Acc=0.7666
Epoch 3/50 | Loss=94.8755 | Acc=0.7839
Epoch 4/50 | Loss=93.2012 | Acc=0.7879
Epoch 5/50 | Loss=91.2878 | Acc=0.7906
Epoch 6/50 | Loss=89.2786 | Acc=0.7896
Epoch 7/50 | Loss=88.1678 | Acc=0.8007
Epoch 8/50 | Loss=87.8355 | Acc=0.7920
Epoch 9/50 | Loss=86.0160 | Acc=0.8009
Epoch 10/50 | Loss=84.7077 | Acc=0.8047
Epoch 11/50 | Loss=86.4613 | Acc=0.8013
Epoch 12/50 | Loss=83.0796 | Acc=0.8099
Epoch 13/50 | Loss=84.4831 | Acc=0.8046
Epoch 14/50 | Loss=82.3835 | Acc=0.8102
Epoch 15/50 | Loss=83.6175 | Acc=0.8110
Epoch 16/50 | Loss=82.4906 | Acc=0.8153
Epoch 17/50 | Loss=81.9328 | Acc=0.8127
Epoch 18/50 | Loss=82.5600 | Acc=0.8099
Epoch 19/50 | Loss=80.1833 | Acc=0.8157
Epoch 20/50 | Loss=79.3212 | Acc=0.8223
Epoch 21/50 | Loss=78.3036 | Acc=0.8229
Epoch 22/50 | Loss=78.7574 | Acc=0.8219
Epoch 23/50 | Loss=76.7085 | Acc=0.8268
Epoch 24/50 | Loss=76.9387 | Acc=0.8260
Epoch 25/50 | Loss=82.4264 | Acc=0.8134
Epoch 26

In [11]:
torch.save(model.state_dict(), "ml_model/fusion_model.pth")