### Prerequisite Packages

In [20]:
import sys
import os
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torcheval.metrics import BinaryPrecision, BinaryRecall, BinaryF1Score
from sklearn.model_selection import train_test_split, KFold

In [21]:
sys.path.append('../')

from modules.cross_attentionb import CrossAttentionB
from modules.dataloader import load_npy_files
from modules.classifier import DenseLayer, BCELoss, CustomLoss, BCEWithLogits
from modules.linear_transformation import LinearTransformations
from modules.output_max import output_max
from evaluation_validation.train_model import train_model
from evaluation_validation.evaluate_model import evaluate_model
from evaluation_validation.test_model import test_model

### Data Loading

In [22]:
class MultimodalDataset(Dataset):
    def __init__(self, id_label_df, text_features, audio_features, video_features):
        self.id_label_df = id_label_df
        
        # Convert feature lists to dictionaries for fast lookup
        self.text_features = {os.path.basename(file).split('.')[0]: tensor for file, tensor in text_features}
        self.audio_features = {os.path.basename(file).split('_')[1].split('.')[0]: tensor for file, tensor in audio_features}
        self.video_features = {os.path.basename(file).split('_')[0]: tensor for file, tensor in video_features}

        # List to store missing files
        self.missing_files = []

        # Filter out entries with missing files
        self.valid_files = self._filter_valid_files()


    def _filter_valid_files(self):
        valid_files = []
        for idx in range(len(self.id_label_df)):
            imdbid = self.id_label_df.iloc[idx]['IMDBid']

            # Check if the IMDBid exists in each modality's features
            if imdbid in self.text_features and imdbid in self.audio_features and imdbid in self.video_features:
                valid_files.append(idx)
            else:
                self.missing_files.append({'IMDBid': imdbid})

        # Print missing files after checking all
        if self.missing_files:
            print("Missing files:")
            for item in self.missing_files:
                print(f"IMDBid: {item['IMDBid']}")
            print(f"Total IMDB IDs with missing files: {len(self.missing_files)}")
        else:
            print("No missing files.")

        return valid_files

    def __len__(self):
        return len(self.valid_files)

    def __getitem__(self, idx):
        # Get the original index from the filtered valid files
        original_idx = self.valid_files[idx]
        imdbid = self.id_label_df.iloc[original_idx]['IMDBid']
        label = self.id_label_df.iloc[original_idx]['Label']

        # Retrieve data from the loaded features
        text_data = self.text_features.get(imdbid, torch.zeros((1024,)))
        audio_data = self.audio_features.get(imdbid, torch.zeros((1, 197, 768)))
        video_data = self.video_features.get(imdbid, torch.zeros((95, 768)))
        
        # Define label mapping
        label_map = {'red': 0, 'green': 1} 
        
        # Convert labels to tensor using label_map
        try:
            label_data = torch.tensor([label_map[label]], dtype=torch.float32)  # Ensure labels are integers
        except KeyError as e:
            print(f"Error: Label '{e}' not found in label_map.")
            raise

        return text_data, audio_data, video_data, label_data


In [23]:
def collate_fn(batch):
    text_data, audio_data, video_data, label_data = zip(*batch)

    # Convert lists to tensors
    text_data = torch.stack(text_data)
    audio_data = torch.stack(audio_data)

    # Padding for video data
    # Determine maximum length of video sequences in the batch
    video_lengths = [v.size(0) for v in video_data]
    max_length = max(video_lengths)

    # Pad video sequences to the maximum length
    video_data_padded = torch.stack([
        F.pad(v, (0, 0, 0, max_length - v.size(0)), "constant", 0)
        for v in video_data
    ])

    # Convert labels to tensor and ensure the shape [batch_size, 1]
    label_data = torch.stack(label_data)  # Convert list of tensors to a single tensor

    return text_data, audio_data, video_data_padded, label_data

In [24]:
# Load the labels DataFrame
id_label_df = pd.read_excel('../misc/MM-Trailer_dataset.xlsx')

# Define the directories
text_features_dir = '../misc/textStream_BERT/feature_vectors'
audio_features_dir = '../misc/audio_fe/logmel_spectrograms'
video_features_dir = '../misc/visualStream_ViT'

# Load the feature vectors from each directory
text_features = load_npy_files(text_features_dir)
audio_features = load_npy_files(audio_features_dir)
video_features = load_npy_files(video_features_dir)

