In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
import torch.nn.functional as F

In [2]:
import sys
sys.path.append('../')

from modules.cross_attention import CrossAttention
from modules.dataloader import load_npy_files

In [3]:
class EmbracementLayer(nn.Module):
    def __init__(self, d_in, d_out):
        super(EmbracementLayer, self).__init__()
        self.fc = nn.Linear(d_in, d_out)
        self.norm = nn.LayerNorm(d_out)
        self.activation = nn.ReLU()

    def forward(self, video_features, audio_features, text_features):
        # Concatenate features along the last dimension
        combined_features = torch.cat([video_features, audio_features, text_features], dim=-1)
        
        # Apply linear transformation
        transformed_features = self.fc(combined_features)
        
        # Apply normalization and activation
        norm_features = self.norm(transformed_features)
        output = self.activation(norm_features)
        
        return output

In [4]:
# Linear transformation to match dimensions
class LinearTransformations(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearTransformations, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.linear(x)

In [5]:
def HadamardProduct(tensor1, tensor2):
    # Ensure both tensors have the same shape
    if tensor1.shape != tensor2.shape:
        raise ValueError("Tensors must have the same shape for Hadamard product.")
    
    # Compute the Hadamard product
    return tensor1 * tensor2

In [6]:
# Cross Attention Function
def PairCrossAttention(modalityAlpha, modalityBeta, d_out_kq=768, d_out_v=768):
    cross_attn = CrossAttention(modalityAlpha.shape[-1], modalityBeta.shape[-1], d_out_kq, d_out_v)
    modalityAlphaBeta = cross_attn(modalityAlpha, modalityBeta)
    return modalityAlphaBeta

In [7]:
if __name__ == "__main__":
    torch.manual_seed(42)
    
    # Load .npy files
    video_features = load_npy_files(r'/Users/kyleandrecastro/Documents/GitHub/features/visual')
    audio_features = load_npy_files(r'/Users/kyleandrecastro/Documents/GitHub/features/audio')
    text_features = load_npy_files(r'/Users/kyleandrecastro/Documents/GitHub/features/text')

    # Select the first file from each modality directories (for testing)
    video_file_name, video_features = video_features[0]
    audio_file_name, audio_features = audio_features[0]
    text_file_name, text_features = text_features[0]

    # # Print the file names
    # print("\nSelected File Names:")
    # print("Video file:", video_file_name)
    # print("Audio file:", audio_file_name)
    # print("Text file:", text_file_name)
    
    # Reshape features
    video_features = video_features.unsqueeze(0)  # Add batch dimension
    audio_features = audio_features.unsqueeze(0)    # Add batch dimension
    text_features = text_features.unsqueeze(0)    # Add batch dimension

    # Apply linear transformation to match dimensions
    linear_transform_video = LinearTransformations(video_features.shape[-1], 768)
    linear_transform_audio = LinearTransformations(audio_features.shape[-1], 768)
    linear_transform_text = LinearTransformations(text_features.shape[-1], 768)

    video_features = linear_transform_video(video_features)
    audio_features = linear_transform_audio(audio_features)
    text_features = linear_transform_text(text_features)
    
    # Cross-Attention for every possible pairs
    video_audio = PairCrossAttention(video_features, audio_features)
    video_text = PairCrossAttention(video_features, text_features)
    audio_video = PairCrossAttention(audio_features, video_features)
    audio_text = PairCrossAttention(audio_features, text_features)
    text_video = PairCrossAttention(text_features, video_features)
    text_audio = PairCrossAttention(text_features, audio_features)

    # Combine the Cross-Attention outputs using hadamard
    video_combined = HadamardProduct(video_audio, video_text)
    audio_combined = HadamardProduct(audio_video, audio_text)
    text_combined = HadamardProduct(text_video, text_audio)
    text_combined = text_combined.expand(1, 1, 197, 768)

    print("Video Combined Shape:", video_combined.shape)
    print("Audio Combined Shape:", audio_combined.shape)
    print("Text Combined Shape:", text_combined.shape)

    # Initialize and apply Embracement Layer
    d_in = video_combined.shape[-1] + audio_combined.shape[-1] + text_combined.shape[-1]
    embracement_layer = EmbracementLayer(d_in, 768)
    
    fused_representation = embracement_layer(video_combined, audio_combined, text_combined)

    print("Fused Representation Shape:", fused_representation.shape)

Video Combined Shape: torch.Size([1, 1, 197, 768])
Audio Combined Shape: torch.Size([1, 1, 197, 768])
Text Combined Shape: torch.Size([1, 1, 197, 768])
Fused Representation Shape: torch.Size([1, 1, 197, 768])
