In [1]:
import numpy as np
import torch
from torchvision import transforms, datasets
from torchvision.utils import save_image

import unsplit.attacks as unsplit
from unsplit.models import *
from unsplit.util import *

In [52]:
class MnistMLP(torch.nn.Module):
    def __init__(self):
        super(MnistMLP, self).__init__()
        #self.features = []
        self.classifier = []
        self.layers = collections.OrderedDict()

        self.fc1 = torch.nn.Linear(28 * 28, 256, bias=False)
        self.classifier.append(self.fc1) #0
        self.layers["fc1"] = self.fc1

        self.fc1act = torch.nn.ReLU(inplace=False)
        self.classifier.append(self.fc1act) #1
        self.layers["fc1act"] = self.fc1act

        self.fc2 = torch.nn.Linear(256, 120, bias=False)
        self.classifier.append(self.fc2) #2
        self.layers["fc2"] = self.fc2

        self.fc2act = torch.nn.ReLU(inplace=False)
        self.classifier.append(self.fc2act) #3
        self.layers["fc2act"] = self.fc2act

        self.fc3 = torch.nn.Linear(120, 10, bias=False)
        self.classifier.append(self.fc3) #4
        self.layers["fc3"] = self.fc3

        self.initial_params = [
            param.clone().detach().data for param in self.parameters()
        ]

    def forward(self, x, start=0, end=4):
        if start == end:
            raise ValueError("start should be less than end")
        else:
            x = x.view(-1, 28 * 28)
            for idx, layer in enumerate(self.classifier[start:]):
                x = layer(x)
                if idx == end:
                    return x
                
    def get_params(self, end=4):
        params = []
        for layer in list(self.layers.values())[: end + 1]:
            params += list(layer.parameters())
        return params

    def restore_initial_params(self):
        for param, initial in zip(self.parameters(), self.initial_params):
            param.data = initial.requires_grad_(True)

In [53]:
dataset = 'mnist'

if dataset == 'mnist':
    trainset = datasets.MNIST('data/mnist', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.MNIST('data/mnist', download=True, train=False, transform=transforms.ToTensor())
    client, server, clone = MnistMLP(), MnistMLP(), MnistMLP()
elif dataset == 'f_mnist':
    trainset = datasets.FashionMNIST('data/f_mnist', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.FashionMNIST('data/f_mnist', download=True, train=False, transform=transforms.ToTensor())
    client, server, clone = MnistMLP(), MnistMLP(), MnistMLP()
elif dataset == 'cifar':
    trainset = datasets.CIFAR10('data/cifar', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.CIFAR10('data/cifar', download=True, train=False, transform=transforms.ToTensor())
    client, server, clone = CifarMLP(), CifarMLP(), CifarMLP()

trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=64)
testloader = torch.utils.data.DataLoader(testset, shuffle=True)

In [48]:
split_layer = 2

In [50]:
client_opt = torch.optim.Adam(client.parameters(), lr=0.001, amsgrad=True)
server_opt = torch.optim.Adam(server.parameters(), lr=0.001, amsgrad=True)
criterion = torch.nn.CrossEntropyLoss()

epochs = 10
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        client_opt.zero_grad()
        server_opt.zero_grad()
        
        images = images.view(-1, 784)
        pred = server(client(images, end=split_layer), start=split_layer+1)
        print(pred)

        loss = criterion(pred, labels)
        loss.backward()
        running_loss += loss

        server_opt.step()
        client_opt.step()
    else:
        print(f'Epoch: {epoch} Loss: {running_loss / len(trainloader)} Acc: {get_test_score(client, server, testset, split=split_layer)}')


None


TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not NoneType

In [65]:
import math

class DefenseLayer(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, input_dim, output_dim, U):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(input_dim, output_dim), requires_grad=True)
        self.U = U

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        torch.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, input):
        #return input @ self.U.weight.T.detach() @ self.weight
        output = torch.matmul(torch.matmul(input, self.U.weight.T), self.weight)
        assert output.requires_grad, "Output should have requires_grad=True"
        return output

