In [3]:
import torch

embedding = torch.rand(10, 56,100,100)

In [14]:
# let's get the client and server model from the embedding
from Models.mobilenetv2 import MobileNetV2
from Dataloaders.dataloader_cifar10 import Dataloader_cifar10
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm import tqdm
import os
import sys
import cv2

model = MobileNetV2(num_classes=10)
client_model = model.get_client_model()
server_model = model.get_server_model()

# split the server model into 2 parts
server_model_1 = server_model[:len(server_model)-1]
server_model_2 = server_model[-1]

# client model and server 1 are fixed, server 2 is trainable
client_model = client_model.train()
server_model_1 = server_model_1.train()
server_model_2 = server_model_2.train()

loss = torch.nn.CrossEntropyLoss()
# https://github.com/d-li14/mobilenetv2.pytorch
# batch size 256; epoch 150; learning rate 0.05; LR decay strategy cosine; weight decay 0.00004
optimizer = torch.optim.Adam(server_model_2.parameters(), lr=0.05)

# put them in cuda
client_model = client_model.cuda()
server_model_1 = server_model_1.cuda()
server_model_2 = server_model_2.cuda()

# get the data
train, test, val, labels = Dataloader_cifar10(128, 2014)

# train
epoch = 20
for i in tqdm(range(epoch)):
    for data in train:
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        # with torch.no_grad():
        outputs = client_model(inputs)
        outputs = server_model_1(outputs)
        outputs = server_model_2(outputs)
        # get the index of the max value of the output?
        loss_train = loss(outputs, labels)
        loss_train.backward()
        optimizer.step()
    print('train loss: ',loss_train.item())
    
    # val
    if i % 5 == 0:
        for data in val:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                outputs = client_model(inputs)
                outputs = server_model_1(outputs)
                outputs = server_model_2(outputs)
            loss_val = loss(outputs, labels)
        print('val loss: ', loss_val.item())

# test
correct = 0
total = 0
with torch.no_grad():
    for data in test:
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        outputs = client_model(inputs)
        outputs = server_model_1(outputs)
        outputs = server_model_2(outputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % ( 100 * correct / total)) 


 

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/100 [00:00<?, ?it/s]

train loss:  0.6799212098121643


  1%|          | 1/100 [00:13<21:46, 13.20s/it]

val loss:  1.2833060026168823


  2%|▏         | 2/100 [00:23<18:45, 11.48s/it]

train loss:  0.673984169960022


  3%|▎         | 3/100 [00:33<17:42, 10.96s/it]

train loss:  1.7235395908355713


  4%|▍         | 4/100 [00:44<17:05, 10.69s/it]

train loss:  1.005233645439148


  5%|▌         | 5/100 [00:54<16:41, 10.54s/it]

train loss:  1.3770822286605835
train loss:  1.5769028663635254


  6%|▌         | 6/100 [01:07<17:52, 11.41s/it]

val loss:  1.8506985902786255


  7%|▋         | 7/100 [01:17<17:08, 11.05s/it]

train loss:  1.7342307567596436


  8%|▊         | 8/100 [01:28<16:36, 10.83s/it]

train loss:  1.3301993608474731


  9%|▉         | 9/100 [01:38<16:10, 10.67s/it]

train loss:  1.7982712984085083


 10%|█         | 10/100 [01:48<15:51, 10.57s/it]

train loss:  1.2410037517547607
train loss:  2.573725700378418


 11%|█         | 11/100 [02:01<16:44, 11.29s/it]

val loss:  0.5315096974372864


 12%|█▏        | 12/100 [02:12<16:07, 11.00s/it]

train loss:  1.8356878757476807


 13%|█▎        | 13/100 [02:22<15:38, 10.79s/it]

train loss:  1.3045153617858887


 14%|█▍        | 14/100 [02:32<15:15, 10.64s/it]

train loss:  1.2198320627212524


 15%|█▌        | 15/100 [02:43<14:56, 10.55s/it]

train loss:  0.5310781002044678
train loss:  0.7152358293533325


 16%|█▌        | 16/100 [02:56<15:51, 11.32s/it]

val loss:  2.3010005950927734


 17%|█▋        | 17/100 [03:06<15:15, 11.03s/it]

train loss:  2.519162654876709


 18%|█▊        | 18/100 [03:16<14:46, 10.82s/it]

train loss:  2.133892059326172


 19%|█▉        | 19/100 [03:27<14:23, 10.67s/it]

train loss:  0.8783572316169739


 20%|██        | 20/100 [03:37<14:03, 10.55s/it]

train loss:  1.0628490447998047
train loss:  0.5412212610244751


 21%|██        | 21/100 [03:50<14:53, 11.31s/it]

val loss:  2.2737250328063965


 22%|██▏       | 22/100 [04:00<14:17, 11.00s/it]