print(f"Number of text feature vectors loaded: {len(text_features)}")
print(f"Number of audio feature vectors loaded: {len(audio_features)}")
print(f"Number of video feature vectors loaded: {len(video_features)}")

Number of text feature vectors loaded: 1353
Number of audio feature vectors loaded: 1353
Number of video feature vectors loaded: 1353


### Important Functions

In [25]:
class MultiheadCrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        
    def forward(self, query, key_value):
        # Ensure inputs are 3D: (batch_size, sequence_length, embed_dim)
        if query.dim() == 2:
            query = query.unsqueeze(1)  # Add sequence length dimension
        if key_value.dim() == 2:
            key_value = key_value.unsqueeze(1)  # Add sequence length dimension
            
        output, _ = self.multihead_attn(query, key_value, key_value)
        return output

### SMCA Functions and Model

In [26]:
def SMCAStage1(modalityAlpha, modalityBeta, d_out_v, device):
    # Ensure inputs are 3D
    if modalityAlpha.dim() == 2:
        modalityAlpha = modalityAlpha.unsqueeze(1)
    if modalityBeta.dim() == 2:
        modalityBeta = modalityBeta.unsqueeze(1)

    # Initialize the cross attention module
    cross_attn = MultiheadCrossAttention(d_out_v).to(device)

    # Cross-attention: Alpha -> Beta
    alphaBeta = cross_attn(modalityAlpha, modalityBeta)  # Shape: (batch_size, seq_alpha, d_out_v)

    # Cross-attention: Beta -> Alpha
    betaAlpha = cross_attn(modalityBeta, modalityAlpha)  # Shape: (batch_size, seq_beta, d_out_v)

    # Get sequence lengths
    seq_len_alpha = alphaBeta.size(1)
    seq_len_beta = betaAlpha.size(1)
    max_seq_len = max(seq_len_alpha, seq_len_beta)

    # Pad sequences to match lengths
    if seq_len_alpha < max_seq_len:
        alphaBeta = F.pad(alphaBeta, (0, 0, 0, max_seq_len - seq_len_alpha), value=0)
    if seq_len_beta < max_seq_len:
        betaAlpha = F.pad(betaAlpha, (0, 0, 0, max_seq_len - seq_len_beta), value=0)

    # Concatenate the cross-attention outputs
    modalityAlphaBeta = torch.cat((alphaBeta, betaAlpha), dim=-1)
    return modalityAlphaBeta

In [27]:
import torch
import torch.nn as nn

class ProjectionLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

def SMCAStage2(modalityAlphaBeta, modalityGamma, d_out_v, device):
    # Ensure modalityGamma is 3D
    if modalityGamma.dim() == 2:
        modalityGamma = modalityGamma.unsqueeze(1)

    # Project modalityAlphaBeta to match embedding dimension
    projection = nn.Linear(modalityAlphaBeta.shape[-1], d_out_v).to(device)
    modalityAlphaBetaProjected = projection(modalityAlphaBeta)

    # Initialize cross attention
    cross_attn = MultiheadCrossAttention(d_out_v).to(device)

    # Cross-attention: AlphaBeta -> Gamma
    alphaBetaGamma = cross_attn(modalityAlphaBetaProjected, modalityGamma)

    # Cross-attention: Gamma -> AlphaBeta
    gammaAlphaBeta = cross_attn(modalityGamma, modalityAlphaBetaProjected)

    # Get sequence lengths
    seq_len_alphaBeta = alphaBetaGamma.size(1)
    seq_len_gamma = gammaAlphaBeta.size(1)
    max_seq_len = max(seq_len_alphaBeta, seq_len_gamma)

    # Pad sequences to match lengths
    if seq_len_alphaBeta < max_seq_len:
        alphaBetaGamma = F.pad(alphaBetaGamma, (0, 0, 0, max_seq_len - seq_len_alphaBeta), value=0)
    if seq_len_gamma < max_seq_len:
        gammaAlphaBeta = F.pad(gammaAlphaBeta, (0, 0, 0, max_seq_len - seq_len_gamma), value=0)

    # Concatenate and apply global average pooling
    multimodal_representation = torch.cat((alphaBetaGamma, gammaAlphaBeta), dim=-1)
    GAP = torch.mean(multimodal_representation, dim=1)
    
    return GAP


