In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import syft as sy
import numpy as np
import torch as th
from syft import VirtualMachine
from pathlib import Path
from torchvision import datasets, transforms
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.lib.python.collections.ordered_dict import OrderedDict
from syft.lib.python.list import List
from matplotlib import pyplot as plt
from syft import logger
from syft import SyModule, SySequential
logger.remove()

# Dataset

In [3]:
mnist_path = Path.home() / ".pysyft" / "mnist"
mnist_path.mkdir(exist_ok=True, parents=True)

In [4]:
mnist_train = datasets.MNIST(str(mnist_path), train=True, download=True,
               transform=transforms.Compose([
                   transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

mnist_test = datasets.MNIST((mnist_path), train=False, 
              transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

In [5]:
train_loader = th.utils.data.DataLoader(mnist_train, batch_size=64*3, shuffle=True, pin_memory=True)
test_loader = th.utils.data.DataLoader(mnist_test, batch_size=1024, shuffle=True, pin_memory=True)

# Define Plan

In [6]:
# !pip install timm
import timm

In [7]:
pretrained_model = timm.create_model('resnet18', pretrained=True)
# model = timm.create_model('resnet18d', pretrained=True)

In [8]:
class BasicBlock(SyModule):
    
    def __init__(self, f_in, f_out, stride1=1, downsample=False, **kwargs):
        super().__init__(**kwargs)
        
        self.conv1 = th.nn.Conv2d(f_in, f_out, kernel_size=(3, 3), stride=stride1, padding=(1, 1), bias=False)
        self.bn1 = th.nn.BatchNorm2d(f_out)
        self.act1 = th.nn.ReLU()
        self.conv2 = th.nn.Conv2d(f_out, f_out, kernel_size=(3, 3), padding=(1, 1), bias=False)
        self.bn2 = th.nn.BatchNorm2d(f_out)
        self.act2 = th.nn.ReLU()
        if downsample ==False:
            self.downsample = None
        else:
            self.downsample = SySequential(
                th.nn.Conv2d(f_in, f_out, kernel_size=(1, 1), stride=2, bias=False),
                th.nn.BatchNorm2d(f_out),
                input_size=self.input_size
            )
        
    
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            residual = self.downsample(x=residual)[0]
        x += residual
        x = self.act2(x)
        return x
        

In [9]:
class ResNet18(SyModule):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # stem
        self.conv1 = th.nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
        self.bn1 = th.nn.BatchNorm2d(64)
        self.act1 = th.nn.ReLU()
        self.maxpool = th.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # blocks
        filters = [(64, 64), (64, 128), (128, 256), (256, 512)]   
        input1_sizes = [(2, 64, 7, 7), (2, 64, 7, 7), (2, 128, 4, 4), (2, 256, 2, 2)]
        input2_sizes = [(2, 64, 7, 7), (2, 128, 7, 7), (2, 256, 4, 4), (2, 512, 2, 2)]

        for i in range(1,5):
            downsample_first = i != 1
            f_in, f_out = filters[i-1]
            f_in2 = f_out
            stride1 = 1 if i == 1 else 2
            input1_size = input1_sizes[i-1]
            input2_size = input2_sizes[i-1]
            
            layer = SySequential(
                BasicBlock(f_in=f_in, f_out=f_out, downsample=downsample_first, stride1=stride1,
                           input_size=input1_size),
                BasicBlock(f_in=f_in2, f_out=f_out,
                           input_size=input2_size)
            )
            setattr(self, f"layer{i}", layer)
            
        # head
        self.global_pool = th.nn.AdaptiveAvgPool2d(1)
        self.fc = th.nn.Linear(in_features=512, out_features=10, bias=True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.maxpool(x)
        
        # self.layern are user defined layers and therefore need the self.layern(x=x)[0] stuff
        x = self.layer1(x=x)[0]
        x = self.layer2(x=x)[0]
        x = self.layer3(x=x)[0]
        x = self.layer4(x=x)[0]
        x = self.global_pool(x).flatten(1)
        x = self.fc(x)
        return x
        
        

In [10]:
model = ResNet18(input_size=(2,3,32,32))

In [11]:
state_dict = dict(filter(lambda x: "fc." not in x[0], pretrained_model.state_dict().items()))

In [12]:
model.load_state_dict(state_dict, strict=False);

## Data

In [13]:
cifar10_path = Path.home() / ".pysyft" / "cifar10"
cifar10_path.mkdir(exist_ok=True, parents=True)
norm = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)

cifar_train = datasets.CIFAR10(cifar10_path, train=True, download=True,
               transform=transforms.Compose([
                   transforms.RandomHorizontalFlip(),
                   transforms.ToTensor(),
                   transforms.Normalize(*norm)
               ]))

cifar_test = datasets.CIFAR10((cifar10_path), train=False, 
              transform=transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize(*norm)]))

Files already downloaded and verified


In [14]:
train_loader = th.utils.data.DataLoader(cifar_train, batch_size=64, shuffle=True, pin_memory=True)
test_loader = th.utils.data.DataLoader(cifar_test, batch_size=1024, shuffle=True, pin_memory=True)

## Plan

In [15]:
remote_torch = ROOT_CLIENT.torch
dummy_dl = sy.lib.python.List([next(iter(train_loader))])

In [16]:
@make_plan
def train(dl = dummy_dl, model=model):
    
    optimizer = remote_torch.optim.AdamW(model.parameters())

    for xy in dl:
        optimizer.zero_grad()
        x, y = xy[0], xy[1]
        out = model(x=x)[0]
        loss = remote_torch.nn.functional.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
    
    return [model]

In [17]:
def test(test_loader, model):
    correct = []
    model.eval()

    for data, target in test_loader:        
        output = model(x=data)[0]
        _, pred = th.max(output, 1)
        correct.append(th.sum(np.squeeze(pred.eq(target.data))))
    acc = sum(correct) / len(test_loader.dataset)
    return acc

In [18]:
alice_client = VirtualMachine(name="alice").get_client()
train_ptr = train.send(alice_client)

The expected accuracy for this pretrained model:

| iter          | Test acc |
|---------------|----------|
| 10            | 10%      |
| 100           | 25%      |
| 200           | 43%      |
| 300           | 54%      |
| 782 (1 epoch) | 70%      |

Currently, this is very slow because the model needs to be serialized & deserialized every time we run it.

In [None]:
for i, (x, y) in enumerate(train_loader):
    print(f"iter {i}")
    dl = [[x,y]]
    res_ptr  = train_ptr(dl=dl, model=model)
    model, = res_ptr.get()
    
    if i%10 == 0 and i!=0:
        print(f"Iter: {i} Test accuracy: {test(test_loader, model):.2F}", flush=True)
    if i>50:
        break

iter 0
iter 1
iter 2
iter 3
iter 4
iter 5
iter 6
iter 7
iter 8
iter 9
iter 10


  grad = getattr(obj, "grad", None)


Iter: 10 Test accuracy: 0.37
iter 11
iter 12
iter 13
iter 14
iter 15


# Appendix

In [66]:
# layer1 = SySequential(
#     BasicBlock(f_in=64, f_out=64, stride1=1, input_size=(2, 64, 7, 7)),
#     BasicBlock(f_in=64, f_out=64, input_size=(2, 64, 7, 7))
# )


In [22]:
# layer2 = SySequential(
#     BasicBlock(f_in=64, f_out=128, stride1=2, downsample=True, input_size=(1, 64, 7, 7)),
#     BasicBlock(f_in=128, f_out=128, input_size=(1, 128, 7, 7))
# )

In [23]:
# layer3 = SySequential(
#     BasicBlock(f_in=128, f_out=256, stride1=2, downsample=True, input_size=(1, 128, 4, 4)),
#     BasicBlock(f_in=256, f_out=256, input_size=(1, 256, 4, 4))
# )

In [24]:
# layer4 = SySequential(
#     BasicBlock(f_in=256, f_out=512, stride1=2, downsample=True, input_size=(2, 256, 2, 2)),
#     BasicBlock(f_in=512, f_out=512, input_size=(2, 512, 2, 2))
# )

In [None]:
# # Cell
# class XResNet(nn.Sequential):
#     def __init__(self, block, expansion, layers, p=0.0, c_in=3, n_out=1000, stem_szs=(32,32,64),
#                  widen=1.0, sa=False, act_cls=defaults.activation, ndim=2, ks=3, stride=2, **kwargs):
#         store_attr('block,expansion,act_cls,ndim,ks')
#         if ks % 2 == 0: raise Exception('kernel size has to be odd!')
#         stem_szs = [c_in, *stem_szs]
#         stem = [ConvLayer(stem_szs[i], stem_szs[i+1], ks=ks, stride=stride if i==0 else 1,
#                           act_cls=act_cls, ndim=ndim) for i in range(3)]

#         block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
#         block_szs = [64//expansion] + block_szs
#         blocks    = self._make_blocks(layers, block_szs, sa, stride, **kwargs)

#         super().__init__(
#             *stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=ndim),
#             *blocks,
#             AdaptiveAvgPool(sz=1, ndim=ndim), Flatten(), nn.Dropout(p),
#             nn.Linear(block_szs[-1]*expansion, n_out),
#         )
#         init_cnn(self)

#     def _make_blocks(self, layers, block_szs, sa, stride, **kwargs):
#         return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l,
#                                  stride=1 if i==0 else stride, sa=sa and i==len(layers)-4, **kwargs)
#                 for i,l in enumerate(layers)]

#     def _make_layer(self, ni, nf, blocks, stride, sa, **kwargs):
#         return nn.Sequential(
#             *[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1,
#                       sa=sa and i==(blocks-1), act_cls=self.act_cls, ndim=self.ndim, ks=self.ks, **kwargs)
#               for i in range(blocks)])
