# MNIST Syft Data Scientist

In [None]:
import syft as sy
sy.VERBOSE = False

# PART 1: Connect to a Remote Duet Server

As the data scientist, you want to perform data science on data that is sitting in the Data Owner's Duet server (in their Notebook).

In order to do this, we must run the code that the Data Owner sends us, which importantly includes their Duet Session ID. This will create a direct connection from my notebook to the remote Duet server. Once the connection is established all traffic is sent directly between the two nodes.

Let's run the code below and follow the instructions it gives.

In [None]:
duet = sy.join_duet(loopback=True)

# PART 2: Launch a Duet Server and Connect

In [None]:
# lets get some references to our data owners Duet torch and torchvision
torch = duet.torch
torchvision = duet.torchvision

# these are the same as the original mnist example
transforms = torchvision.transforms
datasets = torchvision.datasets
nn = torch.nn
F = torch.nn.functional
optim = torch.optim
StepLR = torch.optim.lr_scheduler.StepLR

In [None]:
# has_cuda_ptr = torch.cuda.is_available()
# try:
#     has_cuda = has_cuda_ptr.get(
#         request_block=True,
#         request_name="cuda.is_available()",
#         reason="If you have CUDA I will enable it to speed up training."
#     )
#     print(f"DO has CUDA: {has_cuda}")
# except Exception as e:
#     print("No permission, try requesting again.")
    
# print("Result of Blocking Request", has_cuda)

In [None]:
import torch as th
import torchvision as tv

In [None]:
# we need some transforms for our MNIST data set
local_transform_1 = tv.transforms.ToTensor()  # this converts PIL images to Tensors
local_transform_2 = tv.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset

# compose our transforms
local_transforms = tv.transforms.Compose([local_transform_1, local_transform_2])

In [None]:
duet.store.pandas

In [None]:
# Training settings from original MNIST example command line args
args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 14,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": False,
    "dry_run": False,
    "seed": 42, # the meaning of life
    "log_interval": 10,
    "save_model": False,
}

In [None]:
test_kwargs = {
    "batch_size": args["test_batch_size"],
}

# this is our carefully curated test data which represents the goal of our problem domain
test_data = tv.datasets.MNIST('../data', train=False, download=True, transform=local_transforms)
test_loader = th.utils.data.DataLoader(test_data,**test_kwargs)

In [None]:
test_data_length = len(test_loader.dataset)
print(test_data_length)

test_data_length_ptr = duet.syft.lib.python.Int(test_data_length)
print(test_data_length_ptr)

In [None]:
print(test_data, type(test_data))

The "duet" variable is now your reference to a whole world of remote operations including supported libraries like torch.

In [None]:
# we need some transforms for our MNIST data set
transform_1 = torchvision.transforms.ToTensor()  # this converts PIL images to Tensors
transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset
print(type(transform_1), type(transform_2))

remote_list = duet.syft.lib.python.List()
remote_list.append(transform_1)
remote_list.append(transform_2)

# compose our transforms
transforms = torchvision.transforms.Compose(remote_list)
print(type(transforms))

In [None]:
# TODO replace with local inference so this doesn't need to be on the DO side
test_data_ptr = torchvision.datasets.MNIST('../data', train=False, download=True, transform=transforms)
print(test_data_ptr)

test_loader_ptr = torch.utils.data.DataLoader(test_data_ptr,**test_kwargs)
print(test_loader_ptr)
# TODO

In [None]:
class Net:
    modules = []
    training = False

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

        # add to modules list
        self.modules.append(self.conv1)
        self.modules.append(self.conv2)
        self.modules.append(self.dropout1)
        self.modules.append(self.dropout2)
        self.modules.append(self.fc1)
        self.modules.append(self.fc2)

    def train(self, mode: bool = True):
        self.training = mode
        for module in self.modules:
            module.train(mode)
        return self

    def eval(self):
        return self.train(False)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

    def __call__(self, input):
        return self.forward(input)

    # local list of remote ListPointers of TensorPointers
    def parameters(self, recurse: bool = True):
        params_list = torch.python.List()
        for module in self.modules:
            param_pointers = module.parameters()
            params_list += param_pointers

        return params_list

    def cuda(self, device) -> "Net":
        for module in self.modules:
            module.cuda(device)
        return self

    def cpu(self) -> "Net":
        for module in self.modules:
            module.cpu()
        return self

