In [None]:
# Load all the necessary modules

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from config import get_config
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd

from AUDIO.model import AudioCNN
from MOCAP.model import Simple1DCNN
from VIDEO.model_video import VideoCNN

from AUDIO.dataset_audio import AudioDataset
from MOCAP.dataset_mc import MotionDataset
from VIDEO.dataset import MultiDataset


In [None]:
# Create instances of the specific model classes
model1 = AudioCNN(num_classes=5)  
model2 = Simple1DCNN(num_classes=5)  
model3 = VideoCNN(num_classes=5)  

In [None]:
# Load the saved state_dicts into the model instances
model1.load_state_dict(torch.load('/path/to/audio_model.pth'))
print(model1.fc5.weight.shape)
model2.load_state_dict(torch.load('/path/to/mocap_model.pth'))
print(model2.fc5.weight.shape)
model3.load_state_dict(torch.load('/path/to/video_model.pth'))
print(model3.fc5.weight.shape)

In [None]:
representations_mocap = []
representations_video = []
representations_audio = []

In [None]:
model1.eval()
model2.eval()
model3.eval()

In [None]:
config = get_config()

dataset_mocap = MotionDataset(csv_file='/path/to/combined_mocap.csv', root_dir='/path/to/mocap/data', nb_class=config['n_class'])
dataset_video = MultiDataset(csv_file='/path/to/combined_video.csv', root_dir='/path/to/video/data', nb_class=config['n_class'])
dataset_audio = AudioDataset(csv_file='/path/to/combined_audio.csv' , root_dir = '/path/to/audio/data', nb_class=config['n_class'])

dataloader_mocap = DataLoader(dataset_mocap, batch_size=config['batch_size'], shuffle=True)
dataloader_video = DataLoader(dataset_video, batch_size=config['batch_size'], shuffle=True)
dataloader_audio = DataLoader(dataset_audio, batch_size=config['batch_size'], shuffle=True)

In [None]:
for data in dataset_mocap:
    input_mocap, label = data  
    # Pass input through the model to get representations
    representation = model2(input_mocap.unsqueeze(0).unsqueeze(0))
    representations_mocap.append(representation)

for data in dataset_video:
    input_video, label = data 
    representation = model3(input_video.unsqueeze(0).unsqueeze(0))  
    representations_video.append(representation)

for data in dataset_audio:
    input_audio, label = data  
    representation = model1(input_audio.unsqueeze(0).unsqueeze(0))
    representations_audio.append(representation)

In [None]:
representations_mocap = torch.cat(representations_mocap, dim=1)
representations_video = torch.cat(representations_video, dim=1)
representations_audio = torch.cat(representations_audio, dim=1)

In [None]:
fused_representations = torch.cat([representations_mocap, representations_video, representations_audio], dim=1)

In [None]:
# Convert the representation data into a PyTorch tensor
input_tensor = torch.tensor(fused_representations, dtype=torch.float32)

# Add an extra dimension to represent the batch size (assuming the concatenated representation is for a single sample)
input_tensor = input_tensor.unsqueeze(dim=0)
