# DCGAN - Syft Duet - Data Scientist 🥁

Contributed by [@Koukyosyumei](https://github.com/Koukyosyumei)

This example trains a DCGAN network on the BSD300 dataset with Syft.
This notebook is mainly based on the original pytorch [example](https://github.com/OpenMined/PySyft/tree/dev/examples/duet/dcgan/original).

## PART 1: Connect to a Remote Duet Server

As the Data Scientist, you want to perform data science on data that is sitting in the Data Owner's Duet server in their Notebook.

In order to do this, we must run the code that the Data Owner sends us, which importantly includes their Duet Session ID. The code will look like this, importantly with their real Server ID.

```
import syft as sy
duet = sy.duet('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
```

This will create a direct connection from my notebook to the remote Duet server. Once the connection is established all traffic is sent directly between the two nodes.

Paste the code or Server ID that the Data Owner gives you and run it in the cell below. It will return your Client ID which you must send to the Data Owner to enter into Duet so it can pair your notebooks.

In [None]:
import syft as sy
duet = sy.join_duet(loopback=True)
sy.logger.add(sink="./syft_ds.log")

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 0 : Now STOP and run the Data Owner notebook until Checkpoint 1.

In [None]:
import os
import torch
import torchvision
import torchvision.utils as vutils

from PIL import Image

try:
    # make notebook progress bars nicer
    from tqdm.notebook import tqdm
except ImportError:
    print(f"Unable to import tqdm")

In [None]:
remote_torch = duet.torch

# Set params

In [None]:
config = {
    "image_size":28,
    "batch_size":64,
    "no_cuda":True,
    "seed":42,
    "nz":100,
    "ngf":28,
    "ndf":28,
    "lr":0.0002,
    "beta1":0.5,
    "ngpu":0,
    "num_iter":100,
    "dry_run":True,
    "log_interval":2,
    "save_model":False,
    "save_model_interval":20,
    "cuda":False,
    "outf":"result",
}

In [None]:
if not os.path.exists(config["outf"]):
    os.mkdir(config["outf"])

# Load Data

In [None]:
from syft.util import get_root_data_path

# we need some transforms for the MNIST data set
remote_torchvision = duet.torchvision

transform_1 = remote_torchvision.transforms.ToTensor()  # this converts PIL images to Tensors
transform_2 = remote_torchvision.transforms.Normalize((0.5), (0.5))  # this normalizes the dataset

remote_list = duet.python.List()  # create a remote list to add the transforms to
remote_list.append(transform_1)
remote_list.append(transform_2)

# compose our transforms
transforms = remote_torchvision.transforms.Compose(remote_list)

# The DO has kindly let us initialise a DataLoader for their training set
train_kwargs = {
    "batch_size": config["batch_size"],
    "shuffle":True
}
train_data_ptr = remote_torchvision.datasets.MNIST(str(get_root_data_path()), train=True, download=True, transform=transforms)
train_loader_ptr = remote_torch.utils.data.DataLoader(train_data_ptr,**train_kwargs)

In [None]:
# normally we would not necessarily know the length of a remote dataset so lets ask for it
# so we can pass that to our training loop and know when to stop
def get_train_length(train_data_ptr):
    train_length = train_data_ptr.__len__()
    return train_length

try:
    if train_data_length is None:
        train_data_length = get_train_length(train_data_ptr)
except NameError:
        train_data_length = get_train_length(train_data_ptr)

print(f"Training Dataset size is: {train_data_length}")

# Check GPU

In [None]:
has_cuda = False
has_cuda_ptr = remote_torch.cuda.is_available().get(request_block=True)

# lets ask to see if our Data Owner has CUDA
print("Is cuda available ? : ", has_cuda)

use_cuda = not config["no_cuda"] and has_cuda
# now we can set the seed
remote_torch.manual_seed(config["seed"])

device = remote_torch.device("cuda" if use_cuda else "cpu")
#print(f"Data Owner device is {device.type.get()}")

# Define and Create models

In this notebook, we'll use two models based on the following link. 

https://github.com/pytorch/examples/tree/master/dcgan

The structure of the models is a bit different from the original ones because pysyft doesn't support transform.Resize and we have to match the dimension of input without Resize.

In [None]:
nz = config["nz"]
ngf = config["ngf"]
ndf = config["ndf"]
nc = 1

In [None]:
class Generator(sy.Module):
    def __init__(self, torch_ref):
        super(Generator, self).__init__(torch_ref=torch_ref)

        """
        input is Z, going into a convolution
        ----------------------------------------------------------------
                Layer (type)               Output Shape         Param #
        ================================================================
           ConvTranspose2d-1            [-1, 112, 4, 4]         179,200
               BatchNorm2d-2            [-1, 112, 4, 4]             224
                      ReLU-3            [-1, 112, 4, 4]               0
           ConvTranspose2d-4             [-1, 56, 7, 7]          56,448
               BatchNorm2d-5             [-1, 56, 7, 7]             112
                      ReLU-6             [-1, 56, 7, 7]               0
           ConvTranspose2d-7           [-1, 28, 14, 14]          25,088
               BatchNorm2d-8           [-1, 28, 14, 14]              56
                      ReLU-9           [-1, 28, 14, 14]               0
          ConvTranspose2d-10            [-1, 1, 28, 28]             448
                     Tanh-11            [-1, 1, 28, 28]               0
        ================================================================
        Total params: 261,576
        Trainable params: 261,576
        Non-trainable params: 0      
        """
        
        self.conv1 = self.torch_ref.nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False)
        self.norm1 = self.torch_ref.nn.BatchNorm2d(ngf*4)
        self.relu1 = self.torch_ref.nn.ReLU(True)

        self.conv2 = self.torch_ref.nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, bias=False)
        self.norm2 = self.torch_ref.nn.BatchNorm2d(ngf*2)
        self.relu2 = self.torch_ref.nn.ReLU(True)

        self.conv3 = self.torch_ref.nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False)
        self.norm3 = self.torch_ref.nn.BatchNorm2d(ngf)
        self.relu3 = self.torch_ref.nn.ReLU(True)

        self.conv4 = self.torch_ref.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        output = self.torch_ref.nn.Tanh()(x)
        return output

