In [1]:
import torch
import torchvision
from torchvision.datasets import MNIST
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

In [None]:
data_dir = "/scratch/qingqu_root/qingqu1/DL/xlxiao/data/"
mnist_train = MNIST(root=data_dir, download=True, train=True, transform=None)
mnist_test = MNIST(root=data_dir, download=True, train=False, transform=None)

In [6]:
def pad_image(image,pad_size,left=True):
    if left:
        new_image = F.pad(image, (pad_size, 0, pad_size, 0),"constant", 0)
    else: 
        new_image = F.pad(image, (0, pad_size, 0, pad_size),"constant", 0)
    return new_image

def random_shift(image, max_shift):

    right_shift = np.random.randint(max_shift)
    down_shift = np.random.randint(max_shift)

    new_image = torch.roll(image, shifts=(down_shift, right_shift), dims=(0, 1))
    return new_image

def pad_then_shift(image, pad_size, left=True):
    return random_shift(pad_image(image, pad_size, left),pad_size/4)

def create_m1_samples(data, targets, class_size=5000, pad_size=28):
    all_data = []
    all_targets = []
    for cla_idx in range(10):
        cla_length = torch.sum(targets==cla_idx).item()
        # Some randomness
        chosen_ones = np.random.permutation(cla_length)[:class_size]
        cur_cla_data = data[targets==cla_idx][chosen_ones]
        for i in range(class_size):
            cur_data = pad_then_shift(cur_cla_data[i], pad_size)
            all_data.append(cur_data)
            
            cur_target = torch.zeros(10)
            cur_target[cla_idx] = 1.0
            all_targets.append(cur_target)
    return torch.stack(all_data, dim=0), torch.stack(all_targets, dim=0)

def create_m2_samples(data, targets, class_size=5000, pad_size=28):
    all_data = []
    all_targets = []
    for idx_1 in range(10):
        for idx_2 in range(idx_1+1, 10):
            cla_length_1 = torch.sum(targets==idx_1).item()
            cla_length_2 = torch.sum(targets==idx_2).item()
            # Some randomness
            chosen_ones_1 = np.random.permutation(cla_length_1)[:class_size]
            chosen_ones_2 = np.random.permutation(cla_length_2)[:class_size]
            cla_data_1 = data[targets==idx_1][chosen_ones_1]
            cla_data_2 = data[targets==idx_2][chosen_ones_2]
        
            for i in range(class_size):
                first_left = torch.rand(1) < 0.5
                cur_data_1 = pad_then_shift(cla_data_1[i], pad_size, left=first_left)
                cur_data_2 = pad_then_shift(cla_data_2[i], pad_size, left=not first_left)
                all_data.append(torch.maximum(cur_data_1, cur_data_2))

                cur_target = torch.zeros(10)
                cur_target[idx_1] = 0.5
                cur_target[idx_2] = 0.5
                all_targets.append(cur_target)
    return torch.stack(all_data, dim=0), torch.stack(all_targets, dim=0)

def create_dataset(num_samples, data, targets, pad_size=28):
    m1_num, m2_num = num_samples
    m1_data, m1_targets = create_m1_samples(data, targets, class_size=m1_num, pad_size=pad_size)
    m2_data, m2_targets = create_m2_samples(data, targets, class_size=m2_num, pad_size=pad_size)
    
    return torch.cat([m1_data, m2_data], dim=0), torch.cat([m1_targets, m2_targets], dim=0)

In [7]:
trainset, trainlabels = create_dataset([3100, 200], mnist_train.data / torch.max(mnist_train.data), 
                                       mnist_train.targets, pad_size=28)
testset, testlabels = create_dataset([800, 50], mnist_test.data / torch.max(mnist_test.data), 
                                       mnist_test.targets, pad_size=28)

In [13]:
import pickle
to_save = {
            "train_data": trainset,
            "train_label": trainlabels,
            "test_data": testset,
            "test_label": testlabels
            }
    
with open("/scratch/qingqu_root/qingqu1/xlxiao/DL/data/mnist_combine.pkl", 'wb') as f: 
    pickle.dump(to_save, f)