In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from subprocess import run
from torch.utils.data import Dataset
from torchvision import datasets, transforms

class MultimodalDataset(Dataset):
    def __init__(self, name, dataset_dir, device, download=False, exclude_modality='none', target_modality='none', train=True, transform=None, adv_attack=None):
        super().__init__()
        if download:
            self._download()

        self.name = name
        self.device = device
        self.dataset_dir = dataset_dir
        self.exclude_modality = exclude_modality
        self.transform = transform
        self.adv_attack = adv_attack
        self.target_modality = target_modality
        self.dataset = {}
        self.dataset_len = 0
        self.labels = None
        self.modalities = None
        self._load_data(train)

    def _download(self):
        raise NotImplementedError
    
    def _load_data(self, train):
        raise NotImplementedError
    
    def _show_dataset_label_distribution(self):
        raise NotImplementedError
    
    def _get_name(self):
        return self.name
    
    def _get_modalities(self):
        return self.modalities
    
    def _set_transform(self, transform):
        self.transform = transform
    
    def _set_adv_attack(self, adv_attack):
        self.adv_attack = adv_attack

    def __len__(self):
        return self.dataset_len
    
    def __getitem__(self, index):
        data = dict.fromkeys(self.dataset.keys())
        for key in data.keys():
            data[key] = self.dataset[key][index].type(torch.cuda.FloatTensor)

        if self.labels is not None:
            labels = self.labels[index]

        if self.transform is not None:
            data = self.transform(data)

        if self.adv_attack is not None:
            if self.labels is not None:
                data = self.adv_attack(data, labels)
            else:
                data = self.adv_attack(data, data)

        return data, labels

In [2]:
class MhdDataset(MultimodalDataset):
    def __init__(self, dataset_dir, device, download=False, exclude_modality='none', target_modality='none', train=True, transform=None, adv_attack=None):
        super().__init__(dataset_dir, device, download, exclude_modality, target_modality, train, transform, adv_attack)

    @staticmethod
    def _download():
        run([os.path.join(os.getcwd(), "datasets", "mhd", "download_mhd_dataset.sh"), "bash"], shell=True)
        return

    def _load_data(self, train):
        if train:
            data_path = os.path.join(self.dataset_dir, "mhd_train.pt")
        else:
            data_path = os.path.join(self.dataset_dir, "mhd_test.pt")
        
        data = list(torch.load(data_path))
        self.dataset_len = len(data[0])

        # Normalize datasets
        data[1] = (data[1] - torch.min(data[1])) / (torch.max(data[1]) - torch.min(data[1]))
        data[2] = (data[2] - torch.min(data[2])) / (torch.max(data[2]) - torch.min(data[2]))

        if self.exclude_modality == 'image':
            self.dataset = {'image': torch.full(data[1].size(), -1).to(self.device),'trajectory': data[2].to(self.device)}
        elif self.exclude_modality == 'trajectory':
            self.dataset = {'image': data[1].to(self.device), 'trajectory': torch.full(data[2].size(), -1).to(self.device)}
        else:
            self.dataset = {'image': data[1].to(self.device), 'trajectory': data[2].to(self.device)}

        self.labels = data[0].to(self.device)
        return
    
    def _show_dataset_label_distribution(self):
        label_dict = {"0": 0, "1": 0, "2": 0, "3": 0, "4": 0, "5": 0, "6": 0, "7": 0, "8": 0, "9": 0}
        
        for data_set in ['train', 'test']:
            data_path = os.path.join(self.dataset_dir, f"mhd_{data_set}.pt")

            data = torch.load(data_path)
            labels = data["labels"]
            dataset_len = len(labels)
            for label in labels:
                label_dict[str(label.item())] += 1

            print(f'Label count: {dataset_len}')
            print(label_dict)
            X_axis = np.arange(len(label_dict.keys()))
            fig, ax = plt.subplots()
            fig.figsize=(20, 10)
            ax.set_xticks(X_axis)
            ax.set_xticklabels(label_dict.keys())
            ax.set_title(f"MHD {data_set} set digit labels")
            ax.yaxis.grid(True)
            metrics_bar = ax.bar(X_axis, label_dict.values(), width=1, label="Loss values", align='center', ecolor='black', capsize=10)
            ax.bar_label(metrics_bar)
            fig.legend()
            fig.savefig(os.path.join(self.dataset_dir, f'mhd_{data_set}.png'))
            plt.close()
        return

