In [None]:
import syft as sy
from syft.core.remote_dataloader import RemoteDataset
from syft.core.remote_dataloader import RemoteDataLoader
import torch
import torchvision
import time
import matplotlib.pyplot as plt

In [None]:
duet = sy.join_duet(loopback=True)

## Run client until end

In [None]:
meta_ptr = duet.store["meta"]

In [None]:
duet.store.pandas

In [None]:
# create RemoteDataset object on remote side
rds_ptr = duet.syft.core.remote_dataloader.RemoteDataset(meta_ptr)
# create RemoteDataLoader object on remote side
rdl_ptr = duet.syft.core.remote_dataloader.RemoteDataLoader(rds_ptr, batch_size=32)
# call create_dataset to create the real Dataset object on remote side
rdl_ptr.load_dataset()
# call create_dataloader to create the real DataLoader object on remote side
rdl_ptr.create_dataloader()

### Create the Model in remote

This model is the remote model and this has to be sent to the client site for training. Only disadvantage that I see is that the client must provide the computational resources for training a bigger model.

While creating the module we must inherit from syft module and give a torch reference. 

In [None]:
class MedMNISTModel(sy.Module):
    
    def __init__(self,torch_ref):
        super(MedMNISTModel, self).__init__(torch_ref=torch_ref)
        self.conv_head = self.torch_ref.nn.Sequential(
            self.torch_ref.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
            self.torch_ref.nn.ReLU(),
            self.torch_ref.nn.MaxPool2d(kernel_size=2, stride=2),
            self.torch_ref.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            self.torch_ref.nn.ReLU(),
            self.torch_ref.nn.MaxPool2d(kernel_size=2, stride=2),
            self.torch_ref.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
            self.torch_ref.nn.ReLU(),
            self.torch_ref.nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classification_head = self.torch_ref.nn.Sequential(
            self.torch_ref.nn.Linear(in_features=2304, out_features=128),
            self.torch_ref.nn.ReLU(),
            self.torch_ref.nn.Linear(in_features=128, out_features=128),
            self.torch_ref.nn.ReLU(),
            self.torch_ref.nn.Linear(in_features=128, out_features=6),
        )
        
    def forward(self, x):
        x = self.conv_head(x)
        
        # Flattening
        x = self.torch_ref.flatten(x, start_dim=1)
        
        x = self.classification_head(x)
        return x

In [None]:
# now we can create the model and pass in our local copy of torch
local_model = MedMNISTModel(torch)

In [None]:
args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 10,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": False,
    "dry_run": False,
    "seed": 42, # the meaning of life
    "log_interval": 10,
    "save_model": True,
}

In [None]:
model = local_model.send(duet)

In [None]:
remote_torch = duet.torch

In [None]:
has_cuda = False
has_cuda_ptr = remote_torch.cuda.is_available()
has_cuda = bool(has_cuda_ptr.get(
    request_block=True,
    reason="To run test and inference locally",
    timeout_secs=5,  # change to something slower
))
print(has_cuda)

In [None]:
use_cuda = not args["no_cuda"] and has_cuda
# now we can set the seed
remote_torch.manual_seed(args["seed"])

device = remote_torch.device("cuda" if use_cuda else "cpu")
print(f"Data Owner device is {device.type.get()}")

In [None]:
if has_cuda:
    model.cuda(device)
else:
    model.cpu()

In [None]:
params = model.parameters()

In [None]:
optimizer = remote_torch.optim.Adadelta(params, lr=args["lr"])
scheduler = remote_torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=args["gamma"])

## Training loop

In [None]:
def train(model, torch_ref, train_loader, optimizer, epoch, args, train_data_length):
    # + 0.5 lets us math.ceil without the import
    train_batches = round((train_data_length / args["batch_size"]) + 0.5)
    print(f"> Running train in {train_batches} batches")
    if model.is_local:
        print("Training requires remote model")
        return

    model.train()

    for batch_idx, data in enumerate(train_loader):
        data_ptr, target_ptr = data[0], data[1]
        optimizer.zero_grad()
        output = model(data_ptr)
        loss = torch_ref.nn.functional.nll_loss(output, target_ptr)
        loss.backward()
        optimizer.step()
        loss_item = loss.item()
        train_loss = duet.python.Float(0)  # create a remote Float we can use for summation
        train_loss += loss_item
        if batch_idx % args["log_interval"] == 0:
            local_loss = None
            local_loss = loss_item.get(
                reason="To evaluate training progress",
                request_block=True,
                timeout_secs=5
            )
            if local_loss is not None:
                print("Train Epoch: {} {} {:.4}".format(epoch, batch_idx, local_loss))
            else:
                print("Train Epoch: {} {} ?".format(epoch, batch_idx))
        if batch_idx >= train_batches - 1:
            print("batch_idx >= train_batches, breaking")
            break
        if args["dry_run"]:
            break

