In [2]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse

# ---------------------------
# Cross-Attention Block
# ---------------------------
class CrossAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super(CrossAttentionBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, embed_dim)
        )
        self.ff_norm = nn.LayerNorm(embed_dim)

    def forward(self, query, key_value):
        attn_output, _ = self.attn(query, key_value, key_value)
        out = self.norm(query + attn_output)
        out_ff = self.ff_norm(out + self.ff(out))
        return out_ff

# ---------------------------
# Adaptive Modality Selector
# ---------------------------
class AdaptiveSelector(nn.Module):
    def __init__(self, embed_dim, num_modalities=3):
        super(AdaptiveSelector, self).__init__()
        self.fc = nn.Linear(embed_dim * num_modalities, num_modalities)

    def forward(self, embeddings):
        combined = torch.cat(embeddings, dim=-1)
        weights = F.softmax(self.fc(combined), dim=-1)
        return weights

# ---------------------------
# Fusion Model
# ---------------------------
class FusionModel(nn.Module):
    def __init__(self, embed_dim, num_heads=4, num_slots=24):
        super(FusionModel, self).__init__()
        self.user_content_attn = CrossAttentionBlock(embed_dim, num_heads)
        self.user_context_attn = CrossAttentionBlock(embed_dim, num_heads)
        self.content_context_attn = CrossAttentionBlock(embed_dim, num_heads)
        self.selector = AdaptiveSelector(embed_dim, num_modalities=3)
        self.fc_out = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, num_slots)
        )

    def forward(self, user_emb, content_emb, context_emb):
        user = user_emb.unsqueeze(1)
        content = content_emb.unsqueeze(1)
        context = context_emb.unsqueeze(1)

        user_refined = self.user_content_attn(user, content) + self.user_context_attn(user, context)
        content_refined = self.user_content_attn(content, user) + self.content_context_attn(content, context)
        context_refined = self.user_context_attn(context, user) + self.content_context_attn(context, content)

        user_refined = user_refined.squeeze(1)
        content_refined = content_refined.squeeze(1)
        context_refined = context_refined.squeeze(1)

        weights = self.selector([user_refined, content_refined, context_refined])
        fused_emb = (
            weights[:, 0:1] * user_refined +
            weights[:, 1:2] * content_refined +
            weights[:, 2:3] * context_refined
        )

        slot_scores = self.fc_out(fused_emb)
        heatmap = torch.sigmoid(slot_scores)
        return heatmap


In [32]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

# ---------------------------
# Custom Dataset
# ---------------------------
class FusionDataset(Dataset):
    def __init__(self, user_file,video_file, metadata_file):
        # Load input embeddings
        df_input_vid = pd.read_csv(video_file)
        df_input_user = pd.read_csv(user_file)
        self.user_embeddings = df_input_user.iloc[: ].values.astype(np.float32)
        self.video_embeddings = df_input_vid.iloc[:,:384].values.astype(np.float32)
        self.slot_ids = df_input_vid["slot_id"].values.astype(np.int64)

        # Load metadata embeddings (fixed per row)
        df_metadata = pd.read_csv(metadata_file)
        self.metadata_embeddings = df_metadata.values.astype(np.float32)

        # Ensure alignment (row i in metadata corresponds to row i in input)
        assert self.metadata_embeddings.shape[0] == len(self.user_embeddings), "Metadata and input row count mismatch"

    def __len__(self):
        return len(self.user_embeddings)

    def __getitem__(self, idx):
        user_emb = torch.tensor(self.user_embeddings[idx])
        video_emb = torch.tensor(self.video_embeddings[idx])
        metadata_emb = torch.tensor(self.metadata_embeddings[idx])
        slot_id = torch.tensor(self.slot_ids[idx])
        return user_emb, video_emb, metadata_emb, slot_id


In [None]:
# df_meta = pd.read_csv('metadata_embeddings.csv')
# df_meta