train loss:  1.918734073638916


 23%|██▎       | 23/100 [04:10<13:49, 10.78s/it]

train loss:  1.17880117893219


 24%|██▍       | 24/100 [04:21<13:26, 10.62s/it]

train loss:  0.874615490436554


 25%|██▌       | 25/100 [04:31<13:09, 10.52s/it]

train loss:  1.3317281007766724
train loss:  1.4357728958129883


 26%|██▌       | 26/100 [04:44<13:54, 11.28s/it]

val loss:  1.5737113952636719


 27%|██▋       | 27/100 [04:54<13:20, 10.97s/it]

train loss:  1.4533605575561523


 28%|██▊       | 28/100 [05:05<12:54, 10.75s/it]

train loss:  1.5515905618667603


 29%|██▉       | 29/100 [05:15<12:32, 10.59s/it]

train loss:  2.013132095336914


 30%|███       | 30/100 [05:25<12:14, 10.49s/it]

train loss:  1.668703556060791
train loss:  1.8721858263015747


 31%|███       | 31/100 [05:38<12:56, 11.25s/it]

val loss:  2.051382303237915


 32%|███▏      | 32/100 [05:48<12:22, 10.93s/it]

train loss:  1.4211286306381226


 33%|███▎      | 33/100 [05:59<11:59, 10.73s/it]

train loss:  0.9488620758056641


 34%|███▍      | 34/100 [06:09<11:39, 10.60s/it]

train loss:  0.7363617420196533


 35%|███▌      | 35/100 [06:19<11:21, 10.49s/it]

train loss:  0.6525201797485352
train loss:  3.2339415550231934


 36%|███▌      | 36/100 [06:32<12:00, 11.26s/it]

val loss:  2.1911637783050537


 37%|███▋      | 37/100 [06:42<11:31, 10.97s/it]

train loss:  0.810379683971405


 38%|███▊      | 38/100 [06:53<11:05, 10.73s/it]

train loss:  0.7357950806617737


 39%|███▉      | 39/100 [07:02<10:36, 10.44s/it]

train loss:  0.7441762089729309


 40%|████      | 40/100 [07:12<10:14, 10.23s/it]

train loss:  1.502135992050171
train loss:  2.4388601779937744


 41%|████      | 41/100 [07:25<10:42, 10.89s/it]

val loss:  3.116422176361084


 42%|████▏     | 42/100 [07:34<10:12, 10.56s/it]

train loss:  0.8879203200340271


 43%|████▎     | 43/100 [07:44<09:48, 10.33s/it]

train loss:  3.0048987865448


 44%|████▍     | 44/100 [07:54<09:30, 10.18s/it]

train loss:  0.790829062461853


 45%|████▌     | 45/100 [08:04<09:14, 10.08s/it]

train loss:  1.6766748428344727
train loss:  2.9241979122161865


 46%|████▌     | 46/100 [08:16<09:43, 10.81s/it]

val loss:  3.343167543411255


 47%|████▋     | 47/100 [08:26<09:17, 10.52s/it]

train loss:  0.9278885722160339


 48%|████▊     | 48/100 [08:36<08:56, 10.32s/it]

train loss:  2.8830673694610596


 49%|████▉     | 49/100 [08:46<08:38, 10.17s/it]

train loss:  1.764209508895874


 50%|█████     | 50/100 [08:56<08:23, 10.07s/it]

train loss:  1.4193823337554932
train loss:  1.27657151222229


 51%|█████     | 51/100 [09:08<08:48, 10.79s/it]

val loss:  3.723851442337036


 52%|█████▏    | 52/100 [09:18<08:24, 10.50s/it]

train loss:  0.3444065451622009


 53%|█████▎    | 53/100 [09:28<08:04, 10.30s/it]

train loss:  2.6777687072753906


 54%|█████▍    | 54/100 [09:38<07:47, 10.17s/it]

train loss:  2.413179636001587


 55%|█████▌    | 55/100 [09:47<07:33, 10.07s/it]

train loss:  1.8058786392211914
train loss:  1.6517902612686157


 56%|█████▌    | 56/100 [10:00<07:55, 10.81s/it]

val loss:  1.737433671951294


 57%|█████▋    | 57/100 [10:10<07:32, 10.53s/it]

train loss:  1.4296820163726807


 58%|█████▊    | 58/100 [10:20<07:13, 10.33s/it]

train loss:  0.8381726741790771


 59%|█████▉    | 59/100 [10:30<06:57, 10.19s/it]

train loss:  1.052212119102478


 60%|██████    | 60/100 [10:39<06:43, 10.09s/it]

train loss:  2.3655693531036377
train loss:  2.086556911468506


 61%|██████    | 61/100 [10:52<07:01, 10.82s/it]

val loss:  5.907161712646484


 62%|██████▏   | 62/100 [11:02<06:39, 10.52s/it]

