In [44]:
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, KFold

In [45]:
# Append the path for module imports
sys.path.append('../')

# Import custom modules
from modules.cross_attention import CrossAttention
from modules.dataloader import load_npy_files
from modules.classifier import DenseLayer, BCELoss
from modules.linear_transformation import LinearTransformations

In [46]:
def collate_fn(batch):
    text_data, audio_data, video_data, label_data = zip(*batch)

    # Convert lists to tensors
    text_data = torch.stack(text_data)
    audio_data = torch.stack(audio_data)

    # Padding for video data
    # Determine maximum length of video sequences in the batch
    video_lengths = [v.size(0) for v in video_data]
    max_length = max(video_lengths)

    # Pad video sequences to the maximum length
    video_data_padded = torch.stack([
        F.pad(v, (0, 0, 0, max_length - v.size(0)), "constant", 0)
        for v in video_data
    ])

    # Convert labels to tensor and ensure the shape [batch_size, 1]
    label_data = torch.stack(label_data)  # Convert list of tensors to a single tensor

    return text_data, audio_data, video_data_padded, label_data

In [47]:
# Load the labels DataFrame
id_label_df = pd.read_excel('/Users/kyleandrecastro/Documents/GitHub/SMCA/misc/MM-Trailer_dataset.xlsx')

# Define the directories
text_features_dir = '/Users/kyleandrecastro/Documents/GitHub/SMCA/misc/textStream_BERT/feature_vectors/test'
audio_features_dir = '/Users/kyleandrecastro/Documents/GitHub/SMCA/misc/audio_fe/logmel_spectrograms/test'
video_features_dir = '/Users/kyleandrecastro/Documents/GitHub/features/visual'

# Load the feature vectors from each directory
text_features = load_npy_files(text_features_dir)
audio_features = load_npy_files(audio_features_dir)
video_features = load_npy_files(video_features_dir)

# # Splitting data for training, validation, and testing
# train_df, val_test_df = train_test_split(id_label_df, test_size=0.3, random_state=42)

# # Further splitting remaining set into validation and test sets
# val_df, test_df = train_test_split(val_test_df, test_size=0.5, random_state=42)

# # Create datasets
# train_dataset = MultimodalDataset(train_df, text_features, audio_features, video_features)
# val_dataset = MultimodalDataset(val_df, text_features, audio_features, video_features)
# test_dataset = MultimodalDataset(test_df, text_features, audio_features, video_features)

# # Create DataLoaders
# train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)
# val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)
# test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)

# # Combine all data for K-fold cross-validation
# full_dataset = MultimodalDataset(id_label_df, text_features, audio_features, video_features)

In [48]:
# Cross Attention Function
def PairCrossAttention(modalityAlpha, modalityBeta, d_out_kq=768, d_out_v=768):
    cross_attn = CrossAttention(modalityAlpha.shape[-1], modalityBeta.shape[-1], d_out_kq, d_out_v)
    modalityAlphaBeta = cross_attn(modalityAlpha, modalityBeta)
    return modalityAlphaBeta

In [49]:
def HadamardProduct(tensor1, tensor2):
    # Ensure both tensors have the same shape
    if tensor1.shape != tensor2.shape:
        raise ValueError("Tensors must have the same shape for Hadamard product.")
    
    # Compute the Hadamard product
    return tensor1 * tensor2

In [50]:
class Flatten(nn.Module):
    def forward(self, x):
        # Flatten the input tensor except the batch dimension
        return x.view(x.size(0), -1)

In [51]:
class EmbracementLayer(nn.Module):
    def __init__(self, d_in, d_out):
        super(EmbracementLayer, self).__init__()
        self.fc = nn.Linear(d_in, d_out)
        self.norm = nn.LayerNorm(d_out)
        self.activation = nn.ReLU()

    def forward(self, video_features, audio_features, text_features):
        # Concatenate features along the last dimension
        combined_features = torch.cat([video_features, audio_features, text_features], dim=-1)
        
        # Apply linear transformation
        transformed_features = self.fc(combined_features)
        
        # Apply normalization and activation
        norm_features = self.norm(transformed_features)
        output = self.activation(norm_features)
        
        return output

