In [26]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
class EmbeddingDataset(Dataset):
    def __init__(self, img, vid, aud):

        self.img_emb = img["embeddings"]
        self.vid_emb = vid["embeddings"]
        self.aud_emb = aud["embeddings"]
        self.labels = torch.logical_and(img["labels"] == 1, aud["labels"] == 1).long()

        assert len(self.img_emb) == len(self.vid_emb)
        # print(len(self.img_emb), len(self.vid_emb), len(self.aud_emb))
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (
            self.img_emb[idx],       # (6529)
            self.vid_emb[idx],       # (6529)
            self.aud_emb[idx],       # (53868)
            self.labels[idx]         # scalar
        )


In [6]:
img_path = "embeddings/image_embeddings.pt"
vid_path = "embeddings/video_embeddings.pt"
aud_path = "embeddings/audio_embeddings.pt"
img = torch.load(img_path)
vid = torch.load(vid_path)
aud = torch.load(aud_path)


{'embeddings': tensor([[-0.0401, -0.0837,  0.0470,  ...,  0.0987,  0.4234, -0.1591],
         [ 0.2385, -0.0045, -0.1886,  ...,  0.3085,  0.1782, -0.0213],
         [-0.0053,  0.0596,  0.1795,  ...,  0.0959,  0.4545, -0.1240],
         ...,
         [ 0.1586,  0.0167,  0.3634,  ...,  0.2028,  0.3787, -0.1390],
         [ 0.1164,  0.0420,  0.2573,  ...,  0.0286,  0.4332,  0.1259],
         [-0.1189,  0.0500,  0.1513,  ...,  0.1987,  0.3084, -0.3461]]),
 'labels': tensor([0, 1, 1,  ..., 0, 0, 1])}

In [8]:
aud["labels"].shape

torch.Size([53868])

In [12]:
type(aud)

dict

In [9]:
aud_projection = nn.Linear(768, 512)
aud_embed_512 = aud_projection(aud["embeddings"])
dim_reduced_aud = {}
dim_reduced_aud["embeddings"] = aud_embed_512
dim_reduced_aud["labels"] = aud["labels"]
dim_reduced_aud

{'embeddings': tensor([[-0.0872, -0.0255,  0.0847,  ...,  0.0452, -0.0935, -0.0173],
         [-0.4114,  0.2869, -0.0347,  ...,  0.1379,  0.0814,  0.0863],
         [-0.2068,  0.0483, -0.0170,  ..., -0.0067, -0.1485, -0.0904],
         ...,
         [-0.0726, -0.0470,  0.1386,  ...,  0.1219, -0.1404, -0.1497],
         [-0.1465, -0.0144,  0.1058,  ...,  0.0989, -0.0414, -0.2342],
         [-0.2574,  0.0760, -0.0037,  ..., -0.0108, -0.1635, -0.1294]],
        grad_fn=<AddmmBackward0>),
 'labels': tensor([0, 1, 1,  ..., 0, 0, 1])}

In [19]:
emb = dim_reduced_aud["embeddings"]  
lbl = dim_reduced_aud["labels"] 

B, D = emb.shape
k = B // 6529
B_new = 6529 * k

perm = torch.randperm(B)
emb_shuffled = emb[perm]
lbl_shuffled = lbl[perm]

emb_trim = emb_shuffled[:B_new]
lbl_trim = lbl_shuffled[:B_new]

emb_reshaped = emb_trim.view(6529, k, 512)      # [6529, k, 512]
emb_pooled = emb_reshaped.mean(dim=1)           # [6529, 512]

lbl_reshaped = lbl_trim.view(6529, k)           # [6529, k]
lbl_final = lbl_reshaped.mode(dim=1).values     # majority vote, shape [6529]

# 5. Pack results
aud_pooled = {
    "embeddings": emb_pooled,
    "labels": lbl_final
}


In [20]:
aud_pooled

