In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence


In [2]:
# Set float32 matmul precision for Tensor Cores
torch.set_float32_matmul_precision('medium')

In [3]:
class MultimodalDataset(Dataset):
    def __init__(self, label_df, text_dir, audio_dir, video_dir):
        self.label_df = label_df
        self.text_dir = text_dir
        self.audio_dir = audio_dir
        self.video_dir = video_dir

        # List to store missing files
        self.missing_files = []

        # Filter out entries with missing files
        self.valid_indices = self._filter_valid_files()

    def _filter_valid_files(self):
        valid_indices = []
        for idx in range(len(self.label_df)):
            imdbid = self.label_df.iloc[idx]['IMDBid']

            text_path = os.path.join(self.text_dir, f"{imdbid}.npy")
            audio_path = os.path.join(self.audio_dir, f"feature_{imdbid}.npy")
            video_path = os.path.join(self.video_dir, f"{imdbid}_features.npy")

            missing_files = []
            if not os.path.exists(text_path):
                missing_files.append(text_path)
            if not os.path.exists(audio_path):
                missing_files.append(audio_path)
            if not os.path.exists(video_path):
                missing_files.append(video_path)

            if missing_files:
                self.missing_files.append({
                    'IMDBid': imdbid,
                    'missing_files': missing_files,
                    'missing_count': len(missing_files)
                })
            else:
                valid_indices.append(idx)

        # Print missing files after checking all
        if self.missing_files:
            print("Missing files:")
            total_missing_files = 0
            for item in self.missing_files:
                print(f"IMDBid: {item['IMDBid']} (Missing {item['missing_count']} files)")
                total_missing_files += item['missing_count']
                for file in item['missing_files']:
                    print(f"  Missing file: {file}")
            print(f"Total missing files: {total_missing_files}")
            print(f"Total IMDB IDs with missing files: {len(self.missing_files)}")
        else:
            print("No missing files.")

        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Get the original index from the filtered valid indices
        original_idx = self.valid_indices[idx]
        imdbid = self.label_df.iloc[original_idx]['IMDBid']
        label = self.label_df.iloc[original_idx]['Label']

        text_path = os.path.join(self.text_dir, f"{imdbid}.npy")
        audio_path = os.path.join(self.audio_dir, f"feature_{imdbid}.npy")
        video_path = os.path.join(self.video_dir, f"{imdbid}_features.npy")

        text_data = np.load(text_path) if os.path.exists(text_path) else np.zeros((1024,))
        audio_data = np.load(audio_path) if os.path.exists(audio_path) else np.zeros((1, 197, 768))
        video_data = np.load(video_path) if os.path.exists(video_path) else np.zeros((95, 768))

        # Convert to torch tensors
        text_data = torch.tensor(text_data, dtype=torch.float32)
        audio_data = torch.tensor(audio_data, dtype=torch.float32)
        video_data = torch.tensor(video_data, dtype=torch.float32)

        return text_data, audio_data, video_data, label

In [4]:
# Define label mapping
label_map = {'red': 0, 'green': 1}  # Add other labels as needed

def collate_fn(batch):
    text_data, audio_data, video_data, labels = zip(*batch)

    # Convert lists to tensors
    text_data = torch.stack(text_data)
    audio_data = torch.stack(audio_data)

    # Padding for video data
    video_lengths = [v.size(0) for v in video_data]
    max_length = max(video_lengths)
    video_data_padded = torch.stack([torch.cat([v, torch.zeros(max_length - v.size(0), v.size(1))]) for v in video_data])

    # Convert labels to tensor using label_map
    try:
        labels = torch.tensor([label_map[label] for label in labels], dtype=torch.long)  # Ensure labels are integers
    except KeyError as e:
        print(f"Error: Label '{e}' not found in label_map.")
        raise

    return text_data, audio_data, video_data_padded, labels


In [5]:
# Load the labels DataFrame
label_df = pd.read_excel('C:\\Users\\edjin\\OneDrive\\Documents\\Programming Files\\Thesis\\SMCA\\misc\\MM-Trailer_dataset.xlsx')

# Splitting data into training and remaining sets
train_df, remaining_df = train_test_split(label_df, test_size=0.3, random_state=42)

# Further splitting remaining set into validation and test sets
val_df, test_df = train_test_split(remaining_df, test_size=0.5, random_state=42)

