In [None]:
from torchvision import datasets, transforms
import torch
from torch import nn
from torch import functional as F
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from opacus import PrivacyEngine
from torchvision import datasets, transforms
import os
import tenseal as ts
import syft as sy
sy.load("tenseal")
from collections import OrderedDict

In [None]:
train_dataset = datasets.MNIST(root=".", train=True, download=True)

In [None]:
mnist_mean = 0.1307
mnist_std = 0.3081
batch_size = 64
epochs = 1
lr = 0.1
sigma = 1.0
max_per_sample_grad_norm = 1.0
delta = 1e-5
root = "."
weights_filename = "mnist_cnn_weights.pt"
device = torch.device("cpu")


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # -> [B, 16, 14, 14]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 16, 13, 13]
        x = F.relu(self.conv2(x))  # -> [B, 32, 5, 5]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 32, 4, 4]
        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]
        x = self.fc1(x)  # -> [B, 32]
        x = x * x  # -> [B, 32] square activation
        x = self.fc2(x)  # -> [B, 10]
        return x

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = []
    for _batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

def test(model, device, test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )
    return correct / len(test_loader.dataset)


def train_model():  
    model = Model().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0)
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            root,
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((mnist_mean,), (mnist_std,)),
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers = 1,
        pin_memory = True
    )
    
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            root,
            train=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((mnist_mean,), (mnist_std,)),
                ]
            ),
        ),
        batch_size=1000,
        shuffle=True,
        num_workers = 1,
        pin_memory = True
    )
    
    privacy_engine = PrivacyEngine(
        model,
        batch_size=batch_size,
        sample_size=len(train_loader.dataset),
        alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
        noise_multiplier=sigma,
        max_grad_norm=max_per_sample_grad_norm,
    )

    privacy_engine.attach(optimizer)

    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
    torch.save(model.state_dict(), weights_filename)
    return model

In [None]:
if os.path.isfile(root + "/" + weights_filename):
    model = Model()
    model.load_state_dict(torch.load(weights_filename))
    model.eval()
else:
    model = train_model()
    model.eval()

In [None]:
conv_base_params = sy.lib.python.List(["conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias"])
fc_head_params = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]

state_dict = model.state_dict()
conv_base_weights = sy.lib.python.List()
fully_connected_weights = OrderedDict()

for param_name in conv_base_params:
    conv_base_weights.append(state_dict[param_name])

for param_name in fc_head_params:
    fully_connected_weights[param_name] = state_dict[param_name]

In [None]:
class FullyConnectedHead():
    def __init__(self, parameters):
        self.fc1_weight = parameters["fc1.weight"]
        self.fc1_bias = parameters["fc1.bias"]
        self.fc2_weight = parameters["fc2.weight"]
        self.fc2_bias = parameters["fc2.bias"]

    def forward(self, enc_x: ts.CKKSTensor, batch_size: int) -> ts.CKKSTensor:
        if batch_size == 1:
            fc1_bias = self.fc1_bias.unsqueeze(0)
            fc2_bias = self.fc2_bias.unsqueeze(0)
        else:
            fc1_bias = torch.stack([self.fc1_bias for elem in range(batch_size)])
            fc2_bias = torch.stack([self.fc2_bias for elem in range(batch_size)])
        print(fc1_bias.shape)
        out = enc_x.mm(self.fc1_weight.T) + fc1_bias
        out.square_()
        out = out.mm(self.fc2_weight.T) + fc2_bias
        return out

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

fully_connected_head = FullyConnectedHead(fully_connected_weights)

In [None]:
duet = sy.join_duet(loopback=True)

In [None]:
conv_base_params_ptr_ = conv_base_params.send(duet, searchable=True, tags=["conv_base_names"])
conv_base_weights_ptr = conv_base_weights.send(duet, searchable=True, tags=["conv_base_weights"])

In [None]:
print(duet.store)

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 1 : Now STOP and run the Data Owner notebook until the next checkpoint.

In [None]:
batch_size = duet.store["batch_size"].get(request_block=True)
print("Got the batch_size")
context = duet.store["context"].get(request_block=True)
print("Got the context")
encrypted_activation = duet.store["encrypted_activation"].get(request_block=True)
print("Got the encrypted_activation")
encrypted_activation.link_context(context)
print("Linked the context")

In [None]:
encrypted_result = fully_connected_head(encrypted_activation, batch_size)

In [None]:
encrypted_result_ptr = encrypted_result.send(duet, searchable=True, tags=["result"])
# comment this to hang the store
# print(duet.store)

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 2 : Now STOP and run the Data Owner notebook until the next checkpoint.