# Install Crypten
Currently pip version of this framework is unstable due to some version dependency. It needs to be install from the source. [Issue link](https://github.com/facebookresearch/CrypTen/issues/391). 

Ignore this part if you have crypten already installed. This is exclusively for Google colab.

In [None]:
!git clone https://github.com/facebookresearch/CrypTen.git
%cd CrypTen
# after this commit some version dependency is broken
!git checkout efe8edad571be1c586d0d9cefc562d562d4e9aa1
# !python setup.py install --user
%pip install -e .

## Check installed version

In [None]:
!pip show crypten

Name: crypten
Version: 0.4.0
Summary: CrypTen: secure machine learning in PyTorch.
Home-page: https://github.com/facebookresearch/CrypTen
Author: Facebook AI Research
Author-email: None
License: MIT licensed, as found in the LICENSE file
Location: /root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg
Requires: torch, torchvision, omegaconf, onnx, pandas, pyyaml, tensorboard, future, scipy, sklearn
Required-by: 


## Fix existing bug
[Issue link](https://github.com/facebookresearch/CrypTen/issues/438). Due to "/config" string in the setup.py of this framework, the crypten configs are not copied properly. You can either change "/config" to "config" manually or do the following.

Note: Seems `pip install -e .` fixes it.

In [None]:
# current setup file doesn't copy the default.yaml correctly in the configs folder
!cp configs/default.yaml /root/.local/lib/python3.8/site-packages/crypten-0.4.0-py3.8.egg/configs/

# Restart the runtime
You would need to restart the kernel runtime to load the newly installed crypten module. If you have restarted no need to run the prior cells. You can just start from here. 

# Import Libraries

In [None]:
#import the libraries
import crypten
import torch
import torch.nn as nn
import torch.nn.functional as F

# doesn't work in windows
#initialize crypten
crypten.init()
#Disables OpenMP threads -- needed by @mpc.run_multiprocess which uses fork
torch.set_num_threads(1)

# MNIST Image Classification

## Model

In [None]:
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 10) # For binary classification, final layer needs only 2 outputs
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 16 * 12 * 12)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

## Device

In [None]:
if torch.cuda.is_available(): 
    device = 'cuda'
    print(torch.cuda.get_device_name(device=0))
else:
    device = 'cpu'
print(f'Using {device} backend.')

Using cpu backend.


## Dataset

In [None]:
%cd CrypTen/tutorials/

/content/CrypTen/tutorials


In [None]:
from torchvision.datasets import MNIST
from torchvision import transforms

In [None]:
batch_size=64
test_batch_size=128
train_kwargs = {'batch_size': batch_size, 'shuffle':True}
test_kwargs = {'batch_size': test_batch_size}

transform=transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
        ])

train_data = MNIST('../data', train=True, download=True,
                   transform=transform)
test_data = MNIST('../data', train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(train_data,**train_kwargs)
test_loader = torch.utils.data.DataLoader(test_data, **test_kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [None]:
# Transform labels into one-hot encoding
# label_eye = torch.eye(10)
# y_one_hot = label_eye[test_data.targets]

## Train

In [None]:
model = ExampleNet().to(device)
learning_rate=1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
criterion = nn.CrossEntropyLoss(reduction='mean')

In [None]:
import time
from tqdm.auto import tqdm
DISABLE_PROGRESS = False
from numpy import round

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    progress_bar = tqdm(
        range(len(train_loader)), desc=f'Epoch {epoch} (Train)', 
        disable=DISABLE_PROGRESS
    )

    start_time = time.perf_counter()
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader), loss.item()))
        #     if dry_run:
        #         break

        total_loss += loss.item()
        progress_bar.update(1)
        progress_bar.set_postfix(
            loss=round(total_loss/(batch_idx+1), 4)
        )
    elapsed_time = time.perf_counter() - start_time
    return total_loss/batch_idx, elapsed_time

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    start_time = time.perf_counter()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * accuracy))
    elapsed_time = time.perf_counter() - start_time
    return test_loss, elapsed_time, accuracy


In [None]:
epochs=2
total_train_time, total_test_time = 0, 0
for epoch in range(1, epochs + 1):
    train_loss, train_time = train(model, device, train_loader, optimizer, epoch)
    test_loss, test_time, test_accuracy = test(model, device, test_loader)
    scheduler.step()

    total_train_time += train_time
    total_test_time += test_time
    
print(f'Train time: total {total_train_time:6g}, mean {(total_train_time/epochs):6g}.\n Test time: total {total_test_time:6g}, mean {(total_test_time/epochs):6g}.')