In [52]:
if __name__ == "__main__":
    torch.manual_seed(42)

    # Select the first file from each modality directories (for testing)
    text_file_name, text_features = text_features[0]
    audio_file_name, audio_features = audio_features[0]
    video_file_name, video_features = video_features[0]
    
    print("Text file name:", text_file_name)
    print("Audio file name:", audio_file_name)
    print("Video file name:", video_file_name)

    print("Text features shape:", text_features.shape)
    print("Audio features shape:", audio_features.shape)
    print("Video features shape:", video_features.shape, '\n')

    # Reshape feature
    # video_features = video_features.unsqueeze(0)  # Add batch dimension   
    audio_features = audio_features.squeeze(0)
    text_features = text_features.unsqueeze(0)    # Add batch dimension

    print("text_features shape:", text_features.shape)
    print("audio_features shape:", audio_features.shape)
    print("video_features shape:", video_features.shape, '\n')
    
    # Cross-Attention for every possible pairs
    video_audio = PairCrossAttention(video_features, audio_features)
    video_text = PairCrossAttention(video_features, text_features)
    audio_video = PairCrossAttention(audio_features, video_features)
    audio_text = PairCrossAttention(audio_features, text_features)
    text_video = PairCrossAttention(text_features, video_features)
    text_audio = PairCrossAttention(text_features, audio_features)

    # Print shapes of cross-attention results
    print("video_audio shape:", video_audio.shape)
    print("video_text shape:", video_text.shape)
    print("audio_video shape:", audio_video.shape)
    print("audio_text shape:", audio_text.shape)
    print("text_video shape:", text_video.shape)
    print("text_audio shape:", text_audio.shape, '\n')

    # Combine the Cross-Attention outputs using hadamard
    video_combined = HadamardProduct(video_audio, video_text)
    audio_combined = HadamardProduct(audio_video, audio_text)
    text_combined = HadamardProduct(text_video, text_audio)
    
    print("Text Combined Shape:", text_combined.shape)
    print("Video Combined Shape:", video_combined.shape)
    print("Audio Combined Shape:", audio_combined.shape, '\n')
    
    # Fusion using Embracement Layer
    d_in = video_combined.shape[-1] + audio_combined.shape[-1] + text_combined.shape[-1]
    embracement_layer = EmbracementLayer(d_in, d_in)
    fused_features = embracement_layer(video_combined[-1], audio_combined[-1], text_combined[-1])

    print("Fused Features Shape:", fused_features.shape)

Text file name: /Users/kyleandrecastro/Documents/GitHub/SMCA/misc/textStream_BERT/feature_vectors/test/text_tt0021814.npy
Audio file name: /Users/kyleandrecastro/Documents/GitHub/SMCA/misc/audio_fe/logmel_spectrograms/test/audio_tt0021814.npy
Video file name: /Users/kyleandrecastro/Documents/GitHub/features/visual/tt0021814_features.npy
Text features shape: torch.Size([1024])
Audio features shape: torch.Size([1, 197, 768])
Video features shape: torch.Size([95, 768]) 

text_features shape: torch.Size([1, 1024])
audio_features shape: torch.Size([197, 768])
video_features shape: torch.Size([95, 768]) 

video_audio shape: torch.Size([95, 768])
video_text shape: torch.Size([95, 768])
audio_video shape: torch.Size([197, 768])
audio_text shape: torch.Size([197, 768])
text_video shape: torch.Size([1, 768])
text_audio shape: torch.Size([1, 768]) 

Text Combined Shape: torch.Size([1, 768])
Video Combined Shape: torch.Size([95, 768])
Audio Combined Shape: torch.Size([197, 768]) 

Fused Features S