In [3]:
class MnistSvhnDataset(MultimodalDataset):
    def __init__(self, name, dataset_dir, device, download=False, exclude_modality='none', target_modality='none', train=True, transform=None, adv_attack=None, max_d = 10000, dm=30):
        super().__init__(name, dataset_dir, device, download, exclude_modality, target_modality, train, transform, adv_attack)
        self.max_d = max_d  # maximum number of datapoints per class
        self.dm = dm        # data multiplier: random permutations to match 
        
    @staticmethod
    def _download():
        # Get the individual datasets
        tx = transforms.ToTensor()
        train_mnist = datasets.MNIST(os.path.join("datasets", "mnist_svhn"), train=True, download=True, transform=tx)
        test_mnist = datasets.MNIST(os.path.join("datasets", "mnist_svhn"), train=False, download=True, transform=tx)
        train_svhn = datasets.SVHN(os.path.join("datasets", "mnist_svhn"), split='train', download=True, transform=tx)
        test_svhn = datasets.SVHN(os.path.join("datasets", "mnist_svhn"), split='test', download=True, transform=tx)
        # SVHN labels need extra work
        train_svhn.labels = torch.LongTensor(train_svhn.labels.squeeze().astype(int)) % 10
        test_svhn.labels = torch.LongTensor(test_svhn.labels.squeeze().astype(int)) % 10

        train_dict = {"mnist": [], "svhn": [], "labels": []}
        svhn_dict = {"0": [], "1": [], "2": [], "3": [], "4": [], "5": [], "6": [], "7": [], "8": [], "9": []}
        print("Exporting svhn train set...")
        for feats, label in tqdm(zip(train_svhn.data, train_svhn.labels), total=len(train_svhn)):
            svhn_dict[str(label.item())].append(feats)

        mnist_dict = {"0": [], "1": [], "2": [], "3": [], "4": [], "5": [], "6": [], "7": [], "8": [], "9": []}
        print("Exporting mnist train set...")
        for feats, label in tqdm(zip(train_mnist.data, train_mnist.targets), total=len(train_mnist)):
            mnist_dict[str(label.item())].append(feats)
        
        print("Combining training datasets...")
        for dig in tqdm(svhn_dict.keys()):
            for mnist_feats, svhn_feats in zip(mnist_dict[str(dig)], svhn_dict[str(dig)]):
                train_dict["mnist"].append(mnist_feats[None, :])
                train_dict["svhn"].append(torch.from_numpy(svhn_feats))
                train_dict["labels"].append(torch.tensor(int(dig)))

        train_dict["mnist"] = torch.stack(train_dict["mnist"])
        train_dict["svhn"] = torch.stack(train_dict["svhn"])
        train_dict["labels"] = torch.stack(train_dict["labels"])

        test_dict = {"mnist": [], "svhn": [], "labels": []}
        svhn_dict = {"0": [], "1": [], "2": [], "3": [], "4": [], "5": [], "6": [], "7": [], "8": [], "9": []}
        print("Exporting svhn test set...")
        for feats, label in tqdm(zip(test_svhn.data, test_svhn.labels), total=len(test_svhn)):
            svhn_dict[str(label.item())].append(feats)

        mnist_dict = {"0": [], "1": [], "2": [], "3": [], "4": [], "5": [], "6": [], "7": [], "8": [], "9": []}
        print("Exporting mnist test set...")
        for feats, label in tqdm(zip(test_mnist.data, test_mnist.targets), total=len(test_mnist)):
            mnist_dict[str(label.item())].append(feats)
        
        print("Combining test datasets...")
        for dig in tqdm(svhn_dict.keys()):
            for mnist_feats, svhn_feats in zip(mnist_dict[str(dig)], svhn_dict[str(dig)]):
                test_dict["mnist"].append(mnist_feats[None, :])
                test_dict["svhn"].append(torch.from_numpy(svhn_feats))
                test_dict["labels"].append(torch.tensor(int(dig)))

        test_dict["mnist"] = torch.stack(test_dict["mnist"])
        test_dict["svhn"] = torch.stack(test_dict["svhn"])
        test_dict["labels"] = torch.stack(test_dict["labels"])
        
        torch.save(train_dict, os.path.join("datasets", "mnist_svhn", 'mnist_svhn_train.pt'))
        torch.save(test_dict, os.path.join("datasets", "mnist_svhn", 'mnist_svhn_test.pt'))
        return
    
    def _load_data(self, train):
        if train:
            data_path = os.path.join(self.dataset_dir, "mnist_svhn_train.pt")
        else:
            data_path = os.path.join(self.dataset_dir, "mnist_svhn_test.pt")

        data = torch.load(data_path)
        self.dataset_len = len(data["labels"])

        # Normalize datasets
        data['mnist'] = (data['mnist'] - torch.min(data['mnist'])) / (torch.max(data['mnist']) - torch.min(data['mnist']))
        data['svhn'] = (data['svhn'] - torch.min(data['svhn'])) / (torch.max(data['svhn']) - torch.min(data['svhn']))

        if self.exclude_modality == 'mnist':
            self.dataset = {'mnist': torch.full(data["mnist"].size(), -1).to(self.device), 'svhn': data["svhn"].to(self.device)}
        elif self.exclude_modality == 'svhn':
            self.dataset = {'mnist': data["mnist"].to(self.device), 'svhn': torch.full(data["svhn"].size(), -1).to(self.device)}
        else:
            self.dataset = {'mnist': data["mnist"].to(self.device), 'svhn': data["svhn"].to(self.device)}

        self.labels = data["labels"].to(self.device)
        return
    
    def _show_dataset_label_distribution(self):
        label_dict = {"0": 0, "1": 0, "2": 0, "3": 0, "4": 0, "5": 0, "6": 0, "7": 0, "8": 0, "9": 0}

        for data_set in ['train', 'test']:
            data_path = os.path.join(self.dataset_dir, f"mnist_svhn_{data_set}.pt")

            data = torch.load(data_path)
            labels = data["labels"]
            dataset_len = len(labels)
            for label in labels:
                label_dict[str(label.item())] += 1

            print(f'Label count: {dataset_len}')
            print(label_dict)
            X_axis = np.arange(len(label_dict.keys()))
            fig, ax = plt.subplots()
            fig.figsize=(20, 10)
            ax.set_xticks(X_axis)
            ax.set_xticklabels(label_dict.keys())
            ax.set_title(f"MNIST-SVHN {data_set} set digit labels")
            ax.yaxis.grid(True)
            metrics_bar = ax.bar(X_axis, label_dict.values(), width=1, label="Loss values", align='center', ecolor='black', capsize=10)
            ax.bar_label(metrics_bar)
            fig.legend()
            fig.savefig(os.path.join(self.dataset_dir, f'ms_{data_set}.png'))
            plt.close()
        return

