In [57]:
import torch
import scipy.io as sio
import random
import numpy as np
from torch.utils.data import DataLoader


In [58]:
class JointDataset(torch.utils.data.Dataset):
    def __init__(self, mnist_pt_path_1, mnist_pt_path_2):

        self.mnist_pt_path_1 = mnist_pt_path_1
        self.mnist_pt_path_2 = mnist_pt_path_2
        # Load the pt for MNIST 
        self.mnist_data_1, self.mnist_targets_1 = torch.load(self.mnist_pt_path_1)
        
        # Load the pt for MNIST 

        self.mnist_data_2, self.mnist_targets_2 = torch.load(self.mnist_pt_path_2)
        self.mnist_target_idx_mapping = self.process_mnist_labels()

    def process_mnist_labels(self):
        numbers_dict = {0: [], 1: [], 2: [], 3:[], 4:[], 5:[], 6:[], 7: [], 8:[], 9:[]}
        for i in range(len(self.mnist_targets_2)):
            mnist_target = self.mnist_targets_2[i].item()
            numbers_dict[mnist_target].append(i)
        return numbers_dict
        
        
    def __len__(self):
        return len(self.mnist_data_1)
        
    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index
        """
        mnist_img_1, mnist_target_1 = self.mnist_data_1[index], int(self.mnist_targets_1[index])
        indices_list = self.mnist_target_idx_mapping[(mnist_target_1+1)%10]
        # Randomly pick an index from the indices list
        idx = random.choice(indices_list)
        mnist_img_2    = self.mnist_data_2[idx]
        mnist_target_2 = int(self.mnist_targets_2[idx])
        

        
        return mnist_img_1/255, mnist_img_2/255, mnist_target_1, mnist_target_2


In [59]:
# ## Importing MNIST and MNIST datasets
# MNIST_TRAINING_PATH = "/home/achint/Practice_code/VAE/MNIST/MNIST/processed/training.pt"
# MNIST_TEST_PATH     = "/home/achint/Practice_code/VAE/MNIST/MNIST/processed/test.pt"

# joint_dataset_train = JointDataset(mnist_pt_path_1 = MNIST_TRAINING_PATH,
#                                                       mnist_pt_path_2 = MNIST_TRAINING_PATH)
# joint_dataset_test  = JointDataset(mnist_pt_path_1 = MNIST_TEST_PATH,
#                                                      mnist_pt_path_2 = MNIST_TEST_PATH)

# joint_dataset_train_loader = DataLoader(
#     joint_dataset_train,
#     batch_size=batch_size,
#     shuffle=True,
#     drop_last=True
# )
# joint_dataset_test_loader = DataLoader(
#     joint_dataset_test,
#     batch_size=batch_size,
#     shuffle=False,
#     drop_last=True
# )
# for i,data in enumerate(joint_dataset_train_loader):
#     data_target1=data[2] 
#     data_target2=data[3] 

In [60]:
# data_target1

tensor([3, 2, 7, 7, 6, 9, 3, 1, 5, 7, 6, 9, 1, 3, 0, 2, 6, 2, 9, 4, 0, 8, 9, 3,
        2, 9, 6, 2, 9, 5, 3, 2, 3, 3, 0, 5, 8, 1, 3, 5, 4, 1, 8, 7, 2, 3, 0, 7,
        6, 3, 7, 3, 3, 7, 5, 9, 9, 5, 4, 6, 0, 5, 7, 6, 6, 0, 8, 1, 6, 0, 0, 0,
        3, 7, 3, 9, 1, 8, 3, 1, 1, 8, 4, 0, 4, 6, 1, 4, 9, 3, 8, 0, 0, 7, 9, 2,
        5, 4, 6, 6, 2, 2, 5, 4, 0, 6, 9, 7, 2, 9, 9, 3, 1, 4, 2, 4, 9, 4, 4, 0,
        5, 0, 7, 0, 5, 6, 6, 9])

In [62]:
# data_target2

tensor([4, 3, 8, 8, 7, 0, 4, 2, 6, 8, 7, 0, 2, 4, 1, 3, 7, 3, 0, 5, 1, 9, 0, 4,
        3, 0, 7, 3, 0, 6, 4, 3, 4, 4, 1, 6, 9, 2, 4, 6, 5, 2, 9, 8, 3, 4, 1, 8,
        7, 4, 8, 4, 4, 8, 6, 0, 0, 6, 5, 7, 1, 6, 8, 7, 7, 1, 9, 2, 7, 1, 1, 1,
        4, 8, 4, 0, 2, 9, 4, 2, 2, 9, 5, 1, 5, 7, 2, 5, 0, 4, 9, 1, 1, 8, 0, 3,
        6, 5, 7, 7, 3, 3, 6, 5, 1, 7, 0, 8, 3, 0, 0, 4, 2, 5, 3, 5, 0, 5, 5, 1,
        6, 1, 8, 1, 6, 7, 7, 0])