# Directory paths
text_dir = 'C:\\Users\\edjin\\OneDrive\\Documents\\Programming Files\\Thesis\\SMCA\\misc\\textStream_BERT\\feature_vectors\\feature_vectors'
audio_dir = 'C:\\Users\\edjin\\OneDrive\\Documents\\Programming Files\\Thesis\\SMCA\\misc\\audio_fe\\logmel_spectrograms'
video_dir = 'C:\\Users\\edjin\\OneDrive\\Documents\\Programming Files\\Thesis\\SMCA\\misc\\visualStream_ViT\\feature_vectors'

# Create datasets
train_dataset = MultimodalDataset(train_df, text_dir, audio_dir, video_dir)
val_dataset = MultimodalDataset(val_df, text_dir, audio_dir, video_dir)
test_dataset = MultimodalDataset(test_df, text_dir, audio_dir, video_dir)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn)


Missing files:
IMDBid: tt2494280 (Missing 3 files)
  Missing file: C:\Users\edjin\OneDrive\Documents\Programming Files\Thesis\SMCA\misc\textStream_BERT\feature_vectors\feature_vectors\tt2494280.npy
  Missing file: C:\Users\edjin\OneDrive\Documents\Programming Files\Thesis\SMCA\misc\audio_fe\logmel_spectrograms\feature_tt2494280.npy
  Missing file: C:\Users\edjin\OneDrive\Documents\Programming Files\Thesis\SMCA\misc\visualStream_ViT\feature_vectors\tt2494280_features.npy
IMDBid: tt1724962 (Missing 3 files)
  Missing file: C:\Users\edjin\OneDrive\Documents\Programming Files\Thesis\SMCA\misc\textStream_BERT\feature_vectors\feature_vectors\tt1724962.npy
  Missing file: C:\Users\edjin\OneDrive\Documents\Programming Files\Thesis\SMCA\misc\audio_fe\logmel_spectrograms\feature_tt1724962.npy
  Missing file: C:\Users\edjin\OneDrive\Documents\Programming Files\Thesis\SMCA\misc\visualStream_ViT\feature_vectors\tt1724962_features.npy
IMDBid: tt1152836 (Missing 3 files)
  Missing file: C:\Users\edji

In [6]:
# Function to print a sample from the dataset
def print_sample(dataset, index):
    text_data, audio_data, video_data, label = dataset[index]
    print(f"Sample {index}:")
    print("Text Data:", text_data)
    print("Audio Data:", audio_data)
    print("Video Data:", video_data)
    print("-" * 30)
    print("Text Data Shape:", text_data.shape)
    print("Audio Data Shape:", audio_data.shape)
    print("Video Data Shape:", video_data.shape)
    print("Label:", label)
    print("-" * 30)

# Print a sample from each dataset
print("Training Dataset Sample:")
print_sample(train_dataset, 5)  # Change 5 to any index to view different samples

print("Validation Dataset Sample:")
print_sample(val_dataset, 5)  # Change 5 to any index to view different samples

print("Test Dataset Sample:")
print_sample(test_dataset, 5)  # Change 5 to any index to view different samples

Training Dataset Sample:
Sample 5:
Text Data: tensor([ 0.5833, -0.4434, -0.8684,  ..., -0.5946, -0.7582,  0.4914])
Audio Data: tensor([[[-0.3771, -0.1759, -0.0416,  ...,  0.1546,  0.1987, -0.0360],
         [-0.3328, -0.3164, -0.1281,  ...,  0.1952,  0.1640, -0.1560],
         [-0.3665, -0.0383, -0.0546,  ...,  0.1496,  0.1882, -0.1261],
         ...,
         [-0.3427, -0.1351, -0.1521,  ...,  0.1828,  0.0589, -0.0564],
         [-0.3111, -0.0340, -0.1754,  ...,  0.1957,  0.1211, -0.1917],
         [-0.3148, -0.1266, -0.1421,  ...,  0.0888,  0.1190, -0.0885]]])