{'embeddings': tensor([[-0.2268,  0.0517,  0.0248,  ...,  0.0564, -0.0724, -0.0970],
         [-0.1260,  0.1102,  0.0634,  ...,  0.0185, -0.0122, -0.1263],
         [-0.2011,  0.1134,  0.0098,  ...,  0.0105, -0.0329, -0.1198],
         ...,
         [-0.1055,  0.0473,  0.0776,  ...,  0.0540, -0.0653, -0.1282],
         [-0.2200,  0.0602,  0.0294,  ..., -0.0005, -0.0868, -0.1063],
         [-0.1973,  0.1149, -0.0024,  ...,  0.0216, -0.0781, -0.0973]],
        grad_fn=<MeanBackward1>),
 'labels': tensor([0, 0, 1,  ..., 0, 0, 1])}

In [27]:
dataset = EmbeddingDataset(img, vid, aud_pooled)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
dataset.img_emb.shape, dataset.vid_emb.shape, dataset.aud_emb.shape
# for img_emb, vid_emb, aud_emb in loader:
#     print("img embeddings: ", img_emb.shape)
#     print("vid embeddings: ", vid_emb.shape)
#     print("aud embeddings: ", aud_emb.shape)

(torch.Size([6529, 512]), torch.Size([6529, 512]), torch.Size([6529, 512]))

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionClassifier(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()

        fusion_dim = embed_dim * 3   # img_emb + vid_emb = 1024

        self.fc1 = nn.Linear(fusion_dim, 512)
        self.ln1 = nn.LayerNorm(512)

        self.fc2 = nn.Linear(512, 256)
        self.ln2 = nn.LayerNorm(256)

        self.dropout = nn.Dropout(0.3)

        self.fc_out = nn.Linear(256, 2)  # For CrossEntropy (real/fake)

    def forward(self, img_emb, vid_emb, aud_emb):
        x = torch.cat([img_emb, vid_emb, aud_emb], dim=1)
        x = F.relu(self.ln1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.ln2(self.fc2(x)))
        return self.fc_out(x)

In [29]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = FusionClassifier(embed_dim=512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

EPOCHS = 1000

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for img_emb, vid_emb, aud_emb, labels in loader:
        
        img_emb = img_emb.detach().to(device)
        vid_emb = vid_emb.detach().to(device)
        aud_emb = aud_emb.detach().to(device)
        
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(img_emb, vid_emb, aud_emb)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss={total_loss:.4f} | Acc={acc:.4f}")


Epoch 1/1000 | Loss=118.0835 | Acc=0.7032
Epoch 2/1000 | Loss=97.7033 | Acc=0.7689
Epoch 3/1000 | Loss=94.9402 | Acc=0.7850
Epoch 4/1000 | Loss=96.8790 | Acc=0.7738
Epoch 5/1000 | Loss=93.1471 | Acc=0.7915
Epoch 6/1000 | Loss=89.5215 | Acc=0.7912
Epoch 7/1000 | Loss=87.7378 | Acc=0.7978
Epoch 8/1000 | Loss=89.4216 | Acc=0.7948
Epoch 9/1000 | Loss=89.0548 | Acc=0.8021
Epoch 10/1000 | Loss=88.0696 | Acc=0.8007
Epoch 11/1000 | Loss=84.5410 | Acc=0.8112
Epoch 12/1000 | Loss=86.2078 | Acc=0.8099
Epoch 13/1000 | Loss=85.8886 | Acc=0.8059
Epoch 14/1000 | Loss=83.3070 | Acc=0.8121
Epoch 15/1000 | Loss=83.7456 | Acc=0.8134
Epoch 16/1000 | Loss=84.8978 | Acc=0.8104
Epoch 17/1000 | Loss=81.6696 | Acc=0.8187
Epoch 18/1000 | Loss=81.2220 | Acc=0.8151
Epoch 19/1000 | Loss=80.0931 | Acc=0.8210
Epoch 20/1000 | Loss=80.4051 | Acc=0.8180
Epoch 21/1000 | Loss=78.6940 | Acc=0.8251
Epoch 22/1000 | Loss=77.5068 | Acc=0.8280
Epoch 23/1000 | Loss=77.6096 | Acc=0.8266
Epoch 24/1000 | Loss=78.0516 | Acc=0.8269


In [31]:
torch.save(model.state_dict(), "ml_model/fusion_model.pth")