## 0 - Pull in dependencies and initiate Duet session

In [None]:
# First lets pull in our dependencies and initiate our duet session
import torch
import random
from torch import nn
from torch import optim
import syft as sy

duet = sy.join_duet(loopback=True)

## 1 - Define and send our remote assets

Here we'll define the remote model which will have the remote input data fed into it. This includes;

- Our first model segment
- Our dummy input data

In [None]:
import torchvision


# we need some transforms for the MNIST data set
local_transform_1 = torchvision.transforms.ToTensor()  # this converts PIL images to Tensors
local_transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset

# compose our transforms
local_transforms = torchvision.transforms.Compose([local_transform_1, local_transform_2])

args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 8,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": False,
    "dry_run": False,
    "seed": 42, # the meaning of life
    "log_interval": 10,
    "save_model": True,
}

In [None]:
from syft.util import get_root_data_path

train_kwargs = {
    "batch_size": args["batch_size"],
}

train_data = torchvision.datasets.MNIST(str(get_root_data_path()), train=True, download=True, transform=local_transforms)
train_loader = torch.utils.data.DataLoader(train_data,**train_kwargs)

data_pointer = []
labels = []

for image, label in train_loader:
    data_pointer.append(image.send(duet))
    labels.append(label)

In [None]:
remote_torch = duet.torch

# In order to serialise our model we need to define it as below
class SyNet1(sy.Module):
    def __init__(self, torch_ref):
        super(SyNet1, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = self.torch_ref.nn.Conv2d(32, 64, 3, 1) 

    def forward(self, x):
        x = self.conv1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.torch_ref.nn.functional.max_pool2d(x, 2)
        output = x
        return output
    
model1 = SyNet1(torch)
model1_ptr = model1.send(duet)
opt1 = duet.torch.optim.SGD(params=model1_ptr.parameters(),lr=["lr"])
sch1 = duet.torch.optim.lr_scheduler.StepLR(opt1, step_size=1, gamma=args["gamma"])

#Define and send our dummy input data


In [None]:
class SyNet2(sy.Module):
    def __init__(self, torch_ref):
        super(SyNet2, self).__init__(torch_ref=torch_ref)
        self.dropout1 = self.torch_ref.nn.Dropout2d(0.25)
        self.dropout2 = self.torch_ref.nn.Dropout2d(0.5)
        self.fc1 = self.torch_ref.nn.Linear(9216, 128)
        self.fc2 = self.torch_ref.nn.Linear(128, 10)

    def forward(self, x):
        x = self.dropout1(x)
        x = self.torch_ref.flatten(x, 1)
        x = self.fc1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output
    
model2 = SyNet2(torch)
opt2 = torch.optim.Adadelta(model2.parameters(), lr=args["lr"])
sch2 = torch.optim.lr_scheduler.StepLR(opt2, step_size=1, gamma=args["gamma"])

## 3 - Define our training logic

In [None]:
dry_run = True

for x in range(args["epochs"]):
    for y in range (len(labels)):
        if y > 10 and dry_run:
            break
        opt1.zero_grad()
        opt2.zero_grad()

        activation_ptr = model1_ptr(data_pointer[y])
        activation = activation_ptr.clone().get(request_block=True, reason="process the model")
        activation.retain_grad()
        
#         print(activation[0])

        pred = model2(activation)
        loss = torch.nn.functional.nll_loss(pred, labels[y])
        loss.backward()
        
#         print(activation.grad)

        grad_ptr = activation.grad.clone().send(duet)
        activation_ptr.backward(grad_ptr)

        opt1.step()
        opt2.step()
    print(f"Epoch {x} Loss: {loss.item()}")
    if dry_run:
        break

## 4 - Pull in our Test Set

In [None]:
test_kwargs = {
    "batch_size": args["test_batch_size"],
}

test_data = torchvision.datasets.MNIST(str(get_root_data_path()), train=False, download=True, transform=local_transforms)
test_loader = torch.utils.data.DataLoader(test_data,**train_kwargs)

In [None]:
model1 = model1_ptr.get(request_block=True, reason="run testing ont the model")

## 5 - Test our Model

In [None]:
# test_data_length = len(test_loader.dataset)
# test_batches = round((test_data_length / args["test_batch_size"]) + 0.5)
# test_loss = 0.0
# correct = 0.0
    
# for batch_idx, (data, target) in enumerate(test_loader):
#     output = model2(model1(data))
#     iter_loss = torch.nn.functional.nll_loss(output, target, reduction="sum").item()
#     test_loss = test_loss + iter_loss
#     pred = output.argmax(dim=1)
#     total = pred.eq(target).sum().item()
#     correct += total
            
#     if batch_idx >= test_batches - 1:
#                 print("batch_idx >= test_batches, breaking")
#                 break
#     accuracy = correct / test_data_length
#     print(f"Test Set Accuracy: {test_loss}%")