Epoch 1 (Train):   0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0003, Accuracy: 9050/10000 (90%)



Epoch 2 (Train):   0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0003, Accuracy: 9210/10000 (92%)

Train time: total 66.5258, mean 33.2629.
 Test time: total 6.90227, mean 3.45113.


In [None]:
torch.save(model, 'mnist_model.pth')

# Encrypted Train

In [None]:
crypten.common.serial.register_safe_class(ExampleNet)

In [None]:
dummy_input = torch.empty((1, 1, 28, 28)).to(device)
model_enc = crypten.nn.from_pytorch(model, dummy_input)
# if you want to train from scratch use the following
# model_enc = crypten.nn.from_pytorch(ExampleNet(), dummy_input)
model_enc.encrypt()

Graph encrypted module

In [None]:
optimizer = crypten.optim.SGD(model_enc.parameters(), lr=learning_rate)
criterion = crypten.nn.CrossEntropyLoss(reduction='mean') # Choose loss functions

In [None]:
def train_encrypted(model, device, train_loader, optimizer, epoch):
    model.train()
    progress_bar = tqdm(
        range(len(train_loader)), desc=f'Epoch {epoch} (Train)', 
        disable=DISABLE_PROGRESS
    )

    start_time = time.perf_counter()
    total_loss = 0
    label_eye = torch.eye(10)
    for batch_idx, (data, target) in enumerate(train_loader):
        data = crypten.cryptensor(data)
        target = crypten.cryptensor(label_eye[target])
        # if isinstance(model, crypten.nn.Module):
        #     if not crypten.is_encrypted_tensor(data):
        #         data = crypten.cryptensor(data)
        #     if not crypten.is_encrypted_tensor(target):
        #         target = crypten.cryptensor(label_eye[target])

        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.get_plain_text().item()
        progress_bar.update(1)
        progress_bar.set_postfix(
            loss=round(total_loss/(batch_idx+1), 4)
        )

    elapsed_time = time.perf_counter() - start_time
    return total_loss/batch_idx, elapsed_time

In [None]:
def test_encrypted(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    start_time = time.perf_counter()
    label_eye = torch.eye(10)
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data = crypten.cryptensor(data)
            target = crypten.cryptensor(label_eye[target])

            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).get_plain_text().item()  # sum up batch loss
            
            pred = output.get_plain_text().argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.get_plain_text().view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * accuracy))
    elapsed_time = time.perf_counter() - start_time
    return test_loss, elapsed_time, accuracy

In [None]:
num_epochs = 2
for epoch in range(1, num_epochs+1):
    train(model_enc, device, train_loader, optimizer, epoch)

In [None]:
model_enc.eval()
test_loss = 0
correct = 0
start_time = time.perf_counter()
label_eye = torch.eye(10)
with torch.no_grad():
    for data, target in tqdm(test_loader):
        data = crypten.cryptensor(data)
        # target = crypten.cryptensor(label_eye[target])
        target = label_eye[target]

        data, target = data.to(device), target.to(device)
        output = model_enc(data)
        test_loss += criterion(output, target).get_plain_text().item()  # sum up batch loss
        
        # pred = output.get_plain_text().argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        # target = target.get_plain_text().argmax(dim=1, keepdim=True)
        # correct += pred.eq(target).sum().item()
        pred = output.get_plain_text().argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        target = target.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
accuracy = correct / len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * accuracy))
elapsed_time = time.perf_counter() - start_time

  0%|          | 0/10 [00:00<?, ?it/s]


Test set: Average loss: 0.0003, Accuracy: 9210/10000 (92%)



In [None]:
# x_train = crypten.cryptensor(train_data.data.reshape((-1, 1, 28, 28))) # original shape is (N, width, length)

# # Transform labels into one-hot encoding
# label_eye = torch.eye(10)
# y_one_hot = label_eye[train_data.targets]
# y_train = crypten.cryptensor(y_one_hot)

In [None]:
# model_enc.train() # Change to training mode

# # Set parameters: learning rate, num_epochs
# learning_rate = 0.001
# num_epochs = 2

# # Train the model: SGD on encrypted data
# for i in range(num_epochs):

#     # forward pass
#     output = model_enc(x_train)
#     loss_value = criterion(output, y_train)
    
#     # set gradients to zero
#     model_enc.zero_grad()

#     # perform backward pass
#     loss_value.backward()

#     # update parameters
#     model_enc.update_parameters(learning_rate) 
    
#     # examine the loss after each epoch
#     print("Epoch: {0:d} Loss: {1:.4f}".format(i, loss_value.get_plain_text()))