In [None]:
class Discriminator(sy.Module):
    def __init__(self, torch_ref):
        super(Discriminator, self).__init__(torch_ref=torch_ref)
        
        """
        input is (nc) x 28 x 28
        ----------------------------------------------------------------
                Layer (type)               Output Shape         Param #
        ================================================================
                    Conv2d-1           [-1, 28, 14, 14]             448
                 LeakyReLU-2           [-1, 28, 14, 14]               0
                    Conv2d-3             [-1, 56, 7, 7]          25,088
               BatchNorm2d-4             [-1, 56, 7, 7]             112
                 LeakyReLU-5             [-1, 56, 7, 7]               0
                    Conv2d-6            [-1, 112, 4, 4]          56,448
               BatchNorm2d-7            [-1, 112, 4, 4]             224
                 LeakyReLU-8            [-1, 112, 4, 4]               0
                    Conv2d-9              [-1, 1, 1, 1]           1,792
                  Sigmoid-10              [-1, 1, 1, 1]               0
        ================================================================
        Total params: 84,112
        Trainable params: 84,112
        Non-trainable params: 0        
        """
        
        self.conv1 = self.torch_ref.nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)
        self.lerelu1 = self.torch_ref.nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = self.torch_ref.nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)
        self.norm2 = self.torch_ref.nn.BatchNorm2d(ndf * 2)
        self.lerelu2 = self.torch_ref.nn.LeakyReLU(0.2, inplace=True)
        self.conv3 = self.torch_ref.nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False)
        self.norm3 = self.torch_ref.nn.BatchNorm2d(ndf * 4)
        self.lerelu3 = self.torch_ref.nn.LeakyReLU(0.2, inplace=True)
        self.conv4 = self.torch_ref.nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.lerelu1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.lerelu2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.lerelu3(x)
        x = self.conv4(x)
        output = self.torch_ref.nn.Sigmoid()(x)

        return output.view(-1, 1).squeeze(1)

