In [None]:
import syft as sy
import copy

In [None]:
duet1 = sy.join_duet("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
sy.logger.add(sink="./syft_ds.log")

In [None]:
duet1.torch

In [None]:
duet2 = sy.join_duet("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")

In [None]:
import torch
import torchvision

In [None]:
# Lets define a few settings which are from the original MNIST example command-line args
args = {
    "images": 60000,
    "clients": 2,
    "rounds": 4,
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 4,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": False,
    "dry_run": False,
    "torch_seed": 0, # the meaning of life
    "log_interval": 10,
    "save_model": True,
}

In [None]:
class SyNet(sy.Module):
    def __init__(self, torch_ref):
        super(SyNet, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = self.torch_ref.nn.Conv2d(32, 64, 3, 1) 
        self.dropout1 = self.torch_ref.nn.Dropout2d(0.25)
        self.dropout2 = self.torch_ref.nn.Dropout2d(0.5)
        self.fc1 = self.torch_ref.nn.Linear(9216, 128)
        self.fc2 = self.torch_ref.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.torch_ref.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = self.torch_ref.flatten(x, 1)
        x = self.fc1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output

In [None]:
clients = []

for i in range(args['clients']):
    clients.append({'duet': eval("duet{}".format(i+1))})

In [None]:
torch.manual_seed(args['torch_seed'])
local_model = SyNet(torch)

In [None]:
# Download MNIST manually using 'wget' then uncompress the file
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

In [None]:
# we need some transforms for the MNIST data set
local_transform_1 = torchvision.transforms.ToTensor()  # this converts PIL images to Tensors
local_transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset

# compose our transforms
local_transforms = torchvision.transforms.Compose([local_transform_1, local_transform_2])

In [None]:
# we will configure the test set here locally since we want to know if our Data Owner's
# private training dataset will help us reach new SOTA results for our benchmark test set
test_kwargs = {
    "batch_size": args["test_batch_size"],
}

test_data = torchvision.datasets.MNIST('./', train=False, download=True, transform=local_transforms)
test_loader = torch.utils.data.DataLoader(test_data,**test_kwargs)
test_data_length = len(test_loader.dataset)
print(test_data_length)

In [None]:
for ind, client in enumerate(clients):
    client['remote_torch'] = client['duet'].torch
    torch.manual_seed(ind)
    client['model'] = SyNet(torch)

In [None]:
# lets ask to see if our Data Owner has CUDA
has_cuda = False
print(has_cuda)

In [None]:
def train(client, epoch, args):
    # + 0.5 lets us math.ceil without the import
    train_batches = round((client['train_data_length'] / args["batch_size"]) + 0.5)
    print(f"> Running train in {train_batches} batches")
    if client['remote_model'].is_local:
        print("Training requires remote model")
        return

    client['remote_model'].train()

    for batch_idx, data in enumerate(client['train_loader_ptr']):
        data_ptr, target_ptr = data[0], data[1]
        client['optim'].zero_grad()
        output = client['remote_model'](data_ptr)
        loss = client['remote_torch'].nn.functional.nll_loss(output, target_ptr)
        loss.backward()
        client['optim'].step()
        loss_item = loss.item()
        train_loss = client['duet'].python.Float(0)  # create a remote Float we can use for summation
        train_loss += loss_item
        if batch_idx % args["log_interval"] == 0:
            local_loss = None
            local_loss = loss_item.get(
                name="loss",
                reason="To evaluate training progress",
                request_block=True,
                timeout_secs=5
            )
            if local_loss is not None:
                print("Train Epoch: {} {} {:.4}".format(epoch, batch_idx, local_loss))
            else:
                print("Train Epoch: {} {} ?".format(epoch, batch_idx))
            if args["dry_run"]:
                break
        if batch_idx >= train_batches - 1:
            print("batch_idx >= train_batches, breaking")
            break

In [None]:
def test_local(model, test_loader, test_data_length):
    current_model = None
    # download remote model
    if not model.is_local:
        current_model = model.get(
            request_block=True,
            name="model_download",
            reason="test evaluation",
            timeout_secs=5
        )
    else:
        current_model = model
    # + 0.5 lets us math.ceil without the import
    test_batches = round((test_data_length / args["test_batch_size"]) + 0.5)
    print(f"> Running test_local in {test_batches} batches")
    current_model.eval()
    test_loss = 0.0
    correct = 0.0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            output = current_model(data)
            iter_loss = torch.nn.functional.nll_loss(output, target, reduction="sum").item()
            test_loss = test_loss + iter_loss
            pred = output.argmax(dim=1)
            total = pred.eq(target).sum().item()
            correct += total
            if args["dry_run"]:
                break
                
            if batch_idx >= test_batches - 1:
                print("batch_idx >= test_batches, breaking")
                break

    accuracy = correct / test_data_length
    print(f"Test Set Accuracy: {100 * accuracy}%")

In [None]:
def averageModels(global_model, clients):
    client_models = [clients[i]['model'] for i in range(len(clients))]
    samples = [clients[i]['samples'] for i in range(len(clients))]
    global_dict = global_model.state_dict()
    
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() * samples[i] for i in range(len(client_models))], 0).sum(0)
            
    global_model.load_state_dict(global_dict)
    return global_model