train loss:  0.5521330833435059


 63%|██████▎   | 63/100 [11:12<06:21, 10.31s/it]

train loss:  1.9706764221191406


 64%|██████▍   | 64/100 [11:21<06:05, 10.16s/it]

train loss:  0.26285940408706665


 65%|██████▌   | 65/100 [11:31<05:52, 10.07s/it]

train loss:  2.272261381149292
train loss:  1.8154892921447754


 66%|██████▌   | 66/100 [11:44<06:07, 10.80s/it]

val loss:  4.818183898925781


 67%|██████▋   | 67/100 [11:54<05:46, 10.51s/it]

train loss:  0.9488795399665833


 68%|██████▊   | 68/100 [12:04<05:30, 10.32s/it]

train loss:  2.269735336303711


 69%|██████▉   | 69/100 [12:13<05:14, 10.16s/it]

train loss:  1.4881142377853394


 70%|███████   | 70/100 [12:23<05:02, 10.07s/it]

train loss:  1.212486982345581
train loss:  1.7686767578125


 71%|███████   | 71/100 [12:36<05:13, 10.80s/it]

val loss:  2.2947144508361816


 72%|███████▏  | 72/100 [12:45<04:54, 10.51s/it]

train loss:  2.8540425300598145


 73%|███████▎  | 73/100 [12:55<04:37, 10.29s/it]

train loss:  2.2752208709716797


 74%|███████▍  | 74/100 [13:05<04:23, 10.15s/it]

train loss:  1.3645622730255127


 75%|███████▌  | 75/100 [13:15<04:11, 10.05s/it]

train loss:  1.1755837202072144
train loss:  1.0884803533554077


 76%|███████▌  | 76/100 [13:27<04:19, 10.80s/it]

val loss:  3.657409429550171


 77%|███████▋  | 77/100 [13:37<04:02, 10.53s/it]

train loss:  0.6837446689605713


 78%|███████▊  | 78/100 [13:47<03:47, 10.34s/it]

train loss:  0.7387427687644958


 79%|███████▉  | 79/100 [13:57<03:33, 10.19s/it]

train loss:  1.8861607313156128


 80%|████████  | 80/100 [14:07<03:22, 10.10s/it]

train loss:  2.4389843940734863
train loss:  0.9862842559814453


 81%|████████  | 81/100 [14:20<03:25, 10.83s/it]

val loss:  3.7880029678344727


 82%|████████▏ | 82/100 [14:29<03:09, 10.54s/it]

train loss:  0.9762613773345947


 83%|████████▎ | 83/100 [14:39<02:55, 10.33s/it]

train loss:  1.0013933181762695


 84%|████████▍ | 84/100 [14:49<02:43, 10.19s/it]

train loss:  1.3964107036590576


 85%|████████▌ | 85/100 [14:59<02:31, 10.08s/it]

train loss:  0.6514539122581482
train loss:  1.2012014389038086


 86%|████████▌ | 86/100 [15:11<02:31, 10.80s/it]

val loss:  1.4326810836791992


 87%|████████▋ | 87/100 [15:21<02:16, 10.52s/it]

train loss:  1.609519124031067


 88%|████████▊ | 88/100 [15:31<02:03, 10.32s/it]

train loss:  1.2855918407440186


 89%|████████▉ | 89/100 [15:41<01:52, 10.18s/it]

train loss:  0.8573668003082275


 90%|█████████ | 90/100 [15:51<01:40, 10.09s/it]

train loss:  0.8195911645889282
train loss:  2.1143572330474854


 91%|█████████ | 91/100 [16:03<01:37, 10.83s/it]

val loss:  4.382082939147949


 92%|█████████▏| 92/100 [16:13<01:24, 10.53s/it]

train loss:  2.0743515491485596


 93%|█████████▎| 93/100 [16:23<01:12, 10.32s/it]

train loss:  1.1139471530914307


 94%|█████████▍| 94/100 [16:33<01:01, 10.18s/it]

train loss:  1.334349274635315


 95%|█████████▌| 95/100 [16:43<00:50, 10.08s/it]

train loss:  1.6716196537017822
train loss:  0.9433549642562866


 96%|█████████▌| 96/100 [16:55<00:43, 10.82s/it]

val loss:  2.364748477935791


 97%|█████████▋| 97/100 [17:05<00:31, 10.53s/it]

train loss:  1.227202296257019


 98%|█████████▊| 98/100 [17:15<00:20, 10.32s/it]

train loss:  1.9942082166671753


 99%|█████████▉| 99/100 [17:25<00:10, 10.19s/it]

train loss:  1.7990056276321411


100%|██████████| 100/100 [17:35<00:00, 10.55s/it]

train loss:  2.167130470275879





Accuracy of the network on the 10000 test images: 79 %