In [None]:
# lets define our SOTA model to train on the data owners data
# note we subclass from sy.Module not nn.Module
class SyNet(sy.Module):
    def __init__(self):
        super(SyNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [None]:
# lets see if our Data Owner has CUDA
has_cuda = False
# has_cuda_ptr = torch.cuda.is_available()
# has_cuda = bool(has_cuda_ptr.get(request_block=True))
# print(has_cuda)

In [None]:
use_cuda = not args["no_cuda"] and has_cuda
torch.manual_seed(args["seed"])

device = torch.device("cuda" if use_cuda else "cpu")
print(f"DO device is {device.type.get()}")

In [None]:
# instantiate our model
# this will construct everything inside init on the DO side
# model = Net()
model = SyNet()

In [None]:
# if we have CUDA lets send our model to the GPU
if has_cuda:
    model.cuda(device)
else:
    model.cpu()

In [None]:
# lets get our parameters for optimization
params = model.parameters()
print(params, type(params))

In [None]:
optimizer = optim.Adadelta(params, lr=args["lr"])
print(optimizer, type(optimizer))

In [None]:
scheduler = StepLR(optimizer, step_size=1, gamma=args["gamma"])
print(scheduler, type(scheduler))

In [None]:
# now can define a simple training loop very similar to the original PyTorch MNIST example
def train(args, model, device, train_loader, optimizer, epoch):
    print("> Running train")
    model.train()
#     for batch_idx, (data, target) in enumerate(train_loader):  # TODO: this requires tuple support
    for batch_idx, tensor_ptr in enumerate(train_loader):  # work around until tuple support
        data, target = tensor_ptr[0], tensor_ptr[1]        # work around until tuple support
        data_ptr = data
        target_ptr = target
#         data, target = data.to(device), target.to(device)  # TODO: wont accept device pointer from this side?
        
        optimizer.zero_grad()
        output = model(data_ptr)
        loss = F.nll_loss(output, target_ptr)
        loss.backward()
        optimizer.step()
        loss_item = loss.item()

        if batch_idx % args["log_interval"] == 0:
            local_loss = loss_item.get(
                request_name="loss",
                reason="To evaluate training progress",
                request_block=True,
                timeout_secs=10
            )
            if local_loss is not None:
                print("Train Epoch: {} {} {:.4}".format(epoch, batch_idx, local_loss))
            else:
                print("Train Epoch: {} {} ?".format(epoch, batch_idx))
            if args["dry_run"]:
                break

In [None]:
# TODO replace with local inference and local test set
# the same for our test training loop except we will need to send our data over for inference
def test(model, device, test_loader):
    print("> Running test")
    model.eval()
    test_loss = duet.syft.lib.python.Float(0)
    correct_ptr = duet.syft.lib.python.Float(0)
    with torch.no_grad():
#         for data, target in test_loader:
        for batch_idx, tensor_ptr in enumerate(test_loader):  # work around until tuple support
            data, target = tensor_ptr[0], tensor_ptr[1]        # work around until tuple support

#             data, target = data.to(device), target.to(device)
    
            # Are we sending these each time or some other way?
#             data_ptr = data.send(duet)
#             target_ptr = target.send(duet)
            data_ptr = data
            target_ptr = target

            output = model(data_ptr)
            loss = F.nll_loss(output, target_ptr, reduction='sum').item()
            test_loss = test_loss + loss

            pred = output.argmax(dim=1)
            total = pred.eq(target_ptr).sum().item()
            correct_ptr += total

            if args["dry_run"]:
                break

    accuracy = correct_ptr / test_data_length_ptr
    # we need to batch or block these requests so the loop doesnt break
    result = accuracy.get(
        request_block=True,
        timeout_secs=10,
        request_name="accuracy",
        reason="To see the accuracy on DO's test set"
    )
    if result is not None:
        print("Test Set Average Loss:", 100 * result)
    else:
        print("Test Set Average Loss: ?")

In [None]:
# The DO has kindly let us initialise a DataLoader for their training set
train_kwargs = {
    "batch_size": args["batch_size"],
}
train_data_ptr = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transforms)
print(train_data_ptr)
train_loader_ptr = torch.utils.data.DataLoader(train_data_ptr,**train_kwargs)
print(train_loader_ptr)

In [None]:
args["dry_run"] = True

for epoch in range(1, args["epochs"] + 1):
    train(args, model, device, train_loader_ptr, optimizer, epoch)
    test(model, device, test_loader_ptr)
    scheduler.step()
    break

In [None]:
print("Done")

# PART 3: Inference

In [None]:
import matplotlib.pyplot as plt
def draw_image_and_label(image, label):
    fig = plt.figure()
    plt.tight_layout()
    plt.imshow(image, cmap="gray", interpolation="none")
    plt.title("Ground Truth: {}".format(label))
    
def prep_for_inference(image):
    image_batch = image.unsqueeze(0).unsqueeze(0)
    image_batch = image_batch * 1.0
    return image_batch

In [None]:
def classify(image):
    image_tensor_ptr = torch.Tensor(prep_for_inference(image_1))
    image_tensor_ptr = image_tensor_ptr.to(device)
    
    output = model(image_tensor_ptr)
    
    preds = torch.exp(output)
    local_y = th.Tensor(preds.get())
    local_y = local_y.squeeze()
    pos = local_y == max(local_y)
    index = th.nonzero(pos, as_tuple=False)
    class_num = index.squeeze()
    print(int(class_num))
    return class_num, local_y

In [None]:
# lets grab something from the test set
import random
total_images = test_data_length # 10000
index = random.randint(0, total_images)
print("Random Test Image:", index)
count = 0
batch = index // test_kwargs["batch_size"]
batch_index = index % int(total_images / len(test_loader))
for tensor_ptr in test_loader:
    data, target = tensor_ptr[0], tensor_ptr[1]
    if batch == count:
        break
    count += 1

print(f"Displaying {index} == {batch_index} in Batch: {batch}/{len(test_loader)}")
image_1 = data[batch_index].reshape((28, 28))
label_1 = target[batch_index]
draw_image_and_label(image_1, label_1)

In [None]:
class_num, preds = classify(image_1)
print(f"Prediction: {class_num} Ground Truth: {label_1}")
print(preds)