In [28]:
class SMCAModel(nn.Module):
    def __init__(self, d_out_v, device):
        super(SMCAModel, self).__init__()
        self.d_out_v = d_out_v
        self.device = device
    
    def forward(self, modalityAlpha, modalityBeta, modalityGamma):
        # Stage 1: Cross attention between modalityAlpha and modalityBeta
        modalityAlphaBeta = SMCAStage1(modalityAlpha, modalityBeta, self.d_out_v, self.device)
        # Stage 2: Cross attention with modalityAlphaBeta and modalityGamma
        multimodal_representation = SMCAStage2(modalityAlphaBeta, modalityGamma, self.d_out_v, self.device)
        
        return multimodal_representation

Test on one instance

In [29]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SMCAModel(768, device)  # Only need d_out_v

# Select the first file from each modality directories (for testing)
video_file_name, video_feature = video_features[5]
audio_file_name, audio_feature = audio_features[5]
text_file_name, text_feature = text_features[5]

# Print the file names
print("\nSelected File Names:")
print("Audio file:", audio_file_name)
print("Video file:", video_file_name)
print("Text file:", text_file_name)

video_feature = video_feature.unsqueeze(0)  # Add batch dimension
text_feature = text_feature.unsqueeze(0)    # Add batch dimension

modalityAlpha=audio_feature 
modalityBeta=text_feature
modalityGamma=video_feature

# Apply linear transformation to match dimensions
linear_transform_Alpha = LinearTransformations(modalityAlpha.shape[-1], 768)
linear_transform_Beta = LinearTransformations(modalityBeta.shape[-1], 768)
linear_transform_Gamma = LinearTransformations(modalityGamma.shape[-1], 768)

modalityAlpha = linear_transform_Alpha(modalityAlpha)
modalityBeta = linear_transform_Beta(modalityBeta)
modalityGamma = linear_transform_Gamma(modalityGamma)

print("Audio: ",modalityAlpha.shape)
print("Text: ",modalityBeta.shape)
print("Video: ",modalityGamma.shape)

outputs = model(modalityAlpha, modalityBeta, modalityGamma)

print("Stage 2:", outputs.shape)





Selected File Names:
Audio file: ../misc/audio_fe/logmel_spectrograms/feature_tt3168230.npy
Video file: ../misc/visualStream_ViT/tt0367652_features.npy
Text file: ../misc/textStream_BERT/feature_vectors/tt0452598.npy
Audio:  torch.Size([1, 197, 768])
Text:  torch.Size([1, 768])
Video:  torch.Size([1, 83, 768])
Stage 2: torch.Size([1, 1536])


Test on Entire Dataset

In [30]:
# Load the labels DataFrame
id_label_df = pd.read_excel('../misc/MM-Trailer_dataset.xlsx')

# Define the directories
text_features_dir = '../misc/textStream_BERT/feature_vectors'
audio_features_dir = '../misc/audio_fe/logmel_spectrograms'
video_features_dir = '../misc/visualStream_ViT'

# Load the feature vectors from each directory
text_features = load_npy_files(text_features_dir)
audio_features = load_npy_files(audio_features_dir)
video_features = load_npy_files(video_features_dir)

print(f"Number of text feature vectors loaded: {len(text_features)}")
print(f"Number of audio feature vectors loaded: {len(audio_features)}")
print(f"Number of video feature vectors loaded: {len(video_features)}")

# Drop unnecessary columns
id_label_df = id_label_df.drop(columns=['Movie Title', 'URL'])

# Splitting data for training, validation, and testing
train_df, val_test_df = train_test_split(id_label_df, test_size=0.3, random_state=42)

# Further splitting remaining set into validation and test sets
val_df, test_df = train_test_split(val_test_df, test_size=0.5, random_state=42)

print("-" * 30)

# Create datasets
train_dataset = MultimodalDataset(train_df, text_features, audio_features, video_features)
val_dataset = MultimodalDataset(val_df, text_features, audio_features, video_features)
test_dataset = MultimodalDataset(test_df, text_features, audio_features, video_features)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=collate_fn)

# Combine all data for K-fold cross-validation
full_dataset = MultimodalDataset(id_label_df, text_features, audio_features, video_features)

