In [1]:
# Imports for model
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.datasets as datasets
import torch
from torch.optim import Adam, SGD

# Imports for server connection
import socket
from send_receive import *

In [2]:
class MnistModel(nn.Module):
  def __init__(self) -> None:
      super().__init__()
      self.lin1 = nn.Linear(784, 256)
      self.lin2 = nn.Linear(256, 64)
      self.lin3 = nn.Linear(64, 10)

  def forward(self, X):
      x1 = F.relu(self.lin1(X))
      x2 = F.relu(self.lin2(x1))
      x3 = self.lin3(x2)
      return x3

  # Fit function
  def fit(self, X, y, optimizer, loss_fn, epochs):

    for epoch in range(epochs):

      ypred = self.forward(X)
      loss = loss_fn(ypred, y)

      optimizer.zero_grad()

      loss.backward()

      optimizer.step()

In [49]:
# Data fetching
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

# Use first 50,000 entries from mnist training set, rest are for server
X_train = mnist_trainset.data[:50000,:]
X_train = X_train.float().flatten(start_dim=1, end_dim=2) # Flatten training images
Y_train = mnist_trainset.targets[:50000]

# Load testsset
X_test = mnist_testset.data
X_test = X_test.float().flatten(start_dim=1, end_dim=2) # Flatten test images
Y_test = mnist_testset.targets

#X_train.shape
print(X_train.shape)

torch.Size([50000, 784])


In [18]:
class CronusKLLoss(nn.Module):
    def __init__(self, T=3.0):
        super().__init__()
        self.T = T
        self.kl = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_logits):
        student_log_probs = F.log_softmax(student_logits / self.T, dim=1)
        teacher_probs     = F.softmax(teacher_logits / self.T, dim=1)
        return self.kl(student_log_probs, teacher_probs) * (self.T ** 2)

class SoftKLLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.kl = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_probs):
        student_log_probs = F.log_softmax(student_logits, dim=1)
        return self.kl(student_log_probs, teacher_probs)

In [71]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import socket


def start_clients(HOST, PORT, num_models, initialization_epochs, collab_epochs):

    sampleSize = len(X_train) // num_models
    imgSize = len(X_train[0])

    X_trains = torch.zeros((sampleSize, imgSize, num_models))
    Y_trains = torch.zeros((sampleSize, num_models))

    # Fill data for each model
    for m in range(num_models):
        idx = torch.randperm(len(X_train))[:sampleSize]
        X_trains[:, :, m] = X_train[idx]
        Y_trains[:, m] = Y_train[idx]

    # Socket setup
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.connect((HOST, PORT))
    print(f"Connected to {HOST} on port {PORT}")

    try:
        # Receive public features
        features = recv_tensor(s)
        print(features.shape)
    
        # Create models
        models = [MnistModel() for _ in range(num_models)]
    
        # INITIALIZATION
        optimAdams = [Adam(models[i].parameters(), lr=0.0005) for i in range(num_models)]
        for i in range(num_models):
            models[i].fit(
                X_trains[:, :, i],
                Y_trains[:, i].long(),
                optimAdams[i],
                nn.CrossEntropyLoss(),
                initialization_epochs
            )
    
        # Logging
        for j in range(num_models):
            preds = models[j].forward(X_test).argmax(dim=1)
            err = (preds != Y_test).float().mean()
            print(f"Initialization, model {j}, error {err}")
    
        # Initial predictions
        predictions = torch.stack(
            [models[i].forward(features).detach() for i in range(num_models)],
            dim=0
        )
        send_tensor(s, predictions)
    
        aggregationLabels = recv_tensor(s)
        aggregationLabels = F.softmax(aggregationLabels, dim=1)
    
        # COLLABORATION

        optimSGDs = [SGD(models[j].parameters(), lr=0.001) for j in range(num_models)]

        optim_hard = [
        torch.optim.SGD(models[j].parameters(), lr=0.01)
        for j in range(num_models)
        ]
    
        optim_soft = [
            torch.optim.SGD(models[j].parameters(), lr=0.001)
            for j in range(num_models)
        ]
        
        for t in range(collab_epochs):
    
            predictions = []
    
            for j in range(num_models):

                # -------- 1. HARD update (anchor) --------
                models[j].fit(
                    X_trains[:, :, j],
                    Y_trains[:, j].long(),
                    optim_hard[j],
                    nn.CrossEntropyLoss(),
                    epochs=1
                )
            
                # -------- 2. SOFT update (regularizer) --------
                models[j].fit(
                    features,
                    aggregationLabels.detach(),
                    optim_soft[j],
                    SoftKLLoss(),
                    epochs=1
                )

                """

                Y_local = F.one_hot(
                    Y_trains[..., j].long(),
                    num_classes=10
                ).float()

                X_mix = torch.cat([X_trains[..., j], features])
                Y_mix = torch.cat([Y_local, aggregationLabels.detach()])
                
                models[j].fit(X_mix, Y_mix, optimSGDs[j], SoftKLLoss(), epochs=5)

                """

                """
                models[j].fit(X_trains[:, :, j], Y_trains[:, j].long(),
                          optimSGD, nn.CrossEntropyLoss(), epochs=5)
    
                # One-step distillation update
                models[j].fit(
                    features,
                    aggregationLabels.detach(),
                    optimSGD,
                    CronusKLLoss(T=3.0),
                    epochs=1
                )
                """
    
                predictions.append(models[j].forward(features).detach())
    
            predictions = torch.stack(predictions, dim=0)
            send_tensor(s, predictions)
    
            aggregationLabels = recv_tensor(s)
            aggregationLabels = F.softmax(aggregationLabels, dim=1)
    
            # Logging
            for j in range(num_models):
                preds = models[j].forward(X_test).argmax(dim=1)
                err = (preds != Y_test).float().mean()
                print(f"Collab step {t}, model {j}, error {err}")
    
        print("Finished")
    
        s.close()

    except KeyboardInterrupt:
        s.close()
        print("Keyboard Interrupt, Thread Closed")
        
    

In [72]:
# -----------------------------
# Run
# -----------------------------
HOST = "localhost"
PORT = 65435

num_models = 10
initialization_epochs = 50
collab_epochs = 50

start_clients(
    HOST,
    PORT,
    num_models,
    initialization_epochs,
    collab_epochs
)

Connected to localhost on port 65435
torch.Size([10000, 784])
Initialization, model 0, error 0.09080000221729279
Initialization, model 1, error 0.09390000253915787
Initialization, model 2, error 0.10320000350475311
Initialization, model 3, error 0.1023000031709671
Initialization, model 4, error 0.10289999842643738
Initialization, model 5, error 0.09749999642372131
Initialization, model 6, error 0.0917000025510788
Initialization, model 7, error 0.09870000183582306
Initialization, model 8, error 0.09189999848604202
Initialization, model 9, error 0.09200000017881393
Collab step 0, model 0, error 0.08919999748468399
Collab step 0, model 1, error 0.09359999746084213
Collab step 0, model 2, error 0.10360000282526016
Collab step 0, model 3, error 0.10209999978542328
Collab step 0, model 4, error 0.10180000215768814
Collab step 0, model 5, error 0.09740000218153
Collab step 0, model 6, error 0.09160000085830688
Collab step 0, model 7, error 0.09669999778270721
Collab step 0, model 8, error 0.0