In [1]:
import torch
import torch.nn as nn
import numpy as np
import os
import random
import torch.nn.functional as F

In [2]:
class CrossAttention(nn.Module):
    def __init__(self, d_in_1, d_in_2, d_out_kq, d_out_v):
        super(CrossAttention, self).__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Parameter(torch.rand(d_in_1, d_out_kq))
        self.W_key = nn.Parameter(torch.rand(d_in_2, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in_2, d_out_v))
    
    def forward(self, modality1, modality2):
        
        queries = modality1.matmul(self.W_query)  
        keys = modality2.matmul(self.W_key)      
        values = modality2.matmul(self.W_value) 

        attn_scores = queries.matmul(keys.transpose(-2, -1))
        attn_weights = torch.softmax(attn_scores / (self.d_out_kq ** 0.5), dim=-1)
        context_vector = attn_weights.matmul(values) 
        return context_vector

In [3]:
class EmbracementLayer(nn.Module):
    def __init__(self, docking_dim, output_dim, num_modalities):
        super(EmbracementLayer, self).__init__()
        self.num_modalities = num_modalities
        self.docking_dim = docking_dim
        self.output_dim = output_dim

        # Linear layers to transform concatenated features to the output dimension
        self.linear = nn.Linear(docking_dim, output_dim)

        # Weights for each modality
        self.modality_weights = nn.Parameter(torch.ones(num_modalities))

    def forward(self, modalities):
        # Ensure that modalities have the same shape
        assert all(modalities[0].shape == modality.shape for modality in modalities), \
            "All modality tensors must have the same shape."

        # Concatenate the modalities along the last dimension
        combined = torch.cat(modalities, dim=-1)

        # Print the shapes for debugging
        print("Combined shape:", combined.shape)
        print("Modality weights shape:", self.modality_weights.shape)

        # Apply modality weights
        modality_weights = F.softmax(self.modality_weights, dim=0)

        # Expand modality weights to match features dimension
        num_features_per_modality = combined.size(-1) // self.num_modalities
        modality_weights_expanded = modality_weights.unsqueeze(-1).expand(-1, -1, num_features_per_modality)

        # Split combined tensor into chunks
        combined_splits = combined.chunk(self.num_modalities, dim=-1)

        # Ensure modality_weights shape matches combined_splits
        modality_weights = modality_weights.view(1, 1, 1, -1).expand(combined.size(0), combined.size(1), combined.size(2), -1)
        
        # Compute weighted sum
        weighted_combined = sum(weight * split for weight, split in zip(modality_weights_expanded.unbind(), combined_splits))
        # Flatten dimensions before applying the linear layer
        combined_view = weighted_combined.view(weighted_combined.size(0), -1, weighted_combined.size(-1))
        fused = self.linear(combined_view)
        return fused


In [4]:
# Linear transformation to match dimensions
class LinearTransformations(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearTransformations, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.linear(x)

In [5]:
def load_npy_files(directory):
    file_list = [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith('.npy')]
    feature_vectors = [(file, torch.tensor(np.load(file))) for file in file_list]
    return feature_vectors

In [6]:
# Cross Attention Function
def PairCrossAttention(modalityAlpha, modalityBeta, d_out_kq=768, d_out_v=768):
    cross_attn = CrossAttention(modalityAlpha.shape[-1], modalityBeta.shape[-1], d_out_kq, d_out_v)
    modalityAlphaBeta = cross_attn(modalityAlpha, modalityBeta)
    return modalityAlphaBeta

In [7]:
if __name__ == "__main__":
    torch.manual_seed(42)
    
    # Load .npy files
    video_features = load_npy_files(r'/Users/kyleandrecastro/Documents/GitHub/features/visual')
    audio_features = load_npy_files(r'/Users/kyleandrecastro/Documents/GitHub/features/audio')
    text_features = load_npy_files(r'/Users/kyleandrecastro/Documents/GitHub/features/text')

    # Select the first file from each modality directories (for testing)
    video_file_name, video_features = video_features[0]
    audio_file_name, audio_features = audio_features[0]
    text_file_name, text_features = text_features[0]

    # # Print the file names
    # print("\nSelected File Names:")
    # print("Video file:", video_file_name)
    # print("Audio file:", audio_file_name)
    # print("Text file:", text_file_name)
    
    # Reshape features
    video_features = video_features.unsqueeze(0)  # Add batch dimension
    audio_features = audio_features.unsqueeze(0)    # Add batch dimension
    text_features = text_features.unsqueeze(0)    # Add batch dimension

    # Apply linear transformation to match dimensions
    linear_transform_video = LinearTransformations(video_features.shape[-1], 768)
    linear_transform_audio = LinearTransformations(audio_features.shape[-1], 768)
    linear_transform_text = LinearTransformations(text_features.shape[-1], 768)

    video_features = linear_transform_video(video_features)
    audio_features = linear_transform_audio(audio_features)
    text_features = linear_transform_text(text_features)
    
    # Cross-Attention for every possible pairs
    video_audio = PairCrossAttention(video_features, audio_features)
    video_text = PairCrossAttention(video_features, text_features)
    audio_video = PairCrossAttention(audio_features, video_features)
    audio_text = PairCrossAttention(audio_features, text_features)
    text_video = PairCrossAttention(text_features, video_features)
    text_audio = PairCrossAttention(text_features, audio_features)

    # Combine the Cross-Attention outputs
    video_combined = torch.cat((video_audio, video_text), dim=-1)
    audio_combined = torch.cat((audio_video, audio_text), dim=-1)
    text_combined = torch.cat((text_video, text_audio), dim=-1)

    # Adjust text_combined to match dimensions
    text_combined = text_combined.expand(-1, -1, video_combined.size(2), -1)

    print("Video Combined Shape:", video_combined.shape)
    print("Audio Combined Shape:", audio_combined.shape)
    print("Text Combined Shape:", text_combined.shape)

    # Instantiate Embracement Layer
    embracement_layer = EmbracementLayer(docking_dim=video_combined.shape[-1] * 3, output_dim=768, num_modalities=3)

    # Fuse the combined outputs
    fused_representation = embracement_layer([video_combined, audio_combined, text_combined])

    print("Fused Representation Shape:", fused_representation.shape)

Video Combined Shape: torch.Size([1, 1, 197, 1536])
Audio Combined Shape: torch.Size([1, 1, 197, 1536])
Text Combined Shape: torch.Size([1, 1, 197, 1536])
Combined shape: torch.Size([1, 1, 197, 4608])
Modality weights shape: torch.Size([3])


RuntimeError: The expanded size of the tensor (-1) isn't allowed in a leading, non-existing dimension 0