In [13]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ExponentialLR
from torch.autograd import Variable
from torch.distributions.binomial import Binomial

In [14]:
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.densenet = torchvision.models.densenet169(pretrained=True)
        self.resnet = torchvision.models.resnet50(pretrained=True)
        self.vgg = torchvision.models.vgg19(pretrained=True)

        # Remove the classification layers (fully connected layers)
        self.densenet = nn.Sequential(*list(self.densenet.children())[:-1])
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
        self.vgg = nn.Sequential(*list(self.vgg.children())[:-1])

    def forward(self, x):
        x1 = self.densenet(x)
        x2 = self.resnet(x)
        x3 = self.vgg(x)

        # Flatten and concatenate
        x1 = x1.view(x1.size(0), -1)
        x2 = x2.view(x2.size(0), -1)
        x3 = x3.view(x3.size(0), -1)
        x = torch.cat((x1, x2, x3), dim=1)

        x = self.fc(x)
        return x

In [15]:
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.densenet = torchvision.models.densenet169(pretrained=True)
        self.resnet = torchvision.models.resnet50(pretrained=True)
        self.vgg = torchvision.models.vgg19(pretrained=True)

        # Remove the classification layers (fully connected layers)
        self.densenet = nn.Sequential(*list(self.densenet.children())[:-1])
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
        self.vgg = nn.Sequential(*list(self.vgg.children())[:-1])

        # Define new classification layers
        self.fc = nn.Sequential(
            nn.Dropout(0.4),
            nn.BatchNorm1d(1664 + 2048 + 512), # Output features from densenet + resnet + vgg
            nn.Linear(1664 + 2048 + 512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 2),
        )

    def forward(self, x):
        x1 = self.densenet(x)
        x2 = self.resnet(x)
        x3 = self.vgg(x)

        # Flatten and concatenate
        x1 = x1.view(x1.size(0), -1)
        x2 = x2.view(x2.size(0), -1)
        x3 = x3.view(x3.size(0), -1)
        x = torch.cat((x1, x2, x3), dim=1)

        x = self.fc(x)
        return x

In [19]:
# Arguments and parameters
num_clients = 4
lr = 0.0001
lr_decay = 0.9
batch_size = 32
num_epochs = 5
quant_bin = 8 # quantization parameter
theta = 0.15 # DP noise parameter

In [20]:
# Make models for each client
models = []
optimizers = []
schedulers = []

for i in range(num_clients+1):
    if i == num_clients:
        model = ClientModel()
    else:
        model = ServerModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, gamma=lr_decay)
    
    models.append(model)
    optimizers.append(optimizer)
    schedulers.append(scheduler)

criterion = nn.BCEWithLogitsLoss() # Binary cross-entropy

In [21]:
# Discrete differential privacy noise
def quantize(x, theta, m):
    p = torch.add(0.5, torch.mul(theta, x))

    binom = Binomial(m, p)
    noise = binom.sample()
    
    y = x.clone()
    y.data = noise
    
    return y

def dequantize(q, theta, m, n):
    det = torch.sub(q, m * n / 2)
    sum = torch.div(det, theta * m)
    return sum

In [None]:
# Split datasets among parties
# Split train-validation
???

In [None]:
# Load datasets
transform = transforms.Compose([transforms.Resize(256), transforms.ToTensor(),])

train_loaders = []
test_loaders =[]

for i in range(num_clients):
    train_data = torchvision.datasets.ImageFolder(root="???", transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,  num_workers=4)
    train_loaders.append(train_loader)
    test_data = torchvision.datasets.ImageFolder(root="???", transform=transform)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loaders.append(test_loader)

In [22]:
# train function
def train():
    for i in range(num_clients):
        for _, (inputs, targets) in enumerate(train_loaders[i]):
            # At party side
            # generate embedding
            embedding_grad = models[i](torch.transpose(inputs, 0, 1))
            with torch.no_grad():
                embedding_nograd = models[i](torch.transpose(inputs, 0, 1))
            
            # add differential privacy noise
            embedding_nograd = quantize(H_nograd, theta, quant_bin)
            
            # send embedding to smart contract
            ???
            
            # At server side
            # after all parties sent embeddings to smart contract
            if i == num_clients - 1:
                # retrieve the embedding sum from smart contract
                embedding_sum = ???
                
                # dequantize the discrete sum into continuous sum
                embedding_sum = dequantize(embedding_sum, theta, quant_bin, num_clients)

                # compute outputs
                outputs = models[num_clients](embedding_sum)
                loss = criterion(outputs, targets)
                
                # parties and server compute gradient and do SGD step
                for j in range(num_clients + 1):
                    optimizers[j].zero_grad()
                loss.backward()
                for j in range(num_clients + 1):
                    optimizer[j].step()
        
        # parties calculate new learning rate
        schedulers[i].step()
        
        # server calculate new learning rate
        if i == num_clients - 1:
            schedulers[num_clients].step()

In [23]:
# evaluation function for validation and testing
def evaluate(mode):
    # validation or testing
    data_loaders = None
    if mode == 'validation':
        data_loaders = train_loaders
    else:
        data_loaders = test_loaders
    
    # initialize variables
    embeddings = [None] * num_clients
    total = 0
    correct = 0
    total_loss = 0
    n = 0

    for i in range(num_clients):
        for _, (inputs, targets) in enumerate(data_loaders[i]):
            # At party side
            # generate embeddings
            embeddings[i] = models[i](torch.transpose(inputs, 0, 1))
            
            # At server side
            if i == num_clients - 1:
                embedding_sum = torch.sum(torch.stack(embeddings),axis=0)

                # compute outputs
                outputs = models[num_clients](embedding_sum)
                loss = criterion(outputs, targets)
                _, predicted = torch.max(outputs.data, 1)

                # compute accuracy
                correct += (predicted == targets).sum()
                total += targets.size(0)
                
                # compute loss
                total_loss += loss
                n += 1
    
    accuracy = 100 * correct / total
    loss = total_loss / n

    return (accuracy, loss)

In [None]:
# Main training loop

for epoch in range(num_epochs):
    print('\n-----------------------------------')
    print('Epoch: [%d/%d]' % (epoch+1, num_epochs))

    train()
    val_accuracy, val_loss = evaluate(mode = 'validation')
    test_accuracy, test_loss = evaluate(mode = 'test')
    
    print('Val Loss: {:.2f} \t Val Accuracy: {:.2f} \t Test Loss: {:.2f} \t Test Accuracy: {:.2f}'.format(val_loss, val_accuracy, test_loss, test_accuracy))