Number of text feature vectors loaded: 1353
Number of audio feature vectors loaded: 1353
Number of video feature vectors loaded: 1353
------------------------------
Missing files:
IMDBid: tt2494280
IMDBid: tt1724962
IMDBid: tt1152836
IMDBid: tt0389790
IMDBid: tt3053228
IMDBid: tt1045778
IMDBid: tt1758795
IMDBid: tt0099385
IMDBid: tt2917484
IMDBid: tt4769836
IMDBid: tt0089652
IMDBid: tt0465494
IMDBid: tt3675748
IMDBid: tt2126362
IMDBid: tt0988083
IMDBid: tt2101341
IMDBid: tt0401997
IMDBid: tt1661461
IMDBid: tt1313139
IMDBid: tt1094661
IMDBid: tt5162658
IMDBid: tt0104839
IMDBid: tt1288558
IMDBid: tt5962210
IMDBid: tt2937696
IMDBid: tt0284363
IMDBid: tt5580390
IMDBid: tt2293750
IMDBid: tt2980472
IMDBid: tt0082186
IMDBid: tt0924129
IMDBid: tt0988595
IMDBid: tt1349482
IMDBid: tt4158096
IMDBid: tt1403241
IMDBid: tt2713642
IMDBid: tt1682940
IMDBid: tt10327354
IMDBid: tt1087842
IMDBid: tt1800302
IMDBid: tt0113855
IMDBid: tt2504022
IMDBid: tt7248248
IMDBid: tt1720164
IMDBid: tt1336621
IMDBid: t

In [31]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = SMCAModel(768, device)

print(f"Model parameters: {list(model.parameters())}")


# Training loop
for text_features, audio_features, video_features, targets in train_dataloader:
    # Move features to the specified device
    text_features = text_features.to(device)
    audio_features = audio_features.to(device)
    video_features = video_features.to(device)

    # Squeeze the audio features to remove the extra dimension
    audio_features = audio_features.squeeze(1) 

    # Print dimensions for debugging
    print("Text Dimension: ", text_features.shape)  
    print("Audio Dimension: ", audio_features.shape)  
    print("Video Dimension: ", video_features.shape) 

    # Apply linear transformations to match dimensions
    linear_transform_Alpha = LinearTransformations(audio_features.shape[-1], 768) 
    linear_transform_Beta = LinearTransformations(text_features.shape[-1], 768)   
    linear_transform_Gamma = LinearTransformations(video_features.shape[-1], 768)    

    # Transform features to match the target dimension
    modalityAlpha = linear_transform_Alpha(audio_features)  
    modalityBeta = linear_transform_Beta(text_features)    
    modalityGamma = linear_transform_Gamma(video_features)

    # Print shapes after transformation to verify the batch dimension
    print("Transformed Audio Dimension: ", modalityAlpha.shape)  
    print("Transformed Text Dimension: ", modalityBeta.shape)    
    print("Transformed Video Dimension: ", modalityGamma.shape)  

    # Pass inputs through the SMCA model
    outputs = model(
        modalityAlpha=modalityAlpha,  # Ensure to pass transformed modalities
        modalityBeta=modalityBeta,
        modalityGamma=modalityGamma,
    )

    print("Stage 2:", outputs.shape)  # Check the output shape
    print("--------")


Model parameters: []
Text Dimension:  torch.Size([16, 1024])
Audio Dimension:  torch.Size([16, 197, 768])
Video Dimension:  torch.Size([16, 281, 768])
Transformed Audio Dimension:  torch.Size([16, 197, 768])
Transformed Text Dimension:  torch.Size([16, 768])
Transformed Video Dimension:  torch.Size([16, 281, 768])
Stage 2: torch.Size([16, 1536])
--------
Text Dimension:  torch.Size([16, 1024])
Audio Dimension:  torch.Size([16, 197, 768])
Video Dimension:  torch.Size([16, 164, 768])
Transformed Audio Dimension:  torch.Size([16, 197, 768])
Transformed Text Dimension:  torch.Size([16, 768])
Transformed Video Dimension:  torch.Size([16, 164, 768])
Stage 2: torch.Size([16, 1536])
--------
Text Dimension:  torch.Size([16, 1024])
Audio Dimension:  torch.Size([16, 197, 768])
Video Dimension:  torch.Size([16, 143, 768])
Transformed Audio Dimension:  torch.Size([16, 197, 768])
Transformed Text Dimension:  torch.Size([16, 768])
Transformed Video Dimension:  torch.Size([16, 143, 768])
Stage 2: tor

