In [None]:
# we do not need to generator 32 channels together. One way is to generate them 

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, inputsize, hiddensize, outputsize):
        super(Generator, self).__init__()
        self.inputsize = inputsize # 8, 16, 24
        self.outputsize = outputsize
        self.hiddensize = hiddensize
        self.section1 = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(self.inputsize, hiddensize * 8, 3, 1, padding=1, bias=False, dilation=1),
            nn.BatchNorm2d(hiddensize * 8),
            nn.ReLU(True)
        )
        self.section2 = nn.Sequential(
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(hiddensize * 8, hiddensize * 4, 3, 1, padding=1 , bias=False, dilation=1),
            nn.BatchNorm2d(hiddensize * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(hiddensize * 4, hiddensize * 2, 3, 1, padding=1, bias=False, dilation=1),
            nn.BatchNorm2d(hiddensize * 2),
            nn.ReLU(True),
            # # state size. (ngf*2) x 16 x 16
            # nn.ConvTranspose2d(hiddensize * 2, hiddensize, 4, 1, 0, bias=False),
            # nn.BatchNorm2d(hiddensize),
            # nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(hiddensize * 2, self.outputsize, 3, 1, padding=1, bias=False, dilation=1),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        output = self.section1(input)
        output = self.section2(output)
        return output

In [5]:
# use the client and server for mobilenetv2
import torch
from Models.mobilenetv2 import mobilenetv2_splitter
client, server = mobilenetv2_splitter(num_classes = 100, 
                                      weight_root = '/home/tonypeng/Workspace1/adaptfilter/Adaptfilter/Weights/cifar-10', 
                                      device = 'cuda:0', 
                                      partition = -1)
# get the dataloader 
from Dataloaders import dataloader_cifar10
train, test, classes = dataloader_cifar10.Dataloader_cifar10()

client = client.to('cuda:0')
server = server.to('cuda:0')
client.eval()
server.eval()

with torch.no_grad():
    correct = 0
    for i, data in enumerate(test, 0):
        inputs, labels = data
        inputs, labels = inputs.to('cuda:0'), labels.to('cuda:0')
        outputs = client(inputs)
        outputs = server(outputs)
        outputs = torch.argmax(outputs, dim = 1)
        correct_rate = (outputs == labels).sum().item()
        correct += correct_rate
        print('Accuracy ', correct_rate/100)
    print('Total Accuracy ', correct/50000)
    


Files already downloaded and verified
Files already downloaded and verified
Accuracy  1.13
Accuracy  1.15
Accuracy  1.16
Accuracy  1.09
Accuracy  1.18
Accuracy  1.19
Accuracy  1.11
Accuracy  1.15
Accuracy  1.12
Accuracy  1.17
Accuracy  1.17
Accuracy  1.13
Accuracy  1.14
Accuracy  1.1
Accuracy  1.11
Accuracy  1.04
Accuracy  1.16
Accuracy  1.17
Accuracy  1.15
Accuracy  1.12
Accuracy  1.15
Accuracy  1.15
Accuracy  1.17
Accuracy  1.12
Accuracy  1.19
Accuracy  1.17
Accuracy  1.15
Accuracy  1.16
Accuracy  1.12
Accuracy  1.09
Accuracy  1.17
Accuracy  1.19
Accuracy  1.2
Accuracy  1.12
Accuracy  1.16
Accuracy  1.15
Accuracy  1.2
Accuracy  1.17
Accuracy  1.16
Accuracy  1.22
Accuracy  1.1
Accuracy  1.16
Accuracy  1.19
Accuracy  1.1
Accuracy  1.2
Accuracy  1.11
Accuracy  1.1
Accuracy  1.19
Accuracy  1.17
Accuracy  1.16
Accuracy  1.03
Accuracy  1.15
Accuracy  1.15
Accuracy  1.16
Accuracy  1.12
Accuracy  1.16
Accuracy  1.01
Accuracy  1.1
Accuracy  1.14
Accuracy  1.12
Accuracy  1.15
Accuracy  1.12
Ac