In [66]:
U = torch.nn.utils.parametrizations.orthogonal(nn.Linear(120, 120, bias=False))

In [75]:
class ClientMLP(torch.nn.Module):
    def __init__(self, U=None):
        super(ClientMLP, self).__init__()
        #self.features = []
        self.classifier = []
        self.U = U
        self.layers = collections.OrderedDict()

        self.fc1 = torch.nn.Linear(28 * 28, 256, bias=False)
        self.classifier.append(self.fc1) #0
        self.layers["fc1"] = self.fc1

        self.fc1act = torch.nn.ReLU(inplace=False)
        self.classifier.append(self.fc1act) #1
        self.layers["fc1act"] = self.fc1act

        self.fc2 = torch.nn.Linear(256, 120, bias=False)
        self.classifier.append(self.fc2) #2
        self.layers["fc2"] = self.fc2

        self.fc2act = torch.nn.ReLU(inplace=False)
        self.classifier.append(self.fc2act) #3
        self.layers["fc2act"] = self.fc2act

        if self.U:
            self.defense = DefenseLayer(120, 120, U)
            self.classifier.append(self.defense)
            self.layers["defense"] = self.defense

        self.initial_params = [
            param.clone().detach().data for param in self.parameters()
        ]

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1act(self.fc1(x))
        x = self.fc2act(self.fc2(x))
        if self.U:
            x = self.defense(x)
        return x
                
    def get_params(self):
        params = []
        for layer in list(self.layers.values()):
            params += list(layer.parameters())
        return params

    def restore_initial_params(self):
        for param, initial in zip(self.parameters(), self.initial_params):
            param.data = initial.requires_grad_(True)


class ServerMLP(torch.nn.Module):
    def __init__(self):
        super(ServerMLP, self).__init__()
        #self.features = []
        self.classifier = []
        self.layers = collections.OrderedDict()

        self.fc3 = torch.nn.Linear(120, 10, bias=False)
        self.classifier.append(self.fc3) #4
        self.layers["fc3"] = self.fc3

        self.initial_params = [
            param.clone().detach().data for param in self.parameters()
        ]

    def forward(self, x):
        x = self.fc3(x)
        return x
                
    def get_params(self):
        params = []
        for layer in list(self.layers.values()):
            params += list(layer.parameters())
        return params

    def restore_initial_params(self):
        for param, initial in zip(self.parameters(), self.initial_params):
            param.data = initial.requires_grad_(True)

In [79]:
client, server = ClientMLP(U=U), ServerMLP()

In [71]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(client)
print_trainable_parameters(ClientMLP())

trainable params: 260,224 || all params: 260,224 || trainable%: 100.00
trainable params: 231,424 || all params: 231,424 || trainable%: 100.00


In [63]:
def get_test_score(m1, m2, dataset, split=0):
    score = 0
    imageloader = get_random_example(dataset, count=2000)
    for image, label in imageloader:
        pred = m2(m1(image))
        if torch.argmax(pred) == label.detach():
            score += 1
    return 100 * score / len(imageloader)

In [81]:
with torch.no_grad():
    for images, labels in trainloader:
            client_opt.zero_grad()
            server_opt.zero_grad()
            
            pred = server(client(images))

            loss = criterion(pred, labels)
            print(loss)
            break

AssertionError: Output should have requires_grad=True

In [80]:
client_opt = torch.optim.Adam(client.parameters(), lr=0.001, amsgrad=True)
server_opt = torch.optim.Adam(server.parameters(), lr=0.001, amsgrad=True)
criterion = torch.nn.CrossEntropyLoss()

epochs = 10
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        client_opt.zero_grad()
        server_opt.zero_grad()
        
        pred = server(client(images))

        loss = criterion(pred, labels)
        loss.backward()
        running_loss += loss

        server_opt.step()
        client_opt.step()
    else:
        print(f'Epoch: {epoch} Loss: {running_loss / len(trainloader)} Acc: {get_test_score(client, server, testset, split=0)}')


Epoch: 0 Loss: nan Acc: 10.219780219780219


KeyboardInterrupt: 