In [None]:
def test_local(model, torch_ref, test_loader, test_data_length):
    # download remote model
    if not model.is_local:
        local_model = model.get(
            request_block=True,
            reason="test evaluation",
            timeout_secs=5
        )
    else:
        local_model = model
    # + 0.5 lets us math.ceil without the import
    test_batches = round((test_data_length / args["test_batch_size"]) + 0.5)
    print(f"> Running test_local in {test_batches} batches")
    local_model.eval()
    test_loss = 0.0
    correct = 0.0

    with torch_ref.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            output = local_model(data)
            iter_loss = torch_ref.nn.functional.nll_loss(output, target, reduction="sum").item()
            test_loss = test_loss + iter_loss
            pred = output.argmax(dim=1)
            total = pred.eq(target).sum().item()
            correct += total
            if args["dry_run"]:
                break
                
            if batch_idx >= test_batches - 1:
                print("batch_idx >= test_batches, breaking")
                break

    accuracy = correct / test_data_length
    print(f"Test Set Accuracy: {100 * accuracy}%")

## Dataloader remote transformation

In [None]:
remote_torchvision = duet.torchvision

transform_1 = remote_torchvision.transforms.ToTensor()  # this converts PIL images to Tensors
transform_2 = remote_torchvision.transforms.Normalize(0.3, 0.3)  # this normalizes the dataset

remote_list = duet.python.List()  # create a remote list to add the transforms to
remote_list.append(transform_1)
remote_list.append(transform_2)

# compose our transforms
transforms = remote_torchvision.transforms.Compose(remote_list)

# The DO has kindly let us initialise a DataLoader for their training set
train_kwargs = {
    "batch_size": args["batch_size"],
}
train_data_ptr = duet.syft.core.remote_dataloader.RemoteDataset(meta_ptr)
train_loader_ptr = duet.syft.core.remote_dataloader.RemoteDataLoader(train_data_ptr, batch_size=32)
#remote_torch.utils.data.DataLoader(train_data_ptr,**train_kwargs)
train_loader_ptr.load_dataset()
# call create_dataloader to create the real DataLoader object on remote side
train_loader_ptr.create_dataloader()

In [None]:
def get_train_length(train_data_ptr):
    train_data_length = len(train_data_ptr)
    return train_data_length

try:
    if train_data_length is None:
        train_data_length = get_train_length(train_data_ptr)
except NameError:
        train_data_length = get_train_length(train_data_ptr)

print(f"Training Dataset size is: {train_data_length}")

## Training call

In [None]:
# args["dry_run"] = False  # comment to do a full train
print("Starting Training")
for epoch in range(1, args["epochs"] + 1):
    epoch_start = time.time()
    print(f"Epoch: {epoch}")
    # remote training on model with remote_torch
    train(model, remote_torch, train_loader_ptr, optimizer, epoch, args, train_data_length)
    # local testing on model with local torch
#     test_local(model, torch, test_loader, test_data_length)
    scheduler.step()
    epoch_end = time.time()
    print(f"Epoch time: {int(epoch_end - epoch_start)} seconds")
    if args["dry_run"]:
        break
print("Finished Training")

In [None]:
    model.get(
        request_block=True,
        reason="test evaluation",
        timeout_secs=5
    ).save("./duet_mnist.pt")

## Inference

In [None]:
def draw_image_and_label(image, label):
    fig = plt.figure()
    plt.tight_layout()
    plt.imshow(image, cmap="gray", interpolation="none")
    plt.title("Ground Truth: {}".format(label))
    
def prep_for_inference(image):
    image_batch = image.unsqueeze(0).unsqueeze(0)
    image_batch = image_batch * 1.0
    return image_batch

In [None]:
def classify_local(image, model):
    if not model.is_local:
        print("model is remote try .get()")
        return -1, torch.Tensor([-1])
    image_tensor = torch.Tensor(prep_for_inference(image))
    output = model(image_tensor)
    preds = torch.exp(output)
    local_y = preds
    local_y = local_y.squeeze()
    pos = local_y == max(local_y)
    index = torch.nonzero(pos, as_tuple=False)
    class_num = index.squeeze()
    return class_num, local_y

In [None]:
def classify_remote(image, model):
    if model.is_local:
        print("model is local try .send()")
        return -1, remote_torch.Tensor([-1])
    image_tensor_ptr = remote_torch.Tensor(prep_for_inference(image))
    output = model(image_tensor_ptr)
    preds = remote_torch.exp(output)
    preds_result = preds.get(
        request_block=True,
        reason="To see a real world example of inference",
        timeout_secs=10
    )
    if preds_result is None:
        print("No permission to do inference, request again")
        return -1, torch.Tensor([-1])
    else:
        # now we have the local tensor we can use local torch
        local_y = torch.Tensor(preds_result)
        local_y = local_y.squeeze()
        pos = local_y == max(local_y)
        index = torch.nonzero(pos, as_tuple=False)
        class_num = index.squeeze()
        return class_num, local_y