In [4]:
class MoseiDataset(MultimodalDataset):
    def __init__(self, dataset_dir, device, download=False, exclude_modality='none', target_modality='none', train=True, transform=None, adv_attack=None):
        super().__init__(dataset_dir, device, download, exclude_modality, target_modality, train, transform, adv_attack)

    @staticmethod
    def _download():
        run([os.path.join(os.getcwd(), "datasets", "mosei", "download_mosei_dataset.sh"), "bash"], shell=True)
        dataset_dir = os.path.join(os.getcwd(), "datasets", "mosei")
        data = torch.load(os.path.join(dataset_dir, "mosei_train.dt"))
        val_data = torch.load(os.path.join(dataset_dir, "mosei_valid.dt"))
        dataset = {'text': torch.concat((data.text, val_data.text)), 'audio': torch.concat((data.audio, val_data.audio)), 'vision': torch.concat((data.vision, val_data.vision)), 'labels': torch.concat((data.labels, val_data.labels))}
        torch.save(dataset, os.path.join(dataset_dir, "mosei_train.pt"))
        data = torch.load(os.path.join(dataset_dir, "mosei_test.dt"))
        dataset = {'text': data.text, 'audio': data.audio, 'vision': data.vision, 'labels': data.labels}
        torch.save(dataset, os.path.join(dataset_dir, "mosei_test.pt"))
        return
    
    def _load_data(self, train):   
        if train:
            data_path = os.path.join(self.dataset_dir, "mosei_train.pt")
            
        else:
            data_path = os.path.join(self.dataset_dir, "mosei_test.pt")

        data = torch.load(data_path)
        self.dataset = {'text': data['text'].to(self.device), 'audio': data['audio'].to(self.device), 'vision': data['vision'].to(self.device)}
        self.labels = data['labels'].to(self.device)

        self.dataset_len = len(self.labels)

        if self.exclude_modality != 'none' and self.exclude_modality is not None:
            self.dataset[self.exclude_modality] = torch.full(self.dataset[self.exclude_modality], -1).to(self.device)

        for mod in ['text', 'audio', 'vision']:
            if mod != self.exclude_modality:
                self.dataset[mod] = (self.dataset[mod] - torch.min(self.dataset[mod])) / (torch.max(self.dataset[mod]) - torch.min(self.dataset[mod]))

        return