In [32]:
def train_model(model, dense_layer, dataloader, criterion, optimizer, device):
    model.train()
    dense_layer.train()  # Set the model to training mode
    total_loss = 0.0

    for text_features, audio_features, video_features, targets in dataloader:
        text_features, audio_features, video_features, targets = (
            text_features.to(device),
            audio_features.to(device),
            video_features.to(device),
            targets.to(device).view(-1)
        )
        
        optimizer.zero_grad()
        
        # Pass inputs through SMCA model
        # Squeeze the audio features to remove the extra dimension
        audio_features = audio_features.squeeze(1) 

        # Apply linear transformations to match dimensions
        linear_transform_Alpha = LinearTransformations(audio_features.shape[-1], 768) 
        linear_transform_Beta = LinearTransformations(text_features.shape[-1], 768)   
        linear_transform_Gamma = LinearTransformations(video_features.shape[-1], 768)    

        # Transform features to match the target dimension
        modalityAlpha = linear_transform_Alpha(audio_features)  
        modalityBeta = linear_transform_Beta(text_features)    
        modalityGamma = linear_transform_Gamma(video_features)
        
        outputs = model(
            modalityAlpha=modalityAlpha,  # Ensure to pass transformed modalities
            modalityBeta=modalityBeta,
            modalityGamma=modalityGamma,
        )

        # Pass the fused features through the dense layer
        predictions = dense_layer(outputs).view(-1)

        # Compute loss
        loss = criterion(predictions, targets)
        total_loss += loss.item()
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

In [33]:
def evaluate_model(model, dense_layer, dataloader, criterion, device):
    model.eval()
    dense_layer.eval()
    total_loss = 0.0

    # Initialize the metrics for binary classification
    precision_metric = BinaryPrecision().to(device)
    recall_metric = BinaryRecall().to(device)
    f1_metric = BinaryF1Score().to(device)

    precision_metric.reset()
    recall_metric.reset()
    f1_metric.reset()

    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
         for text_features, audio_features, video_features, targets in dataloader:
            text_features, audio_features, video_features, targets = (
                text_features.to(device),
                audio_features.to(device),
                video_features.to(device),
                targets.to(device).view(-1)
            )
            
            # Pass inputs through SMCA model
            # Squeeze the audio features to remove the extra dimension
            audio_features = audio_features.squeeze(1) 

            # Apply linear transformations to match dimensions
            linear_transform_Alpha = LinearTransformations(audio_features.shape[-1], 768) 
            linear_transform_Beta = LinearTransformations(text_features.shape[-1], 768)   
            linear_transform_Gamma = LinearTransformations(video_features.shape[-1], 768)    

            # Transform features to match the target dimension
            modalityAlpha = linear_transform_Alpha(audio_features)  
            modalityBeta = linear_transform_Beta(text_features)    
            modalityGamma = linear_transform_Gamma(video_features)
            
            outputs = model(modalityAlpha=modalityAlpha, modalityBeta=modalityBeta, modalityGamma=modalityGamma)

            # Pass the fused features through the dense layer
            predictions = dense_layer(outputs).view(-1) 

            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            
            # Compute loss
            loss = criterion(predictions, targets)
            total_loss += loss.item()

            # Apply threshold to get binary predictions
            preds = (predictions > 0.5).float()
            
            # Update the precision, recall, and F1 score metrics
            precision_metric.update(preds.long(), targets.long())
            recall_metric.update(preds.long(), targets.long())
            f1_metric.update(preds.long(), targets.long())

    # Compute precision, recall, and F1 score
    precision = precision_metric.compute().item()
    recall = recall_metric.compute().item()
    f1_score = f1_metric.compute().item()

    average_loss = total_loss / len(dataloader)
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1_score:.4f}")
    
    return average_loss, precision, recall, f1_score


