# Encrypted Transfer Learning

In this notebook, we're going to build an encrypted model based on a pretrained model. This will allow us to use already functioning models at our advantage, will retaining protection and privacy both for the data holder and the model owner.

Authors:
 - Alejandro Aristizábal - Github: [@aristizabal95](https://github.com/aristizabal95)
 
This notebook is based on Pysyft's [official tutorial on Encrypted NN](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%2012%20-%20Train%20an%20Encrypted%20Neural%20Network%20on%20Encrypted%20Data.ipynb)

## Step 1: Create Workers

This workers will share both the data and the model in an encrypted manner, using Additive Sharing Tensors.

In [30]:
%matplotlib inline
%config InlineBackend.figure_format = "retina"

import matplotlib.pyplot as plt

import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import syft as sy
import numpy as np

In [8]:
# Set the workers and their hooks
hook = sy.TorchHook(torch)

alice = sy.VirtualWorker(id="alice", hook=hook)
bob = sy.VirtualWorker(id="bob", hook=hook)
james = sy.VirtualWorker(id="james", hook=hook)

W0709 11:24:06.805057 4536808896 hook.py:98] Torch was already hooked... skipping hooking process


## Get the Dataset

For this example, we'll be using a dataset composed of photos of flowers, provided by Kaggle. Get it with [this link](https://www.kaggle.com/alxmamaev/flowers-recognition/downloads/flowers-recognition.zip/2), save it in the home directory of this notebook and extract the zip file to get the directory `flowers/`. **Make sure the directory has this exact name**

**Note:** This dataset contains images of varying sizes, with an average of 320x240 pixels. We will later deal with this 

In [17]:
assert os.path.exists("flowers"), \
"It appears like there's no folder called \"flowers\". \
Did you follow the instructions above?"

### Split the data into train and test datasets

The data obtained from kaggle isn't split between training and testing. This script will do that for you ;)

In [36]:
source_path = 'flowers'
train_path = 'flowers/train'
test_path = 'flowers/test'
test_percent = 0.1 # 10% of the data will be used for testing

##### CREATING THE TRAINING SPLIT ######


# First check if a training path doesn't exist
if not os.path.exists(train_path):
    # Create the training path
    print("Training folder doesn't exist. Creating it")
    os.makedirs(train_path)

    
# Now check if the train folder is empty
if [f for f in os.listdir(train_path) if not f.startswith('.')] == []:
    print("Training folder is empty. Moving data to it")
    datafolders = os.listdir(source_path)
    
    # Exclude the train and test folders if they already exists
    datafolders = list(set(datafolders) - set(['train', 'test']))
    
    for folder in datafolders:
        shutil.move(os.path.join(source_path, folder), train_path)
        

##### CREATING THE TESTING SPLIT ######


# First check if a testing path doesn't exist
if not os.path.exists(test_path):
    # Create the training path
    print("Testing folder doesn't exist. Creating it")
    os.makedirs(test_path)
    

# Now check if the test folder is empty
if [f for f in os.listdir(test_path) if not f.startswith('.')] == []:
    print("Testing folder is empty. Moving data to it")
    
    # Move data from the training set to the testing set
    for folder in os.listdir(train_path):
        if not os.path.exists(os.path.join(test_path, folder)):
            os.makedirs(os.path.join(test_path, folder))
            
        files = sorted(os.listdir(os.path.join(train_path, folder)))
        size = int(len(files)*test_percent)
        rand_files = np.random.choice(files, size, replace=False)
        for file in rand_files:
            shutil.move(os.path.join(train_path, folder, file), os.path.join(test_path, folder, file))

## Load and Transform our Data

We'll be using PyTorch's ImageFolder class to load our data.
The pretrained models expect the input data to have some properties (e.g. WidthxHeight, normalization, etc. For that, we're going to use transforms