In [None]:
local_netG =  Generator(torch)
local_netD =  Discriminator(torch)
netG = local_netG.send(duet)
netD = local_netD.send(duet)

In [None]:
# if we have CUDA lets send our model to the GPU
if has_cuda:
    netD.cuda(device)
    netG.cuda(device)
else:
    netD.cpu()
    netG.cpu()

# Training

make sure that models locate in remote

In [None]:
assert not netD.is_local, "Training requires remote model"
assert not netG.is_local, "Training requires remote model"

fixed noise for inference

In [None]:
fixed_noise = torch.randn(config["batch_size"], nz, 1, 1)
fixed_noise.tag("noise")
fixed_noise_ptr = fixed_noise.send(duet, pointable=True)

In [None]:
train_batches = round((train_data_length / config["batch_size"]) + 0.5)
print(f"> Running train in {train_batches} batches")

set optimizer and loss function

In [None]:
optimizerD = remote_torch.optim.Adam(netD.parameters(), lr=config["lr"], betas=(config["beta1"], 0.999))
optimizerG = remote_torch.optim.Adam(netG.parameters(), lr=config["lr"], betas=(config["beta1"], 0.999))
criterion = remote_torch.nn.BCELoss()

In [None]:
real_label = 1
fake_label = 0

In [None]:
if config["dry_run"]:
    num_iter = 1
else:
    num_iter = config["num_iter"]

for epoch in range(num_iter):
    
    netD.train()
    netG.train()
    
    for batch_idx, data in enumerate(train_loader_ptr):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        
        optimizerD.zero_grad()
        data_ptr, target_ptr = data[0], data[1]
        batch_size = config["batch_size"]
       
        output = netD(data_ptr)
        label = remote_torch.zeros_like(output)
        label = label.fill_(real_label)
        
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1)
        fake = netG(noise)
        output = netD(fake.detach())
        label = remote_torch.zeros_like(output)
        label.fill_(fake_label)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        
        #netG.zero_grad()
        optimizerG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        if batch_idx % config["log_interval"] == 0:
            # get loss
            local_errD = None
            local_errD = errD.item().get(request_block=True)
            local_errG = None
            local_errG = errG.item().get(request_block=True)
            
            if (local_errD is not None) and (local_errG is not None):
                print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                  % (epoch, config["num_iter"], batch_idx, train_batches,
                    local_errD, local_errG))  
            else:
                print('[%d/%d][%d/%d] Loss_D: - Loss_G: - '
                  % (epoch, config["num_iter"], batch_idx, train_batches))
            
            # create fake image
            created_img = netG(fixed_noise).get(request_block=True)
            vutils.save_image(
                            created_img.detach(),
                            f"{config['outf']}/fake_samples_{epoch}_{batch_idx}.png",
                            normalize=True,
                            )
            
        if batch_idx >= train_batches - 1:
            break

        if config["dry_run"]:
            break
         
    # save snapshots
    if (config["save_model"]) and (epoch % config["save_model_interval"] == 0):
        netG.get(
            request_block=True,
            timeout_secs=5,
            delete_obj=False
        ).save(f"{config['outf']}/netG_mnist_{epoch}.pt")

        netD.get(
            request_block=True,
            timeout_secs=5,
            delete_obj=False
        ).save(f"{config['outf']}/netD_mnist_{epoch}.pt")

# Save models

In [None]:
local_netG = netG.get(
    request_block=True,
    timeout_secs=5,
    delete_obj=False
)

local_netG.save(f"{config['outf']}/final_netG_mnist.pt")

netD.get(
    request_block=True,
    timeout_secs=5,
    delete_obj=False
).save(f"{config['outf']}/final_netD_mnist.pt")

# Inference locally

In [None]:
assert local_netG.is_local, "model is remote try .get()"

created_img = local_netG(fixed_noise)

vutils.save_image(
    created_img.detach(),
    "fake_samples_local.png",
    normalize=True,
)

In [None]:
fake_image = Image.open("fake_samples_local.png")
#display(fake_image)

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 1 : Now STOP and run the Data Owner notebook until the next checkpoint.