In [5]:
class MosiDataset(MultimodalDataset):
    def __init__(self, dataset_dir, device, download=False, exclude_modality='none', target_modality='none', train=True, transform=None, adv_attack=None):
        super().__init__(dataset_dir, device, download, exclude_modality, target_modality, train, transform, adv_attack)

    @staticmethod
    def _download():
        run([os.path.join(os.getcwd(), "datasets", "mosi", "download_mosi_dataset.sh"), "bash"], shell=True)
        dataset_dir = os.path.join(os.getcwd(), "datasets", "mosi")
        data = torch.load(os.path.join(dataset_dir, "mosi_train.dt"))
        val_data = torch.load(os.path.join(dataset_dir, "mosi_valid.dt"))
        dataset = {'text': torch.concat((data.text, val_data.text)), 'audio': torch.concat((data.audio, val_data.audio)), 'vision': torch.concat((data.vision, val_data.vision)), 'labels': torch.concat((data.labels, val_data.labels))}
        torch.save(dataset, os.path.join(dataset_dir, "mosi_train.pt"))
        data = torch.load(os.path.join(dataset_dir, "mosi_test.dt"))
        dataset = {'text': data.text, 'audio': data.audio, 'vision': data.vision, 'labels': data.labels}
        torch.save(dataset, os.path.join(dataset_dir, "mosi_test.pt"))
        return
    
    def _load_data(self, train):     
        if train:
            data_path = os.path.join(self.dataset_dir, "mosi_train.pt")
            
        else:
            data_path = os.path.join(self.dataset_dir, "mosi_test.pt")

        data = torch.load(data_path)
        self.dataset = {'text': data['text'].to(self.device), 'audio': data['audio'].to(self.device), 'vision': data['vision'].to(self.device)}
        self.labels = data['labels'].to(self.device)

        self.dataset_len = len(self.labels)

        if self.exclude_modality != 'none' and self.exclude_modality is not None:
            self.dataset[self.exclude_modality] = torch.full(self.dataset[self.exclude_modality], -1).to(self.device)

        for mod in ['text', 'audio', 'vision']:
            if mod != self.exclude_modality:
                self.dataset[mod] = (self.dataset[mod] - torch.min(self.dataset[mod])) / (torch.max(self.dataset[mod]) - torch.min(self.dataset[mod]))

        return

In [8]:
def setup_dataset(m_path, dataset_name, device="cpu", train=True):
    if dataset_name == 'mhd':
        dataset = MhdDataset('mhd', os.path.join(m_path, "datasets", "mhd"), device, False, None, None, train)
    elif dataset_name == 'mosi':
        dataset = MosiDataset('mosi', os.path.join(m_path, "datasets", "mosi"), device, False, None, None, train)
    elif dataset_name == 'mosei':
        dataset = MoseiDataset('mosei', os.path.join(m_path, "datasets", "mosei"), device, False, None, None, train)
    elif dataset_name == 'mnist_svhn':
        dataset = MnistSvhnDataset('mnist_svhn', os.path.join(m_path, "datasets", "mnist_svhn"), device, False, None, None, train)
    return dataset

m_path = os.path.split(os.getcwd())[0]
mhd_dataset = setup_dataset(m_path, "mhd")
mosi_dataset = setup_dataset(m_path, "mosi")
mosei_dataset = setup_dataset(m_path, "mosei")
ms_dataset = setup_dataset(m_path, "mnist_svhn")