## Utility functions to Split MNIST data in two parts: bottom and top halves

This code saves the images, the labels, and assigns indexes to the single samples. The indexes are necessary for linkage of the two halves with PSI (Private Set Intersection). The two different part of the images will be distributed to Data Owners. 

The Data Scientist will hold labels and will link the labels with the respective halves. In such a way, the Data Scientist would then be able to link the two parts and run the remote segment of the Split Neural Network on it. 

In [1]:
from __future__ import print_function
import syft as sy
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from typing import List, Tuple
from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler

from torchvision import datasets, transforms

In [2]:
def split_data(dataset, worker_list=None, n_workers=2):

    if worker_list is None:
        worker_list = list(range(0, n_workers))
            
    #counter to create the index of different data samples
    idx = 0 
    
    #dictionary to accomodate the split data
    dic_single_datasets = {}
    for worker in worker_list: 
        """
        Each value is a list of three elements, to accomodate, in order: 
        - data examples (as tensors)
        - label
        - index 
        """
        dic_single_datasets[worker] = [] 

    """
    Loop through the dataset to split the data and labels vertically to assign to data owners. 

    """
    label_list = []
    index_list = []
    for tensor, label in dataset: 
        height = tensor.shape[-1]//len(worker_list)
        i = 0
        for worker in worker_list[:-1]: 
            dic_single_datasets[worker].append(tensor[:, :, height * i : height * (i + 1)])
            i += 1
            
        #add the value of the last data owner / split
        dic_single_datasets[worker_list[-1]].append(tensor[:, :, height * (i) : ])
        label_list.append(torch.Tensor([label]))
        index_list.append(torch.Tensor([idx]))
        
        idx += 1
        
    return dic_single_datasets, label_list, index_list

In [3]:
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
dic_single_datasets, label_list, index_list = split_data(trainset)

In [5]:
import pickle

file = "data_file/1st_owner"
pickle.dump(dic_single_datasets[0], open(file, 'wb'))

file = "data_file/2nd_owner"
pickle.dump(dic_single_datasets[1], open(file, 'wb'))

file = "data_file/indexlist"
pickle.dump(index_list, open(file, 'wb'))

file = "data_file/labellist"
pickle.dump(label_list, open(file, 'wb'))