Unnamed: 0,-7.190454006195068359e-02,9.712099283933639526e-02,-9.324948303401470184e-03,-6.264826655387878418e-02,5.963744223117828369e-02,6.435234844684600830e-02,5.576114729046821594e-02,2.086320519447326660e-02,4.424110738909803331e-05,7.791992276906967163e-02,...,-2.787113748490810394e-02,-4.807336255908012390e-02,6.197397783398628235e-02,-2.120164595544338226e-02,3.126373142004013062e-02,6.850423291325569153e-03,3.779976069927215576e-02,-3.380725160241127014e-02,-3.539919853210449219e-02,8.399283140897750854e-02
0,-0.071605,0.097135,-0.009412,-0.062509,0.059675,0.064583,0.055712,0.020920,0.000120,0.078008,...,-0.027753,-0.048074,0.062302,-0.021202,0.031166,0.006930,0.037799,-0.033908,-0.035279,0.083755
1,-0.071869,0.097151,-0.009437,-0.062648,0.059757,0.064579,0.056054,0.020933,0.000195,0.077736,...,-0.027898,-0.048080,0.061891,-0.021131,0.031179,0.006806,0.037976,-0.034157,-0.035365,0.084221
2,-0.071886,0.096961,-0.009515,-0.062791,0.059639,0.064347,0.055942,0.020867,0.000144,0.077758,...,-0.027836,-0.048053,0.061903,-0.021129,0.031344,0.006780,0.037859,-0.033988,-0.035389,0.084079
3,-0.071905,0.097223,-0.009422,-0.062683,0.059664,0.064460,0.055994,0.021125,-0.000036,0.077872,...,-0.027759,-0.047990,0.062213,-0.021202,0.031084,0.007036,0.037975,-0.034082,-0.035365,0.083965
4,-0.071917,0.097207,-0.009280,-0.062649,0.059675,0.064276,0.055778,0.020905,-0.000075,0.077951,...,-0.027953,-0.048067,0.061994,-0.021193,0.031257,0.006918,0.037764,-0.033796,-0.035447,0.084050
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
744,-0.076209,0.129769,-0.025833,-0.033387,0.076593,0.059540,0.041556,0.032412,-0.024596,0.069554,...,-0.018233,-0.043666,0.063515,-0.011683,0.009050,-0.010640,0.034465,-0.051907,-0.037282,0.087295
745,-0.076209,0.129769,-0.025833,-0.033387,0.076593,0.059540,0.041556,0.032412,-0.024596,0.069554,...,-0.018233,-0.043666,0.063515,-0.011683,0.009050,-0.010640,0.034465,-0.051907,-0.037282,0.087295
746,-0.076336,0.129694,-0.025872,-0.033433,0.076575,0.059766,0.041585,0.032389,-0.024601,0.069919,...,-0.018368,-0.043825,0.063517,-0.011463,0.009147,-0.010398,0.034655,-0.052002,-0.037296,0.087132
747,-0.076209,0.129769,-0.025833,-0.033387,0.076593,0.059540,0.041556,0.032412,-0.024596,0.069554,...,-0.018233,-0.043666,0.063515,-0.011683,0.009050,-0.010640,0.034465,-0.051907,-0.037282,0.087295


In [None]:
# # Preprocessing the csv file

# import pandas as pd
# import ast

# # Load the CSV
# df = pd.read_csv("channel_embedding_results.csv")

# # Convert the string list in embedding_response back to a Python list
# df["embedding_response"] = df["embedding_response"].apply(ast.literal_eval)

# # Expand the list into separate columns
# embeddings_df = pd.DataFrame(df["embedding_response"].tolist())

# # Rename columns as embedding_0, embedding_1, ...
# embeddings_df = embeddings_df.add_prefix("embedding_")

# # # Concatenate channel_id with the expanded embeddings
# # final_df = pd.concat([df["channel_id"], embeddings_df], axis=1)

# # Save to new CSV
# embeddings_df.to_csv("channel_embeddings_expanded.csv", index=False)

# print("Expanded CSV saved as channel_embeddings_expanded.csv")


Expanded CSV saved as channel_embeddings_expanded.csv


In [23]:
# df  = pd.read_csv('channel_embeddings_expanded.csv')



# row_to_duplicate = df.iloc[[0]]   # keep as DataFrame

# # Duplicate it 100 times
# duplicated_rows = pd.concat([row_to_duplicate] * 100, ignore_index=True)

# # Append duplicated rows back to original DataFrame
# df_extended = pd.concat([df, duplicated_rows], ignore_index=True)


In [25]:
# df_vids_emb = df_extended.copy()
# df_vids_emb['slot_id'] = 45
# df_vids_emb