In [None]:
# The DO has kindly let us initialise a DataLoader for their training set
train_kwargs = {
    "batch_size": args["batch_size"],
}

for ind, client in enumerate(clients):
    client['remote_torchvision'] = client['duet'].torchvision
    
    transform_1 = client['remote_torchvision'].transforms.ToTensor()
    transform_2 = client['remote_torchvision'].transforms.Normalize(0.1307, 0.3081)
    
    client['remote_list'] = client['duet'].python.List()  # create a remote list to add the transforms to
    client['remote_list'].append(transform_1)
    client['remote_list'].append(transform_2)
    
    client['transforms'] = client['remote_torchvision'].transforms.Compose(client['remote_list'])
    client['train_data_ptr'] = client['remote_torchvision'].datasets.MNIST('./', train=True, download=True, transform=client['transforms'])
    client['train_loader_ptr'] = client['remote_torch'].utils.data.DataLoader(client['train_data_ptr'], **train_kwargs)

In [None]:
# normally we would not necessarily know the length of a remote dataset so lets ask for it
# so we can pass that to our training loop and know when to stop
def get_train_length(train_data_ptr):
    train_length_ptr = train_data_ptr.__len__()
    train_data_length = train_length_ptr.get(
        request_block=True,
        name="train_size",
        reason="To write the training loop",
        timeout_secs=5,
    )
    return train_data_length


for client in clients:
    client['train_data_length'] = get_train_length(client['train_data_ptr'])
    client['samples'] = client['train_data_length'] / args['images']
    print(f"Training Dataset size is: {client['train_data_length']}")

In [None]:
%%time
import time
args["dry_run"] = True  # comment to do a full train
print("Starting Training")

for fed_round in range(args['rounds']):
    for i, client in enumerate(clients):
        
        client['remote_model'] = client['model'].send(client['duet']).cpu()
        client['optim'] = client['remote_torch'].optim.Adadelta(client['remote_model'].parameters(), lr=args['lr'])
        client['sched'] = client['remote_torch'].optim.lr_scheduler.StepLR(client['optim'], step_size=1, gamma=args['gamma'])
        
        # train the clients
        for epoch in range(1, args["epochs"] + 1):
            epoch_start = time.time()
            print(f"Epoch: {epoch}")
            # remote training on model with remote_torch
            train(client, epoch, args)
            client['sched'].step()
            epoch_end = time.time()
            print(f"Epoch time: {int(epoch_end - epoch_start)} seconds")
            break
        
        # get the client model back for averaging
        client['model'] = client['remote_model'].get(
            request_block=True,
            name="model_download",
            reason="test evaluation",
            timeout_secs=5
        )

    # Average all the clients
    local_model = averageModels(local_model, clients)
    
    # local testing on model with local torch
    test_local(local_model, test_loader, test_data_length)
    
    # Share the global model with the clients
    for client in clients:
        client['model'].load_state_dict(copy.deepcopy(local_model.state_dict()))
    
print("Finished Training")