Video Data: tensor([[-0.0010,  0.0832, -0.0718,  ..., -0.3403, -0.1714,  0.2119],
        [-0.0048,  0.1294,  0.0153,  ..., -0.3646, -0.1083,  0.1913],
        [-0.0551, -0.1336, -0.1215,  ..., -0.0564, -0.0422,  0.1340],
        ...,
        [ 0.1419,  0.0912, -0.1560,  ...,  0.0119, -0.0102,  0.0301],
        [-0.0872, -0.2430,  0.1118,  ..., -0.0621,  0.2087, -0.1113],
        [-0.1739,  0.0709, -0.0382,  ..., -0.2455,  0.05

In [7]:
def print_dataloader_samples(dataloader, num_batches=1):
    """
    Print a few batches from the DataLoader to inspect the data.

    Args:
        dataloader (DataLoader): The DataLoader instance.
        num_batches (int): The number of batches to print.
    """
    for i, batch in enumerate(dataloader):
        if i >= num_batches:
            break
        
        text_data, audio_data, video_data, labels = batch

        # Convert labels to a list of integers if they are tensors
        if isinstance(labels, torch.Tensor):
            labels = labels.tolist()

        print(f"Batch {i}:")
        print("Text Data Shape:", text_data.shape)
        print("Audio Data Shape:", audio_data.shape)
        print("Video Data Shape:", video_data.shape)
        print("Labels:", labels)
        print("-" * 30)

# Print a few batches from the training DataLoader
print("Training DataLoader Samples:")
print_dataloader_samples(train_dataloader, num_batches=5)

# Print a few batches from the validation DataLoader
print("Validation DataLoader Samples:")
print_dataloader_samples(val_dataloader, num_batches=5)

# Print a few batches from the validation DataLoader
print("Validation DataLoader Samples:")
print_dataloader_samples(test_dataloader, num_batches=5)

Training DataLoader Samples:
Batch 0:
Text Data Shape: torch.Size([8, 1024])
Audio Data Shape: torch.Size([8, 1, 197, 768])
Video Data Shape: torch.Size([8, 167, 768])
Labels: [0, 1, 1, 1, 1, 0, 0, 1]
------------------------------
Batch 1:
Text Data Shape: torch.Size([8, 1024])
Audio Data Shape: torch.Size([8, 1, 197, 768])
Video Data Shape: torch.Size([8, 141, 768])
Labels: [1, 1, 0, 0, 1, 1, 1, 1]
------------------------------
Batch 2:
Text Data Shape: torch.Size([8, 1024])
Audio Data Shape: torch.Size([8, 1, 197, 768])
Video Data Shape: torch.Size([8, 142, 768])
Labels: [1, 1, 1, 1, 0, 1, 1, 1]
------------------------------
Batch 3:
Text Data Shape: torch.Size([8, 1024])
Audio Data Shape: torch.Size([8, 1, 197, 768])
Video Data Shape: torch.Size([8, 141, 768])
Labels: [1, 1, 1, 1, 1, 0, 1, 1]
------------------------------
Batch 4:
Text Data Shape: torch.Size([8, 1024])
Audio Data Shape: torch.Size([8, 1, 197, 768])
Video Data Shape: torch.Size([8, 148, 768])
Labels: [1, 0, 1, 1,

In [8]:
class GatedMultimodalUnit(torch.nn.Module):
    def __init__(self, text_dim, audio_dim, video_dim, output_dim):
        super(GatedMultimodalUnit, self).__init__()
        
        # Linear transformation for text features
        self.text_linear = nn.Linear(text_dim, output_dim)
        
        # Convolutional layers for audio and video features
        self.audio_conv = nn.Conv1d(audio_dim, output_dim, kernel_size=1)
        self.video_conv = nn.Conv1d(video_dim, output_dim, kernel_size=1)
        
        # Activation functions
        self.activation = nn.Tanh()
        self.gate_activation = nn.Sigmoid() # Logistic()
        
        # Weight matrices for each modality
        self.W1 = nn.Parameter(torch.Tensor(output_dim, output_dim))
        self.W2 = nn.Parameter(torch.Tensor(output_dim, output_dim))
        self.W3 = nn.Parameter(torch.Tensor(output_dim, output_dim))
        
        
        # Gating matrices
        self.Y1 = nn.Parameter(torch.Tensor(output_dim, output_dim * 3))
        self.Y2 = nn.Parameter(torch.Tensor(output_dim, output_dim * 3))
        self.Y3 = nn.Parameter(torch.Tensor(output_dim, output_dim * 3))
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        
        # Initialize weight matrices
        init.xavier_uniform_(self.W1)
        init.xavier_uniform_(self.W2)
        init.xavier_uniform_(self.W3)
        
        # Initialize gating matrices
        init.xavier_uniform_(self.Y1)
        init.xavier_uniform_(self.Y2)
        init.xavier_uniform_(self.Y3)
        
        
    def forward(self, text_features, audio_features, video_features):
        
       # Process text features
        x_t = self.text_linear(text_features)  # Shape: [batch_size, output_dim]
        h1 = self.activation(torch.matmul(x_t, self.W1))  # Shape: [batch_size, output_dim]

        print('-'*50)
        print("x_t Shape:", x_t.shape)
        print("h1 Shape:", h1.shape)
        print("W1 Shape:", self.W1.shape)


        # Process audio features
        audio_features = audio_features.squeeze(1).permute(0, 2, 1)  # Shape: [batch_size, audio_dim, sequence_length]
        x_a = self.audio_conv(audio_features).mean(dim=-1)  # Shape: [batch_size, output_dim]
        h2 = self.activation(torch.matmul(x_a, self.W2))  # Shape: [batch_size, output_dim]
        
        print('-'*50)
        print("x_a Shape:", x_a.shape)
        print("h2 Shape:", h2.shape)
        print("W2 Shape:", self.W2.shape)
        
        # Process video features
        video_features = video_features.permute(0, 2, 1)  # Shape: [batch_size, video_dim, sequence_length]
        x_v = self.video_conv(video_features).mean(dim=-1)  # Shape: [batch_size, output_dim]
        h3 = self.activation(torch.matmul(x_v, self.W3))  # Shape: [batch_size, output_dim]
        
        print('-'*50)
        print("x_v Shape:", x_v.shape)
        print("h3 Shape:", h3.shape)
        print("W13 Shape:", self.W3.shape)
        
        # Combine processed features
        x = torch.cat((x_t, x_a, x_v), dim=1)
               
        # Compute gating weights for each modality
        z1 = self.gate_activation(torch.matmul(x, self.Y1.t()))
        z2 = self.gate_activation(torch.matmul(x, self.Y2.t()))
        z3 = self.gate_activation(torch.matmul(x, self.Y3.t()))
        
        print('-'*50)
        print("z1 Shape:", z1.shape)
        print("z2 Shape:", z2.shape)
        print("z3 Shape:", z3.shape)
        
        # Calculate final output
        h = z1 * h1 + z2 * h2 + z3 * h3
        return h

In [9]:
# Define dimensions
text_dim = 1024
audio_dim = 768  # Number of channels in audio data
video_dim = 768  # Number of channels in video data
output_dim = 512  # You can set this to any value, depending on your requirements

# Instantiate the GMU model
gmu = GatedMultimodalUnit(text_dim, audio_dim, video_dim, output_dim)

# Use DataLoader to get a batch of data
for batch in train_dataloader:  # You can use any DataLoader (train_dataloader, val_dataloader, etc.)
    text_data, audio_data, video_data, _ = batch
    
    # Feed the entire batch to the GMU model
    with torch.no_grad():
        output = gmu(text_data, audio_data, video_data)
    
    # Print the output shape
    print('-'*50)
    print("GMU Output Shape:", output.shape)
    print("GMU Output: ", output)
    
    # Break after the first batch for testing purposes
    break


--------------------------------------------------
x_t Shape: torch.Size([8, 512])
h1 Shape: torch.Size([8, 512])
W1 Shape: torch.Size([512, 512])
--------------------------------------------------
x_a Shape: torch.Size([8, 512])
h2 Shape: torch.Size([8, 512])
W2 Shape: torch.Size([512, 512])
--------------------------------------------------
x_v Shape: torch.Size([8, 512])
h3 Shape: torch.Size([8, 512])
W13 Shape: torch.Size([512, 512])
--------------------------------------------------
z1 Shape: torch.Size([8, 512])
z2 Shape: torch.Size([8, 512])
z3 Shape: torch.Size([8, 512])
--------------------------------------------------
GMU Output Shape: torch.Size([8, 512])
GMU Output:  tensor([[-0.0544,  0.0315,  0.1469,  ...,  0.0982,  0.0119, -0.0439],
        [-0.3414, -0.0562,  0.1583,  ...,  0.1756, -0.1139, -0.1555],
        [-0.1362,  0.0417,  0.0033,  ...,  0.1210, -0.0395, -0.1473],
        ...,
        [-0.1897, -0.0133,  0.0968,  ...,  0.1642, -0.0928,  0.0761],
        [-0.1032, 