In [None]:
# Imports for model
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.datasets as datasets
import torch
from torch.optim import Adam, SGD

# Imports for server connection
import socket
import io
from send_receive import *

In [None]:
class MnistModel(nn.Module):
  def __init__(self) -> None:
      super().__init__()
      self.lin1 = nn.Linear(784, 256)
      self.lin2 = nn.Linear(256, 64)
      self.lin3 = nn.Linear(64, 10)

  def forward(self, X):
      x1 = F.relu(self.lin1(X))
      x2 = F.relu(self.lin2(x1))
      x3 = F.relu(self.lin3(x2))
      return x3

  # Fit function
  def fit(self, X, y, optimizer, loss_fn, epochs):

    for epoch in range(epochs):

      ypred = self.forward(X)
      loss = loss_fn(ypred, y)

      optimizer.zero_grad()

      loss.backward()

      optimizer.step()

In [None]:
# Data fetching

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

X_train = mnist_trainset.data
X_train = X_train.float().flatten(start_dim=1, end_dim=2)
Y_train = mnist_trainset.targets

X_test = mnist_testset.data
X_test = X_test.float().flatten(start_dim=1, end_dim=2)
Y_test = mnist_testset.targets

torch.Size([10000])


In [None]:
HOST = "127.0.0.1"  # Standard loopback interface address (localhost)
PORT = 65432  # Port to listen on (non-privileged ports are > 1023)

numModels = 10
iterations = 100
sampleSize = len(X_train)//10
imgSize = len(X_train[0])
X_trains = np.zeros((sampleSize, imgSize, numModels))
Y_trains = np.zeros((sampleSize, numModels))

# Fill data for each model
for m in range(numModels):
    # Random indices for this model
    idx = torch.randperm(len(X_train))[:sampleSize]

    # Assign data
    X_trains[:, :, m] = X_train[idx].numpy()
    Y_trains[:, m]    = Y_train[idx].numpy()

models = []

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((HOST, PORT))

# Receive numpy array from server for public features
features = recv_numpy(s)

# Create models
models = [MnistModel() for _ in range(numModels)]

for iteration in range(iterations):

    # Loop to train each model
    for i in range(numModels):
        models[i].fit(X_trains[:,:,i], Y_trains[:,i])

    # Do predictions on all of the public dataset
    predictions = []
    for i in range(numModels):
        predictions.append(models[i].forward(features))
    predictions = np.stack(predictions, axis=0)  

    # Send predictions to server
    send_numpy(s, predictions)

    # Receive aggregation from server
    aggregationFeatures = recv_numpy(s)
    aggregationLabels = recv_numpy(s)

    for i in range(numModels):
        models[i].fit(aggregationFeatures, aggregationLabels)

# Get error rates
for i in range(numModels):
    errorRate = np.mean(models[i].forward(X_test) != Y_test)
    print(f"Error rate for model {i}: {errorRate}")