In [34]:
def test_model(model, dense_layer, dataloader, criterion, device):
    model.eval()
    dense_layer.eval()  # Set the model to evaluation mode
    total_loss = 0

    # Initialize the metrics for binary classification
    precision_metric = BinaryPrecision().to(device)
    recall_metric = BinaryRecall().to(device)
    f1_metric = BinaryF1Score().to(device)

    with torch.no_grad():
        for text_features, audio_features, video_features, targets in dataloader:
            text_features, audio_features, video_features, targets = (
                text_features.to(device),
                audio_features.to(device),
                video_features.to(device),
                targets.to(device).view(-1)
            )
            
            # Pass inputs through SMCA model
            # Squeeze the audio features to remove the extra dimension
            audio_features = audio_features.squeeze(1) 

            # Apply linear transformations to match dimensions
            linear_transform_Alpha = LinearTransformations(audio_features.shape[-1], 768) 
            linear_transform_Beta = LinearTransformations(text_features.shape[-1], 768)   
            linear_transform_Gamma = LinearTransformations(video_features.shape[-1], 768)    

            # Transform features to match the target dimension
            modalityAlpha = linear_transform_Alpha(audio_features)  
            modalityBeta = linear_transform_Beta(text_features)    
            modalityGamma = linear_transform_Gamma(video_features)
            
            outputs = model(modalityAlpha=modalityAlpha, modalityBeta=modalityBeta, modalityGamma=modalityGamma)
            
            # Pass the fused features through the dense layer
            predictions = dense_layer(outputs).view(-1)  
            
            # Compute loss
            loss = criterion(predictions, targets)
            total_loss += loss.item()

            # Apply threshold to get binary predictions
            preds = (predictions > 0.5).float()
            
            # Update the precision, recall, and F1 score metrics
            precision_metric.update(preds.long(), targets.long())
            recall_metric.update(preds.long(), targets.long())
            f1_metric.update(preds.long(), targets.long())

    # Compute precision, recall, and F1 score
    precision = precision_metric.compute().item()
    recall = recall_metric.compute().item()
    f1_score = f1_metric.compute().item()

    average_loss = total_loss / len(dataloader)

    print(f"Test Loss: {average_loss:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test F1 Score: {f1_score:.4f}")

    return average_loss, precision, recall, f1_score


In [35]:
def get_optimizer(parameters, lr=1e-3):
    # Create an optimizer, for example, Adam
    return optim.Adam(parameters, lr=lr)

In [36]:
if __name__ == "__main__":
    torch.manual_seed(42)

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Determine the output dimensions
    output_dim = 768

    # Initialize the SMCA model A
    model = SMCAModel(768, device)
    model.to(device)  # Move the model to the correct device

    # Initialize the DenseLayer with the largest output size
    dense_layer = DenseLayer(input_size=output_dim*2).to(device)  # Initialize and move to the correct device

    # Define the loss function and optimizer
    criterion = BCEWithLogits()  # Use appropriate loss function
    
    for param in model.parameters():
        if param.grad is None:
            print("No gradient for:", param)
    optimizer = get_optimizer(list(model.parameters()) + list(dense_layer.parameters()))


    # Training loop
    num_epochs = 10  # Set the number of epochs you want to train for
   
    for epoch in range(num_epochs):
        print("-" * 30)
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Ensure you have a dataloader that yields inputs and targets
        train_loss = train_model(model=model, dense_layer=dense_layer, dataloader=train_dataloader, criterion=criterion, optimizer=optimizer, device=device)

        # Validate step
        val_loss, precision, recall, f1_score = evaluate_model(model=model, dense_layer=dense_layer, dataloader=val_dataloader, criterion=criterion, device=device)

        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

    # Testing the model
    print("-" * 30)
    print("Testing the model on the test set...")
    test_loss, test_precision, test_recall, test_f1_score = test_model(model=model, dense_layer=dense_layer, dataloader=test_dataloader, criterion=criterion, device=device)


Device: cpu
------------------------------
Epoch 1/10
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Training Loss: 0.6857, Validation Loss: 0.6789
------------------------------
Epoch 2/10
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Training Loss: 0.6725, Validation Loss: 0.6674
------------------------------
Epoch 3/10
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Training Loss: 0.6615, Validation Loss: 0.6564
------------------------------
Epoch 4/10
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Training Loss: 0.6510, Validation Loss: 0.6503
------------------------------
Epoch 5/10
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Training Loss: 0.6410, Validation Loss: 0.6397
------------------------------
Epoch 6/10
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Training Loss: 0.6328, Validation Loss: 0.6350
------------------------------
Epoch 7/10
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Training Loss: 0.6252, Validation Loss: 0.6282
---------------