In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd

In [2]:
dataset_dir = "data/"
dataset_file = "multimodal_dataset.csv"

dataset_df = pd.read_csv(dataset_dir + dataset_file)

In [3]:
dataset_df.head()

Unnamed: 0.1,Unnamed: 0,feature0,feature1,feature2,feature3,feature4,feature5,feature6,feature7,feature8,...,feature32,feature33,feature34,feature35,feature36,feature37,feature38,feature39,path,class
0,0,-617.90826,101.58977,9.847578,26.403584,25.784245,1.755381,7.380787,-12.520157,8.840893,...,-3.261531,3.75202,-1.915234,-1.827078,-2.776441,2.127649,0.283703,-0.362988,./data/mnist_images/0\1.png,0
1,1,-636.50385,104.66348,18.78516,32.966637,32.18875,5.560347,2.799065,-8.271547,8.137912,...,-0.526549,1.419234,-0.246682,-0.890081,-5.292832,4.069287,0.7723,2.229786,./data/mnist_images/0\10.png,0
2,2,-600.72955,100.82433,3.306875,20.441507,27.031813,2.805102,7.517792,-12.253216,4.049151,...,-3.493665,1.765706,-1.252896,-1.452544,-4.556331,3.329124,-0.401874,1.102327,./data/mnist_images/0\1000.png,0
3,3,-591.3263,110.81088,2.862722,20.75193,25.868662,-0.488132,-4.731595,-16.296522,4.120318,...,-1.963178,0.605708,-2.856485,0.433533,-1.704196,1.299239,-0.282075,0.257796,./data/mnist_images/0\10005.png,0
4,4,-619.8362,97.7517,19.81103,26.886065,20.38125,5.995906,-1.315152,-11.737551,5.809255,...,-3.206312,1.513105,0.643656,2.057761,-3.23613,0.692797,-2.332294,1.293126,./data/mnist_images/0\1001.png,0


In [29]:
from pathlib import Path
from PIL import Image

class CustomImageDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        # Optional: drop unnecessary index column
        if 'Unnamed: 0' in self.data.columns:
            self.data.drop(columns=['Unnamed: 0'], inplace=True)

        # Explicitly store only the feature columns (feature0 to feature39)
        self.feature_cols = [f'feature{i}' for i in range(40)]
        
        
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Extract and cast features to float32
        features = torch.tensor(row[self.feature_cols].values.astype('float32'))

        # Load image and convert to grayscale
        image_path = row['path']
        image = Image.open(image_path).convert('L')  # 'L' mode is grayscale
        if self.transform:
            image = self.transform(image)

        # Extract label
        label = int(row['class'])

        return {
            'features': features,
            'image': image,
            'label': label,
            'path': image_path
        }

In [48]:
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