Unnamed: 0,embedding_0,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_375,embedding_376,embedding_377,embedding_378,embedding_379,embedding_380,embedding_381,embedding_382,embedding_383,slot_id
0,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
1,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
2,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
3,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
4,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
96,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
97,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
98,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45
99,0.047816,0.066386,-0.03279,0.017043,-0.019379,0.012575,0.140922,0.0582,0.011558,-0.008675,...,0.014132,-0.038559,0.04619,-0.046468,0.093365,0.088916,-0.026842,-0.068403,0.104807,45


In [26]:
# df_vids_emb.to_csv("vid_embs_expanded.csv", index=False)
# df_extended.to_csv("user_embs_expanded.csv", index=False)




In [33]:
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split

# ---------------------------
# Hyperparameters
# ---------------------------
BATCH_SIZE = 128
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
EMBED_DIM = 384
NUM_HEADS = 4
NUM_SLOTS = 168

# ---------------------------
# Load Dataset
# ---------------------------
dataset = FusionDataset("user_embs_expanded.csv", "vid_embs_expanded.csv","user_embs_expanded.csv")

# Split dataset (80/10/10)
train_idx, temp_idx = train_test_split(np.arange(len(dataset)), test_size=0.2, random_state=42)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)

train_subset = torch.utils.data.Subset(dataset, train_idx)
val_subset = torch.utils.data.Subset(dataset, val_idx)
test_subset = torch.utils.data.Subset(dataset, test_idx)

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False)

# ---------------------------
# Initialize Model
# ---------------------------

model = FusionModel(embed_dim=EMBED_DIM, num_heads=NUM_HEADS, num_slots=NUM_SLOTS)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# ---------------------------
# Loss and Optimizer
# ---------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# ---------------------------
# Training Loop
# ---------------------------
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0
    for user_emb, video_emb, metadata_emb, slot_id in train_loader:
        user_emb, video_emb, metadata_emb, slot_id = user_emb.to(device), video_emb.to(device), metadata_emb.to(device), slot_id.to(device)

        optimizer.zero_grad()
        slot_scores = model(user_emb, video_emb, metadata_emb)
        loss = criterion(slot_scores, slot_id)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Validation
    model.eval()
    val_loss, top1_correct, top3_correct = 0, 0, 0
    with torch.no_grad():
        for user_emb, video_emb, metadata_emb, slot_id in val_loader:
            user_emb, video_emb, metadata_emb, slot_id = user_emb.to(device), video_emb.to(device), metadata_emb.to(device), slot_id.to(device)
            slot_scores = model(user_emb, video_emb, metadata_emb)
            val_loss += criterion(slot_scores, slot_id).item()

            probs = torch.softmax(slot_scores, dim=-1)
            top1_correct += (probs.argmax(dim=-1) == slot_id).sum().item()
            top3_correct += (torch.topk(probs, k=3, dim=-1).indices == slot_id.unsqueeze(1)).any(dim=1).sum().item()

    val_loss /= len(val_loader)
    top1_acc = top1_correct / len(val_subset)
    top3_acc = top3_correct / len(val_subset)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {running_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f} | Top1 Acc: {top1_acc:.4f} | Top3 Acc: {top3_acc:.4f}")


Epoch 1/10 | Train Loss: 5.2426 | Val Loss: 5.0738 | Top1 Acc: 0.0000 | Top3 Acc: 0.0000
Epoch 2/10 | Train Loss: 5.0736 | Val Loss: 4.9203 | Top1 Acc: 0.0000 | Top3 Acc: 0.0000
Epoch 3/10 | Train Loss: 4.9215 | Val Loss: 4.8045 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000
Epoch 4/10 | Train Loss: 4.8113 | Val Loss: 4.7285 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000
Epoch 5/10 | Train Loss: 4.7307 | Val Loss: 4.6817 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000
Epoch 6/10 | Train Loss: 4.6852 | Val Loss: 4.6520 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000
Epoch 7/10 | Train Loss: 4.6541 | Val Loss: 4.6316 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000
Epoch 8/10 | Train Loss: 4.6341 | Val Loss: 4.6161 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000
Epoch 9/10 | Train Loss: 4.6171 | Val Loss: 4.6035 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000
Epoch 10/10 | Train Loss: 4.6050 | Val Loss: 4.5928 | Top1 Acc: 1.0000 | Top3 Acc: 1.0000


In [34]:
torch.save(model.state_dict(), "fusion_model.pth")
print("Model saved to fusion_model.pth")


Model saved to fusion_model.pth
