In [36]:
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import os

In [37]:
training_dataset = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

Files already downloaded and verified


In [38]:
test_dataset = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Files already downloaded and verified


In [39]:
training_dataloader = DataLoader(training_dataset, batch_size=32)
test_dataloader = DataLoader(test_dataset, batch_size=32)

In [40]:
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([32, 3, 32, 32])
Shape of y: torch.Size([32]) torch.int64


In [41]:
class ClientNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_stack1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=(3, 3), padding="same"),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding="same"),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d((2, 2))  
        )
    
    def forward(self, data):
        x = self.conv_stack1(data)
        return x
    
class ServerNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_stack2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding="same"),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding="same"),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2, 2))  
        )

        self.conv_stack3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding="same"),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=(3, 3), padding="same"),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d((2, 2))  
        )

        self.classification_stack = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(),
            nn.Linear(4*4*128, 10),
            nn.Softmax(1)
        )
        
    
    def forward(self, data):
        x = self.conv_stack2(data)
        x = self.conv_stack3(x)
        return self.classification_stack(x)

In [42]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [43]:
serverModel = ServerNN().to(device)
clientModel = ClientNN().to(device)
print(serverModel)
print(clientModel)

ServerNN(
  (conv_stack2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv_stack3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (4): ReLU()
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (classification_s

In [44]:
loss_fn = nn.CrossEntropyLoss()
clientOptimizer = torch.optim.Adam(clientModel.parameters(), lr=3e-4)
serverOptimizer = torch.optim.Adam(serverModel.parameters(), lr=3e-4)

In [45]:
def serverTrain(batch, clientOutputCPU, y, serverModel, loss_fn, serverOptimizer):
    serverModel.train()
    serverOptimizer.zero_grad()
    clientOutput = clientOutputCPU.to(device)
    pred = serverModel(clientOutput)
    loss = loss_fn(pred, y)
    loss.backward()
    serverOptimizer.step()
    loss_value = loss.item()
    print(f"Loss : {loss_value}, Batch : {batch}")
    return clientOutputCPU.grad.clone().detach()


In [46]:
def clientTrain(dataloader, clientModel, serverModel, loss_fn, clientOptimizer, serverOptimizer):
    clientModel.train()
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        
        clientOutput = clientModel(X)
        clientOutput = clientOutput.clone().detach().requires_grad_(True)
        grads = serverTrain(batch, clientOutput, y, serverModel, loss_fn, serverOptimizer)
        clientOutput.backward(grads)
        clientOptimizer.step()
        clientOptimizer.zero_grad()

In [47]:
epochs = 5
for epoch in range(epochs):
    print(f"Epoch {epoch} started!")
    clientTrain(training_dataloader, clientModel, serverModel, loss_fn, clientOptimizer, serverOptimizer)


Epoch 0 started!
Loss : 2.297450304031372, Batch : 0
Loss : 2.247063159942627, Batch : 1
Loss : 2.3187785148620605, Batch : 2
Loss : 2.2961623668670654, Batch : 3
Loss : 2.2806894779205322, Batch : 4
Loss : 2.31801700592041, Batch : 5
Loss : 2.2773256301879883, Batch : 6
Loss : 2.2142996788024902, Batch : 7
Loss : 2.261676073074341, Batch : 8
Loss : 2.2241785526275635, Batch : 9
Loss : 2.287748098373413, Batch : 10
Loss : 2.2325687408447266, Batch : 11
Loss : 2.2701210975646973, Batch : 12
Loss : 2.246320962905884, Batch : 13
Loss : 2.224458932876587, Batch : 14
Loss : 2.2676734924316406, Batch : 15
Loss : 2.1996219158172607, Batch : 16
Loss : 2.197838306427002, Batch : 17
Loss : 2.256319761276245, Batch : 18
Loss : 2.256678581237793, Batch : 19


KeyboardInterrupt: 