# Dataset setup
transform = transforms.Compose([
    ToTensor()
])
dataset = CustomImageDataset(dataset_dir+dataset_file, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Visualize a few examples
for batch in loader:
    print("Features:", batch['features'][0])
    print("Class:", batch['label'][0])
    print("Path:", batch['path'][0])
    
    img = batch['image'][0].permute(1, 2, 0)  # convert to HWC for matplotlib
    #plt.imshow(img)
    #plt.title(f"Class: {batch['label'][0]}")
    #plt.axis('off')
    #plt.show()
    break

Features: tensor([-6.0912e+02,  1.2001e+02,  5.6122e+00,  2.8516e+01,  1.6095e+01,
         1.5480e+01,  7.1957e+00, -4.6163e-01,  5.8542e+00, -7.9944e-01,
        -3.3629e+00,  1.7297e+00, -4.0322e+00, -9.1066e-01, -1.9390e+00,
         6.3424e+00, -1.8770e+00, -4.8751e+00, -4.7109e+00,  5.7536e-02,
        -9.8716e+00,  3.4955e+00, -5.1081e+00, -6.4829e+00, -4.2228e+00,
         9.3065e-01, -3.1307e+00, -3.4545e-01, -1.0484e+00,  2.3160e+00,
        -2.3122e+00,  2.3118e+00,  9.6380e-01,  2.8476e+00, -1.9561e+00,
         6.7078e-01,  1.6186e+00,  4.5677e+00, -5.3396e+00,  1.4408e+00])
Class: tensor(9)
Path: ./data/mnist_images/9\21872.png


In [49]:
import torch.nn as nn
import torch.nn.functional as F

class AudioNet(nn.Module):
    def __init__(self):
        super(AudioNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(40, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 10),  # 10 classes
        )

    def forward(self, x):
        return self.model(x)

In [55]:
class ImageCNN(nn.Module):
    def __init__(self):
        super(ImageCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 1 -> 16
        x = self.pool(F.relu(self.conv2(x)))  # 16 -> 32
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [56]:
def evaluate_accuracy(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            inputs = batch['features' if isinstance(model, AudioNet) else 'image'].to(device)
            labels = batch['label'].to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total


In [57]:
from tqdm import tqdm
def train_model(model, train_loader, val_loader, device, epochs=10, lr=0.001):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for batch in tqdm(train_loader):
            inputs = batch['features' if isinstance(model, AudioNet) else 'image'].to(device)
            labels = batch['label'].to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_acc = evaluate_accuracy(model, val_loader, device)
        print(f"Epoch {epoch+1}: Validation Accuracy = {val_acc:.4f}")


In [58]:
from torch.utils.data import random_split, DataLoader

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size])

# Audio DataLoader
train_data_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_data_loader = DataLoader(val_set, batch_size=64)



In [59]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Audio model
audio_model = AudioNet()
train_model(audio_model, train_data_loader, test_data_loader, device)

# Epoch 10: Validation Accuracy = 0.9485


100%|██████████| 375/375 [01:02<00:00,  6.02it/s]


Epoch 1: Validation Accuracy = 0.8332


100%|██████████| 375/375 [00:48<00:00,  7.75it/s]


Epoch 2: Validation Accuracy = 0.8817


100%|██████████| 375/375 [00:44<00:00,  8.34it/s]


Epoch 3: Validation Accuracy = 0.8888


100%|██████████| 375/375 [00:43<00:00,  8.58it/s]


Epoch 4: Validation Accuracy = 0.8880


100%|██████████| 375/375 [00:43<00:00,  8.55it/s]


Epoch 5: Validation Accuracy = 0.9223


100%|██████████| 375/375 [00:45<00:00,  8.20it/s]


Epoch 6: Validation Accuracy = 0.9233


100%|██████████| 375/375 [00:44<00:00,  8.41it/s]


Epoch 7: Validation Accuracy = 0.9248


100%|██████████| 375/375 [00:47<00:00,  7.92it/s]


Epoch 8: Validation Accuracy = 0.9362


100%|██████████| 375/375 [00:47<00:00,  7.87it/s]


Epoch 9: Validation Accuracy = 0.9400


100%|██████████| 375/375 [00:47<00:00,  7.90it/s]


Epoch 10: Validation Accuracy = 0.9485


In [60]:
# Image model
#image_model = ImageCNN()
#train_model(image_model, train_data_loader, test_data_loader, device)
# Epoch 10: Validation Accuracy = 0.9830

100%|██████████| 375/375 [01:59<00:00,  3.15it/s]


Epoch 1: Validation Accuracy = 0.9590


100%|██████████| 375/375 [01:50<00:00,  3.40it/s]


Epoch 2: Validation Accuracy = 0.9732


100%|██████████| 375/375 [02:01<00:00,  3.08it/s]


Epoch 3: Validation Accuracy = 0.9732


100%|██████████| 375/375 [01:49<00:00,  3.43it/s]


Epoch 4: Validation Accuracy = 0.9772


100%|██████████| 375/375 [01:51<00:00,  3.35it/s]


Epoch 5: Validation Accuracy = 0.9820


100%|██████████| 375/375 [01:38<00:00,  3.82it/s]


Epoch 6: Validation Accuracy = 0.9832


100%|██████████| 375/375 [01:51<00:00,  3.37it/s]


Epoch 7: Validation Accuracy = 0.9833


100%|██████████| 375/375 [01:49<00:00,  3.44it/s]


Epoch 8: Validation Accuracy = 0.9838


100%|██████████| 375/375 [01:42<00:00,  3.65it/s]


Epoch 9: Validation Accuracy = 0.9778


100%|██████████| 375/375 [01:51<00:00,  3.35it/s]


Epoch 10: Validation Accuracy = 0.9830


In [68]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import random


In [69]:
class AudioNetWrapper(nn.Module):
    def __init__(self, audio_model):
        super().__init__()
        self.audio_model = audio_model.model[:-1]  # exclude last layer

    def forward(self, x):
        return self.audio_model(x)  # returns (batch, 100)
    

class ImageCNNWrapper(nn.Module):
    def __init__(self, image_model):
        super().__init__()
        self.image_model = image_model

    def forward(self, x):
        x = self.image_model.pool(F.relu(self.image_model.conv1(x)))
        x = self.image_model.pool(F.relu(self.image_model.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.image_model.fc1(x))  # (batch, 128)
        return x

In [70]:
class AttentionFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, a, b):
        # a and b are (batch, dim)
        queries = self.query(a).unsqueeze(1)  # (batch, 1, dim)
        keys = torch.stack([self.key(a), self.key(b)], dim=1)    # (batch, 2, dim)
        values = torch.stack([self.value(a), self.value(b)], dim=1)  # (batch, 2, dim)

        scores = torch.bmm(queries, keys.transpose(1, 2)) / (keys.size(-1) ** 0.5)  # (batch, 1, 2)
        weights = self.softmax(scores)  # (batch, 1, 2)

        fused = torch.bmm(weights, values).squeeze(1)  # (batch, dim)
        return fused


In [71]:
class FusionModel(nn.Module):
    def __init__(self, audio_model, image_model, aligned_dim=128, num_classes=10):
        super().__init__()
        self.model_a = AudioNetWrapper(audio_model)   # outputs 100
        self.model_b = ImageCNNWrapper(image_model)   # outputs 128

        # Align dimensions
        self.align_a = nn.Linear(100, aligned_dim)
        self.align_b = nn.Linear(128, aligned_dim)

        # Attention-based fusion
        self.attn = AttentionFusion(aligned_dim)

        # Final classifier
        self.classifier = nn.Linear(aligned_dim, num_classes)

    def forward(self, audio_input, image_input):
        feat_a = self.model_a(audio_input)  # (batch, 100)
        feat_b = self.model_b(image_input)  # (batch, 128)

        aligned_a = self.align_a(feat_a)  # (batch, aligned_dim)
        aligned_b = self.align_b(feat_b)  # (batch, aligned_dim)

        fused = self.attn(aligned_a, aligned_b)  # (batch, aligned_dim)
        return self.classifier(fused)


In [76]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch in tqdm(dataloader):
        audio = batch["features"].to(device)
        image = batch["image"].to(device)
        labels = batch["label"].to(device)

        outputs = model(audio, image)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
            audio = batch["features"].to(device)
            image = batch["image"].to(device)
            labels = batch["label"].to(device)

            outputs = model(audio, image)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * labels.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


In [77]:
# Initialize models
audio_model = AudioNet()
image_model = ImageCNN()
model = FusionModel(audio_model, image_model)

# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Optimizer, loss
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Data


# Train loop
epochs = 10
for epoch in range(epochs):
    train_loss, train_acc = train_epoch(model, train_data_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_data_loader, criterion, device)

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}")


KeyboardInterrupt: 