In [1]:
# Imports for model
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.datasets as datasets
import torch
from torch.optim import Adam, SGD

# Imports for server connection
import socket
from send_receive import *

In [2]:
class MnistModel(nn.Module):
  def __init__(self) -> None:
      super().__init__()
      self.lin1 = nn.Linear(784, 256)
      self.lin2 = nn.Linear(256, 64)
      self.lin3 = nn.Linear(64, 10)

  def forward(self, X):
      x1 = F.relu(self.lin1(X))
      x2 = F.relu(self.lin2(x1))
      x3 = self.lin3(x2)
      return x3

  # Fit function
  def fit(self, X, y, optimizer, loss_fn, epochs):

    for epoch in range(epochs):

      ypred = self.forward(X)
      loss = loss_fn(ypred, y)

      optimizer.zero_grad()

      loss.backward()

      optimizer.step()

In [3]:
# Data fetching
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

# Use first 50,000 entries from mnist training set, rest are for server
X_train = mnist_trainset.data[:50000,:]
X_train = X_train.float().flatten(start_dim=1, end_dim=2) # Flatten training images
Y_train = mnist_trainset.targets[:50000]

# Load testsset
X_test = mnist_testset.data
X_test = X_test.float().flatten(start_dim=1, end_dim=2) # Flatten test images
Y_test = mnist_testset.targets[:50000]

#X_train.shape
print(X_train.shape)

torch.Size([50000, 784])


In [None]:
class CronusKLLoss(nn.Module):
    def __init__(self, T=3.0):
        super().__init__()
        self.T = T
        self.kl = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_logits):
        student_log_probs = F.log_softmax(student_logits / self.T, dim=1)
        teacher_probs = F.softmax(teacher_logits / self.T, dim=1)
        return self.kl(student_log_probs, teacher_probs) * (self.T ** 2)

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import socket


def start_clients(HOST, PORT, num_models, learning_rounds, local_epochs, collab_epochs):

    sampleSize = len(X_train) // num_models
    imgSize = len(X_train[0])

    X_trains = torch.zeros((sampleSize, imgSize, num_models))
    Y_trains = torch.zeros((sampleSize, num_models))

    # Fill data for each model
    for m in range(num_models):
        idx = torch.randperm(len(X_train))[:sampleSize]
        X_trains[:, :, m] = X_train[idx]
        Y_trains[:, m] = Y_train[idx]

    # Socket setup
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.connect((HOST, PORT))
    print(f"Connected to {HOST} on port {PORT}")

    # Receive public features
    features = recv_tensor(s)
    print(features.shape)

    # Create models
    models = [MnistModel() for _ in range(num_models)]

    # INITIALIZATION
    for i in range(num_models):
        optimAdam = Adam(models[i].parameters(), lr=0.0005)
        models[i].fit(
            X_trains[:, :, i],
            Y_trains[:, i].long(),
            optimAdam,
            nn.CrossEntropyLoss(),
            local_epochs
        )

    # Logging
    for j in range(num_models):
        preds = models[j].forward(X_test).argmax(dim=1)
        err = (preds != Y_test).float().mean()
        print(f"Initialization, model {j}, error {err}")

    # Initial predictions
    predictions = torch.stack(
        [models[i].forward(features).detach() for i in range(num_models)],
        dim=0
    )
    send_tensor(s, predictions)

    aggregationLabels = recv_tensor(s)

    # COLLABORATION
    for t in range(collab_epochs):

        predictions = []

        for j in range(num_models-1):

            optimSGD = SGD(models[j].parameters(), lr=0.001)

            models[j].fit(X_trains[:, :, j], Y_trains[:, j].long(),
                      optimSGD, nn.CrossEntropyLoss(), epochs=5)

            # One-step distillation update
            models[j].fit(
                features,
                aggregationLabels.detach(),
                optimSGD,
                CronusKLLoss(T=3.0),
                epochs=1
            )

            predictions.append(models[j].forward(features).detach())

        predictions = torch.stack(predictions, dim=0)
        send_tensor(s, predictions)

        aggregationLabels = recv_tensor(s)

        # Logging
        for j in range(num_models):
            preds = models[j].forward(X_test).argmax(dim=1)
            err = (preds != Y_test).float().mean()
            print(f"Collab step {t}, model {j}, error {err}")

    print("Finished")

    s.close()


# -----------------------------
# Run
# -----------------------------
HOST = "localhost"
PORT = 65435

num_models = 3
learning_rounds = 1
local_epochs = 50
collab_epochs = 50

start_clients(
    HOST,
    PORT,
    num_models,
    learning_rounds,
    local_epochs,
    collab_epochs
)

Connected to localhost on port 65435


  tensor = torch.from_numpy(array).to(device)


torch.Size([10000, 784])
Initialization, model 0, error 0.0778999999165535
Initialization, model 1, error 0.09099999815225601
Initialization, model 2, error 0.08470000326633453
Collab step 0, model 0, error 0.07909999787807465
Collab step 0, model 1, error 0.09009999781847
Collab step 0, model 2, error 0.08470000326633453
Collab step 1, model 0, error 0.07919999957084656
Collab step 1, model 1, error 0.08879999816417694
Collab step 1, model 2, error 0.08470000326633453
Collab step 2, model 0, error 0.07880000025033951
Collab step 2, model 1, error 0.08879999816417694
Collab step 2, model 2, error 0.08470000326633453
Collab step 3, model 0, error 0.07850000262260437
Collab step 3, model 1, error 0.0885000005364418
Collab step 3, model 2, error 0.08470000326633453
Collab step 4, model 0, error 0.0778999999165535
Collab step 4, model 1, error 0.08799999952316284
Collab step 4, model 2, error 0.08470000326633453
Collab step 5, model 0, error 0.07760000228881836
Collab step 5, model 1, erro