In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from taskdataset import TaskDataset
import matplotlib.pyplot as plt
import os
import time

from t2_functions import partition_ids

In [2]:
import sys

path = os.getcwd()
print(path)
os.chdir('..')
print(os.getcwd())
dataset_t1 = torch.load("task_1_modelstealing/data/ModelStealingPub.pt")

sys.path.append(os.path.join(os.getcwd(), "task_2_sybilattack/"))

from endpoints.requests import sybil, sybil_reset


/home/janek/Documents/Hackaton/ensembleAI-ScoutTeam/task_2_sybilattack
/home/janek/Documents/Hackaton/ensembleAI-ScoutTeam


In [3]:
os.chdir(path)
print(os.getcwd())
dataset = torch.load("data/SybilAttack.pt")

/home/janek/Documents/Hackaton/ensembleAI-ScoutTeam/task_2_sybilattack


In [4]:
ids = np.array(dataset.ids)
print(len(ids))
print(ids[:10])
binned_ids = partition_ids(ids, main_bin_num=10, train=0.1, test=0.9)

20000
[101031   8526  43127 191394 298792 121086 149475 102605 163605 101855]
bin_size: 2000
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800


In [13]:
task = 'binary'

sybil_reset(home_or_defense='home', binary_or_affine=task)
sybil_reset(home_or_defense='defense', binary_or_affine=task)


Request ok
{'msg': 'Successful sybil binary reset home'}
Request ok
{'msg': 'Successful sybil binary reset defense'}


In [14]:
ids_train, ids_test = binned_ids[0]
print(len(ids_train))
print(len(ids_test))
A_train_reps = sybil(ids=ids_train,
                 home_or_defense='home',
                 binary_or_affine=task)

B_train_reps = sybil(ids=ids_train,
                 home_or_defense='defense',
                 binary_or_affine=task)
print(f"A train reps: {len(A_train_reps)}")

A_test_reps = sybil(ids=ids_test,
                 home_or_defense='home',
                 binary_or_affine=task)

B_test_reps = sybil(ids=ids_test,
                 home_or_defense='defense',
                 binary_or_affine=task)
print(f"A test reps: {len(A_test_reps)}")


200
1800
A train reps: 200
A test reps: 1800


In [15]:
class RepresentationsDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x, y = self.x[idx], self.y[idx]
        return torch.tensor(x), torch.tensor(y)


In [16]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [17]:
class Linear(nn.Module):
    def __init__(self, input_size, output_size):
        super(Linear, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)
    
    def forward(self, x):
        x = self.fc1(x)
        return x

In [18]:
def validate(criterion, loader, net):
    net.eval()

    with torch.no_grad():
        true = []
        pred = []
        for x, y in loader:
            y_hat = net(x)
            true.append(y.numpy())
            pred.append(y_hat.numpy())

        true = torch.tensor(np.concatenate(true, axis=0))
        pred = torch.tensor(np.concatenate(pred, axis=0))

        loss = criterion(pred, true)

    return loss


In [19]:
def l1_reg(model, reg_lambda):
    l1_regularization = torch.tensor(0., device=model.parameters().__next__().device)
    for param in model.parameters():
        l1_regularization += torch.norm(param, p=1)
    return reg_lambda * l1_regularization


In [20]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def train(epochs, optim, criterion, regularise, trainloader, valloader, net, empty_net, reg_lambda=None):
    best_val_loss = np.inf
    best_net = empty_net
    writer = SummaryWriter()


    for epoch in range(epochs):  # loop over the dataset multiple times
        train_loss = 0

        progress_bar = tqdm(trainloader)

        for iter, (x, y) in enumerate(progress_bar):
            net.train()
            optim.zero_grad()

            y_hat = net(x)
            loss = criterion(y_hat, y)
            if reg_lambda is not None:
                loss += regularise(net, reg_lambda)
            loss.backward()
            optim.step()

            batch_loss = loss.item()
            train_loss += batch_loss
            if iter % 20 == 0:
                progress_bar.set_description(f"train | loss: {batch_loss:.4f}")

            writer.add_scalar('Training Loss', batch_loss, epoch)

        train_loss /= len(trainloader)
        print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}")

        net.eval()

        val_loss = validate(criterion, valloader, net)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_net.load_state_dict(net.state_dict())

        if writer is not None:
            # log the validation loss and accuracy
            writer.add_scalar('Validation Loss', val_loss, epoch)

        print(f"Epoch [{epoch + 1}/{epochs}], Val Loss: {val_loss:.4f}")

    return net, best_net



In [21]:
lin_net = Linear(384, 384)
lin_empty_net = Linear(384, 384)

mlp = MLP(384, 384, 384)
empty_mpl = MLP(384, 384, 384)

criterion = nn.MSELoss()
batch_size = 16
epochs = 100
lr = 0.001


In [24]:
train_dataset = RepresentationsDataset(x=A_train_reps, y=B_train_reps)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = RepresentationsDataset(x=A_test_reps, y=B_test_reps)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [25]:
optim = torch.optim.Adam(lr=lr, params=lin_net.parameters())

lin_last_net, lin_best_net = train(epochs, optim, criterion, l1_reg, train_loader, test_loader, lin_net, lin_empty_net, reg_lambda=0.0001)

train | loss: 1.1027: 100%|██████████| 13/13 [00:00<00:00, 86.47it/s]


Epoch [1/100], Train Loss: 0.6462
Epoch [1/100], Val Loss: 0.4872


train | loss: 0.5411: 100%|██████████| 13/13 [00:00<00:00, 83.60it/s]

Epoch [2/100], Train Loss: 0.4695





Epoch [2/100], Val Loss: 0.3832


train | loss: 0.4154: 100%|██████████| 13/13 [00:00<00:00, 143.37it/s]

Epoch [3/100], Train Loss: 0.3910





Epoch [3/100], Val Loss: 0.3639


train | loss: 0.3642: 100%|██████████| 13/13 [00:00<00:00, 77.03it/s]

Epoch [4/100], Train Loss: 0.3392





Epoch [4/100], Val Loss: 0.3510


train | loss: 0.3126: 100%|██████████| 13/13 [00:00<00:00, 117.28it/s]

Epoch [5/100], Train Loss: 0.3018





Epoch [5/100], Val Loss: 0.3435


train | loss: 0.2903: 100%|██████████| 13/13 [00:00<00:00, 56.89it/s]


Epoch [6/100], Train Loss: 0.2715
Epoch [6/100], Val Loss: 0.3370


train | loss: 0.2568: 100%|██████████| 13/13 [00:00<00:00, 60.57it/s]


Epoch [7/100], Train Loss: 0.2468
Epoch [7/100], Val Loss: 0.3349


train | loss: 0.2355: 100%|██████████| 13/13 [00:00<00:00, 124.40it/s]

Epoch [8/100], Train Loss: 0.2265





Epoch [8/100], Val Loss: 0.3298


train | loss: 0.2133: 100%|██████████| 13/13 [00:00<00:00, 126.10it/s]

Epoch [9/100], Train Loss: 0.2091





Epoch [9/100], Val Loss: 0.3262


train | loss: 0.2066: 100%|██████████| 13/13 [00:00<00:00, 63.66it/s]


Epoch [10/100], Train Loss: 0.1941
Epoch [10/100], Val Loss: 0.3241


train | loss: 0.1749: 100%|██████████| 13/13 [00:00<00:00, 125.38it/s]

Epoch [11/100], Train Loss: 0.1821





Epoch [11/100], Val Loss: 0.3227


train | loss: 0.1710: 100%|██████████| 13/13 [00:00<00:00, 132.44it/s]

Epoch [12/100], Train Loss: 0.1717





Epoch [12/100], Val Loss: 0.3222


train | loss: 0.1636: 100%|██████████| 13/13 [00:00<00:00, 71.54it/s]

Epoch [13/100], Train Loss: 0.1627





Epoch [13/100], Val Loss: 0.3193


train | loss: 0.1558: 100%|██████████| 13/13 [00:00<00:00, 125.61it/s]

Epoch [14/100], Train Loss: 0.1535





Epoch [14/100], Val Loss: 0.3218


train | loss: 0.1472: 100%|██████████| 13/13 [00:00<00:00, 129.53it/s]

Epoch [15/100], Train Loss: 0.1474





Epoch [15/100], Val Loss: 0.3195


train | loss: 0.1390: 100%|██████████| 13/13 [00:00<00:00, 78.98it/s]

Epoch [16/100], Train Loss: 0.1414





Epoch [16/100], Val Loss: 0.3222


train | loss: 0.1334: 100%|██████████| 13/13 [00:00<00:00, 118.34it/s]

Epoch [17/100], Train Loss: 0.1370





Epoch [17/100], Val Loss: 0.3183


train | loss: 0.1294: 100%|██████████| 13/13 [00:00<00:00, 126.69it/s]

Epoch [18/100], Train Loss: 0.1323





Epoch [18/100], Val Loss: 0.3185


train | loss: 0.1237: 100%|██████████| 13/13 [00:00<00:00, 69.90it/s]

Epoch [19/100], Train Loss: 0.1289





Epoch [19/100], Val Loss: 0.3150


train | loss: 0.1196: 100%|██████████| 13/13 [00:00<00:00, 62.24it/s]


Epoch [20/100], Train Loss: 0.1261
Epoch [20/100], Val Loss: 0.3159


train | loss: 0.1249: 100%|██████████| 13/13 [00:00<00:00, 104.46it/s]

Epoch [21/100], Train Loss: 0.1239





Epoch [21/100], Val Loss: 0.3183


train | loss: 0.1167: 100%|██████████| 13/13 [00:00<00:00, 97.51it/s]

Epoch [22/100], Train Loss: 0.1219





Epoch [22/100], Val Loss: 0.3168


train | loss: 0.1166: 100%|██████████| 13/13 [00:00<00:00, 67.93it/s]

Epoch [23/100], Train Loss: 0.1201





Epoch [23/100], Val Loss: 0.3174


train | loss: 0.1260: 100%|██████████| 13/13 [00:00<00:00, 93.64it/s] 

Epoch [24/100], Train Loss: 0.1180





Epoch [24/100], Val Loss: 0.3151


train | loss: 0.1133: 100%|██████████| 13/13 [00:00<00:00, 89.05it/s]

Epoch [25/100], Train Loss: 0.1164





Epoch [25/100], Val Loss: 0.3141


train | loss: 0.1144: 100%|██████████| 13/13 [00:00<00:00, 112.28it/s]

Epoch [26/100], Train Loss: 0.1159





Epoch [26/100], Val Loss: 0.3149


train | loss: 0.1120: 100%|██████████| 13/13 [00:00<00:00, 55.66it/s]


Epoch [27/100], Train Loss: 0.1152
Epoch [27/100], Val Loss: 0.3149


train | loss: 0.1156: 100%|██████████| 13/13 [00:00<00:00, 57.44it/s]


Epoch [28/100], Train Loss: 0.1145
Epoch [28/100], Val Loss: 0.3155


train | loss: 0.1138: 100%|██████████| 13/13 [00:00<00:00, 46.76it/s]


Epoch [29/100], Train Loss: 0.1138
Epoch [29/100], Val Loss: 0.3138


train | loss: 0.1201: 100%|██████████| 13/13 [00:00<00:00, 110.78it/s]

Epoch [30/100], Train Loss: 0.1129





Epoch [30/100], Val Loss: 0.3124


train | loss: 0.1129: 100%|██████████| 13/13 [00:00<00:00, 109.30it/s]

Epoch [31/100], Train Loss: 0.1127





Epoch [31/100], Val Loss: 0.3147


train | loss: 0.1068: 100%|██████████| 13/13 [00:00<00:00, 108.77it/s]

Epoch [32/100], Train Loss: 0.1121





Epoch [32/100], Val Loss: 0.3102


train | loss: 0.1203: 100%|██████████| 13/13 [00:00<00:00, 58.73it/s]


Epoch [33/100], Train Loss: 0.1118
Epoch [33/100], Val Loss: 0.3130


train | loss: 0.1031: 100%|██████████| 13/13 [00:00<00:00, 121.35it/s]

Epoch [34/100], Train Loss: 0.1112





Epoch [34/100], Val Loss: 0.3118


train | loss: 0.1068: 100%|██████████| 13/13 [00:00<00:00, 105.91it/s]

Epoch [35/100], Train Loss: 0.1112





Epoch [35/100], Val Loss: 0.3103


train | loss: 0.1046: 100%|██████████| 13/13 [00:00<00:00, 74.67it/s]

Epoch [36/100], Train Loss: 0.1111





Epoch [36/100], Val Loss: 0.3114


train | loss: 0.1131: 100%|██████████| 13/13 [00:00<00:00, 116.80it/s]

Epoch [37/100], Train Loss: 0.1104





Epoch [37/100], Val Loss: 0.3143


train | loss: 0.1038: 100%|██████████| 13/13 [00:00<00:00, 133.61it/s]

Epoch [38/100], Train Loss: 0.1101





Epoch [38/100], Val Loss: 0.3136


train | loss: 0.1072: 100%|██████████| 13/13 [00:00<00:00, 59.46it/s]


Epoch [39/100], Train Loss: 0.1102
Epoch [39/100], Val Loss: 0.3104


train | loss: 0.1110: 100%|██████████| 13/13 [00:00<00:00, 64.84it/s]


Epoch [40/100], Train Loss: 0.1102
Epoch [40/100], Val Loss: 0.3127


train | loss: 0.1139: 100%|██████████| 13/13 [00:00<00:00, 68.41it/s]

Epoch [41/100], Train Loss: 0.1090





Epoch [41/100], Val Loss: 0.3140


train | loss: 0.1077: 100%|██████████| 13/13 [00:00<00:00, 99.81it/s]

Epoch [42/100], Train Loss: 0.1088





Epoch [42/100], Val Loss: 0.3124


train | loss: 0.1067: 100%|██████████| 13/13 [00:00<00:00, 95.19it/s]

Epoch [43/100], Train Loss: 0.1088





Epoch [43/100], Val Loss: 0.3144


train | loss: 0.1055: 100%|██████████| 13/13 [00:00<00:00, 69.97it/s]

Epoch [44/100], Train Loss: 0.1088





Epoch [44/100], Val Loss: 0.3138


train | loss: 0.1057: 100%|██████████| 13/13 [00:00<00:00, 117.11it/s]

Epoch [45/100], Train Loss: 0.1089





Epoch [45/100], Val Loss: 0.3154


train | loss: 0.1056: 100%|██████████| 13/13 [00:00<00:00, 119.44it/s]

Epoch [46/100], Train Loss: 0.1090





Epoch [46/100], Val Loss: 0.3134


train | loss: 0.1071: 100%|██████████| 13/13 [00:00<00:00, 93.81it/s]

Epoch [47/100], Train Loss: 0.1092





Epoch [47/100], Val Loss: 0.3145


train | loss: 0.1047: 100%|██████████| 13/13 [00:00<00:00, 72.70it/s]

Epoch [48/100], Train Loss: 0.1082





Epoch [48/100], Val Loss: 0.3157


train | loss: 0.1084: 100%|██████████| 13/13 [00:00<00:00, 124.06it/s]

Epoch [49/100], Train Loss: 0.1081





Epoch [49/100], Val Loss: 0.3144


train | loss: 0.1040: 100%|██████████| 13/13 [00:00<00:00, 89.34it/s]

Epoch [50/100], Train Loss: 0.1082





Epoch [50/100], Val Loss: 0.3159


train | loss: 0.1087: 100%|██████████| 13/13 [00:00<00:00, 79.27it/s]

Epoch [51/100], Train Loss: 0.1080





Epoch [51/100], Val Loss: 0.3158


train | loss: 0.1079: 100%|██████████| 13/13 [00:00<00:00, 61.22it/s]


Epoch [52/100], Train Loss: 0.1083
Epoch [52/100], Val Loss: 0.3170


train | loss: 0.1049: 100%|██████████| 13/13 [00:00<00:00, 117.16it/s]

Epoch [53/100], Train Loss: 0.1073





Epoch [53/100], Val Loss: 0.3162


train | loss: 0.1042: 100%|██████████| 13/13 [00:00<00:00, 114.50it/s]

Epoch [54/100], Train Loss: 0.1079





Epoch [54/100], Val Loss: 0.3142


train | loss: 0.1101: 100%|██████████| 13/13 [00:00<00:00, 105.64it/s]

Epoch [55/100], Train Loss: 0.1073





Epoch [55/100], Val Loss: 0.3152


train | loss: 0.1132: 100%|██████████| 13/13 [00:00<00:00, 51.26it/s]


Epoch [56/100], Train Loss: 0.1076
Epoch [56/100], Val Loss: 0.3154


train | loss: 0.1098: 100%|██████████| 13/13 [00:00<00:00, 64.51it/s]


Epoch [57/100], Train Loss: 0.1080
Epoch [57/100], Val Loss: 0.3159


train | loss: 0.1093: 100%|██████████| 13/13 [00:00<00:00, 99.61it/s]

Epoch [58/100], Train Loss: 0.1074





Epoch [58/100], Val Loss: 0.3167


train | loss: 0.1036: 100%|██████████| 13/13 [00:00<00:00, 102.88it/s]

Epoch [59/100], Train Loss: 0.1075





Epoch [59/100], Val Loss: 0.3142


train | loss: 0.1075: 100%|██████████| 13/13 [00:00<00:00, 92.59it/s]

Epoch [60/100], Train Loss: 0.1071





Epoch [60/100], Val Loss: 0.3152


train | loss: 0.1106: 100%|██████████| 13/13 [00:00<00:00, 67.47it/s]

Epoch [61/100], Train Loss: 0.1070





Epoch [61/100], Val Loss: 0.3151


train | loss: 0.1064: 100%|██████████| 13/13 [00:00<00:00, 87.86it/s]

Epoch [62/100], Train Loss: 0.1069





Epoch [62/100], Val Loss: 0.3151


train | loss: 0.1076: 100%|██████████| 13/13 [00:00<00:00, 116.59it/s]

Epoch [63/100], Train Loss: 0.1071





Epoch [63/100], Val Loss: 0.3162


train | loss: 0.1106: 100%|██████████| 13/13 [00:00<00:00, 94.24it/s]

Epoch [64/100], Train Loss: 0.1067





Epoch [64/100], Val Loss: 0.3125


train | loss: 0.1120: 100%|██████████| 13/13 [00:00<00:00, 77.24it/s]

Epoch [65/100], Train Loss: 0.1069





Epoch [65/100], Val Loss: 0.3146


train | loss: 0.1041: 100%|██████████| 13/13 [00:00<00:00, 114.74it/s]

Epoch [66/100], Train Loss: 0.1070





Epoch [66/100], Val Loss: 0.3183


train | loss: 0.1106: 100%|██████████| 13/13 [00:00<00:00, 102.73it/s]

Epoch [67/100], Train Loss: 0.1066





Epoch [67/100], Val Loss: 0.3180


train | loss: 0.1076: 100%|██████████| 13/13 [00:00<00:00, 50.50it/s]


Epoch [68/100], Train Loss: 0.1066
Epoch [68/100], Val Loss: 0.3180


train | loss: 0.1085: 100%|██████████| 13/13 [00:00<00:00, 110.75it/s]

Epoch [69/100], Train Loss: 0.1066





Epoch [69/100], Val Loss: 0.3158


train | loss: 0.1027: 100%|██████████| 13/13 [00:00<00:00, 57.92it/s]


Epoch [70/100], Train Loss: 0.1066
Epoch [70/100], Val Loss: 0.3159


train | loss: 0.1038: 100%|██████████| 13/13 [00:00<00:00, 113.12it/s]

Epoch [71/100], Train Loss: 0.1067





Epoch [71/100], Val Loss: 0.3168


train | loss: 0.1045: 100%|██████████| 13/13 [00:00<00:00, 98.87it/s]

Epoch [72/100], Train Loss: 0.1061





Epoch [72/100], Val Loss: 0.3211


train | loss: 0.0972: 100%|██████████| 13/13 [00:00<00:00, 73.92it/s]

Epoch [73/100], Train Loss: 0.1066





Epoch [73/100], Val Loss: 0.3208


train | loss: 0.0997: 100%|██████████| 13/13 [00:00<00:00, 112.71it/s]

Epoch [74/100], Train Loss: 0.1066





Epoch [74/100], Val Loss: 0.3219


train | loss: 0.1064: 100%|██████████| 13/13 [00:00<00:00, 118.17it/s]

Epoch [75/100], Train Loss: 0.1062





Epoch [75/100], Val Loss: 0.3189


train | loss: 0.1106: 100%|██████████| 13/13 [00:00<00:00, 52.08it/s]


Epoch [76/100], Train Loss: 0.1060
Epoch [76/100], Val Loss: 0.3189


train | loss: 0.1008: 100%|██████████| 13/13 [00:00<00:00, 53.14it/s]


Epoch [77/100], Train Loss: 0.1055
Epoch [77/100], Val Loss: 0.3186


train | loss: 0.1035: 100%|██████████| 13/13 [00:00<00:00, 43.71it/s]


Epoch [78/100], Train Loss: 0.1062
Epoch [78/100], Val Loss: 0.3206


train | loss: 0.1039: 100%|██████████| 13/13 [00:00<00:00, 104.38it/s]

Epoch [79/100], Train Loss: 0.1061





Epoch [79/100], Val Loss: 0.3181


train | loss: 0.1051: 100%|██████████| 13/13 [00:00<00:00, 119.59it/s]

Epoch [80/100], Train Loss: 0.1059





Epoch [80/100], Val Loss: 0.3217


train | loss: 0.0943: 100%|██████████| 13/13 [00:00<00:00, 99.12it/s]

Epoch [81/100], Train Loss: 0.1067





Epoch [81/100], Val Loss: 0.3207


train | loss: 0.1010: 100%|██████████| 13/13 [00:00<00:00, 106.08it/s]

Epoch [82/100], Train Loss: 0.1064





Epoch [82/100], Val Loss: 0.3207


train | loss: 0.1076: 100%|██████████| 13/13 [00:00<00:00, 101.66it/s]

Epoch [83/100], Train Loss: 0.1059





Epoch [83/100], Val Loss: 0.3163


train | loss: 0.1030: 100%|██████████| 13/13 [00:00<00:00, 110.08it/s]

Epoch [84/100], Train Loss: 0.1066





Epoch [84/100], Val Loss: 0.3186


train | loss: 0.1072: 100%|██████████| 13/13 [00:00<00:00, 89.96it/s]

Epoch [85/100], Train Loss: 0.1064





Epoch [85/100], Val Loss: 0.3189


train | loss: 0.1021: 100%|██████████| 13/13 [00:00<00:00, 75.73it/s]

Epoch [86/100], Train Loss: 0.1065





Epoch [86/100], Val Loss: 0.3167


train | loss: 0.1016: 100%|██████████| 13/13 [00:00<00:00, 64.39it/s]


Epoch [87/100], Train Loss: 0.1062
Epoch [87/100], Val Loss: 0.3185


train | loss: 0.1045: 100%|██████████| 13/13 [00:00<00:00, 132.09it/s]

Epoch [88/100], Train Loss: 0.1060





Epoch [88/100], Val Loss: 0.3206


train | loss: 0.1087: 100%|██████████| 13/13 [00:00<00:00, 100.25it/s]

Epoch [89/100], Train Loss: 0.1065





Epoch [89/100], Val Loss: 0.3202


train | loss: 0.1061: 100%|██████████| 13/13 [00:00<00:00, 81.75it/s]

Epoch [90/100], Train Loss: 0.1063





Epoch [90/100], Val Loss: 0.3186


train | loss: 0.1016: 100%|██████████| 13/13 [00:00<00:00, 89.87it/s] 

Epoch [91/100], Train Loss: 0.1056





Epoch [91/100], Val Loss: 0.3170


train | loss: 0.1115: 100%|██████████| 13/13 [00:00<00:00, 117.41it/s]

Epoch [92/100], Train Loss: 0.1053





Epoch [92/100], Val Loss: 0.3213


train | loss: 0.1046: 100%|██████████| 13/13 [00:00<00:00, 115.77it/s]

Epoch [93/100], Train Loss: 0.1061





Epoch [93/100], Val Loss: 0.3212


train | loss: 0.1035: 100%|██████████| 13/13 [00:00<00:00, 98.53it/s] 

Epoch [94/100], Train Loss: 0.1061





Epoch [94/100], Val Loss: 0.3191


train | loss: 0.1026: 100%|██████████| 13/13 [00:00<00:00, 63.52it/s]


Epoch [95/100], Train Loss: 0.1053
Epoch [95/100], Val Loss: 0.3210


train | loss: 0.1003: 100%|██████████| 13/13 [00:00<00:00, 115.75it/s]

Epoch [96/100], Train Loss: 0.1059





Epoch [96/100], Val Loss: 0.3222


train | loss: 0.0986: 100%|██████████| 13/13 [00:00<00:00, 122.77it/s]

Epoch [97/100], Train Loss: 0.1062





Epoch [97/100], Val Loss: 0.3204


train | loss: 0.1027: 100%|██████████| 13/13 [00:00<00:00, 54.60it/s]


Epoch [98/100], Train Loss: 0.1067
Epoch [98/100], Val Loss: 0.3232


train | loss: 0.1049: 100%|██████████| 13/13 [00:00<00:00, 30.13it/s]


Epoch [99/100], Train Loss: 0.1065
Epoch [99/100], Val Loss: 0.3226


train | loss: 0.1116: 100%|██████████| 13/13 [00:00<00:00, 107.60it/s]

Epoch [100/100], Train Loss: 0.1062





Epoch [100/100], Val Loss: 0.3184


In [27]:
optim = torch.optim.Adam(lr=lr, params=mlp.parameters())

mlp_last_net, mlp_best_net = train(epochs, optim, criterion, l1_reg, train_loader, test_loader, mlp, empty_mpl, reg_lambda=0.001)

train | loss: 8.0861: 100%|██████████| 10/10 [00:00<00:00, 23.50it/s]


Epoch [1/100], Train Loss: 6.7905
Epoch [1/100], Val Loss: 0.3649


train | loss: 5.3760: 100%|██████████| 10/10 [00:00<00:00, 75.64it/s]


Epoch [2/100], Train Loss: 4.4671
Epoch [2/100], Val Loss: 0.3516


train | loss: 3.4162: 100%|██████████| 10/10 [00:00<00:00, 90.46it/s]


Epoch [3/100], Train Loss: 2.7488
Epoch [3/100], Val Loss: 0.3486


train | loss: 1.9976: 100%|██████████| 10/10 [00:00<00:00, 64.43it/s]


Epoch [4/100], Train Loss: 1.5308
Epoch [4/100], Val Loss: 0.3496


train | loss: 1.0385: 100%|██████████| 10/10 [00:00<00:00, 69.91it/s]


Epoch [5/100], Train Loss: 0.8201
Epoch [5/100], Val Loss: 0.3497


train | loss: 0.6259: 100%|██████████| 10/10 [00:00<00:00, 60.16it/s]


Epoch [6/100], Train Loss: 0.5650
Epoch [6/100], Val Loss: 0.3519


train | loss: 0.4838: 100%|██████████| 10/10 [00:00<00:00, 67.86it/s]


Epoch [7/100], Train Loss: 0.4215
Epoch [7/100], Val Loss: 0.3547


train | loss: 0.3785: 100%|██████████| 10/10 [00:00<00:00, 83.06it/s]


Epoch [8/100], Train Loss: 0.3532
Epoch [8/100], Val Loss: 0.3560


train | loss: 0.3339: 100%|██████████| 10/10 [00:00<00:00, 84.06it/s]


Epoch [9/100], Train Loss: 0.3134
Epoch [9/100], Val Loss: 0.3549


train | loss: 0.3049: 100%|██████████| 10/10 [00:00<00:00, 80.92it/s]


Epoch [10/100], Train Loss: 0.2857
Epoch [10/100], Val Loss: 0.3547


train | loss: 0.2915: 100%|██████████| 10/10 [00:00<00:00, 84.64it/s]


Epoch [11/100], Train Loss: 0.2700
Epoch [11/100], Val Loss: 0.3551


train | loss: 0.2562: 100%|██████████| 10/10 [00:00<00:00, 42.48it/s]


Epoch [12/100], Train Loss: 0.2574
Epoch [12/100], Val Loss: 0.3528


train | loss: 0.2392: 100%|██████████| 10/10 [00:00<00:00, 69.58it/s]


Epoch [13/100], Train Loss: 0.2472
Epoch [13/100], Val Loss: 0.3525


train | loss: 0.2482: 100%|██████████| 10/10 [00:00<00:00, 79.35it/s]


Epoch [14/100], Train Loss: 0.2389
Epoch [14/100], Val Loss: 0.3522


train | loss: 0.2373: 100%|██████████| 10/10 [00:00<00:00, 85.96it/s]


Epoch [15/100], Train Loss: 0.2329
Epoch [15/100], Val Loss: 0.3518


train | loss: 0.2241: 100%|██████████| 10/10 [00:00<00:00, 78.91it/s]


Epoch [16/100], Train Loss: 0.2275
Epoch [16/100], Val Loss: 0.3506


train | loss: 0.2142: 100%|██████████| 10/10 [00:00<00:00, 88.37it/s]


Epoch [17/100], Train Loss: 0.2204
Epoch [17/100], Val Loss: 0.3505


train | loss: 0.2129: 100%|██████████| 10/10 [00:00<00:00, 24.97it/s]


Epoch [18/100], Train Loss: 0.2151
Epoch [18/100], Val Loss: 0.3512


train | loss: 0.2266: 100%|██████████| 10/10 [00:00<00:00, 46.36it/s]


Epoch [19/100], Train Loss: 0.2108
Epoch [19/100], Val Loss: 0.3515


train | loss: 0.2046: 100%|██████████| 10/10 [00:00<00:00, 93.04it/s]


Epoch [20/100], Train Loss: 0.2068
Epoch [20/100], Val Loss: 0.3521


train | loss: 0.1929: 100%|██████████| 10/10 [00:00<00:00, 52.70it/s]


Epoch [21/100], Train Loss: 0.2049
Epoch [21/100], Val Loss: 0.3510


train | loss: 0.2012: 100%|██████████| 10/10 [00:00<00:00, 91.80it/s]


Epoch [22/100], Train Loss: 0.2015
Epoch [22/100], Val Loss: 0.3509


train | loss: 0.2001: 100%|██████████| 10/10 [00:00<00:00, 94.15it/s]


Epoch [23/100], Train Loss: 0.2001
Epoch [23/100], Val Loss: 0.3511


train | loss: 0.1934: 100%|██████████| 10/10 [00:00<00:00, 57.80it/s]


Epoch [24/100], Train Loss: 0.2005
Epoch [24/100], Val Loss: 0.3512


train | loss: 0.1958: 100%|██████████| 10/10 [00:00<00:00, 54.50it/s]

Epoch [25/100], Train Loss: 0.1987





Epoch [25/100], Val Loss: 0.3505


train | loss: 0.2001: 100%|██████████| 10/10 [00:00<00:00, 64.20it/s]


Epoch [26/100], Train Loss: 0.1969
Epoch [26/100], Val Loss: 0.3503


train | loss: 0.2031: 100%|██████████| 10/10 [00:00<00:00, 90.04it/s]


Epoch [27/100], Train Loss: 0.1972
Epoch [27/100], Val Loss: 0.3518


train | loss: 0.1924: 100%|██████████| 10/10 [00:00<00:00, 69.89it/s]


Epoch [28/100], Train Loss: 0.1959
Epoch [28/100], Val Loss: 0.3514


train | loss: 0.1860: 100%|██████████| 10/10 [00:00<00:00, 79.73it/s]


Epoch [29/100], Train Loss: 0.1955
Epoch [29/100], Val Loss: 0.3518


train | loss: 0.1909: 100%|██████████| 10/10 [00:00<00:00, 83.28it/s]


Epoch [30/100], Train Loss: 0.1953
Epoch [30/100], Val Loss: 0.3509


train | loss: 0.1964: 100%|██████████| 10/10 [00:00<00:00, 33.87it/s]


Epoch [31/100], Train Loss: 0.1964
Epoch [31/100], Val Loss: 0.3508


train | loss: 0.1844: 100%|██████████| 10/10 [00:00<00:00, 81.60it/s]


Epoch [32/100], Train Loss: 0.1927
Epoch [32/100], Val Loss: 0.3507


train | loss: 0.1803: 100%|██████████| 10/10 [00:00<00:00, 86.87it/s]


Epoch [33/100], Train Loss: 0.1938
Epoch [33/100], Val Loss: 0.3515


train | loss: 0.1964: 100%|██████████| 10/10 [00:00<00:00, 97.65it/s]


Epoch [34/100], Train Loss: 0.1927
Epoch [34/100], Val Loss: 0.3512


train | loss: 0.1817: 100%|██████████| 10/10 [00:00<00:00, 85.65it/s]


Epoch [35/100], Train Loss: 0.1920
Epoch [35/100], Val Loss: 0.3512


train | loss: 0.1909: 100%|██████████| 10/10 [00:00<00:00, 49.53it/s]


Epoch [36/100], Train Loss: 0.1920
Epoch [36/100], Val Loss: 0.3515


train | loss: 0.1817: 100%|██████████| 10/10 [00:00<00:00, 51.02it/s]


Epoch [37/100], Train Loss: 0.1924
Epoch [37/100], Val Loss: 0.3510


train | loss: 0.1981: 100%|██████████| 10/10 [00:00<00:00, 71.20it/s]


Epoch [38/100], Train Loss: 0.1917
Epoch [38/100], Val Loss: 0.3511


train | loss: 0.1816: 100%|██████████| 10/10 [00:00<00:00, 87.56it/s]


Epoch [39/100], Train Loss: 0.1900
Epoch [39/100], Val Loss: 0.3512


train | loss: 0.1803: 100%|██████████| 10/10 [00:00<00:00, 85.80it/s]


Epoch [40/100], Train Loss: 0.1891
Epoch [40/100], Val Loss: 0.3506


train | loss: 0.2055: 100%|██████████| 10/10 [00:00<00:00, 43.01it/s]


Epoch [41/100], Train Loss: 0.1879
Epoch [41/100], Val Loss: 0.3505


train | loss: 0.1972: 100%|██████████| 10/10 [00:00<00:00, 28.19it/s]


Epoch [42/100], Train Loss: 0.1862
Epoch [42/100], Val Loss: 0.3518


train | loss: 0.1829: 100%|██████████| 10/10 [00:00<00:00, 62.23it/s]


Epoch [43/100], Train Loss: 0.1856
Epoch [43/100], Val Loss: 0.3519


train | loss: 0.1806: 100%|██████████| 10/10 [00:00<00:00, 85.44it/s]


Epoch [44/100], Train Loss: 0.1867
Epoch [44/100], Val Loss: 0.3525


train | loss: 0.1796: 100%|██████████| 10/10 [00:00<00:00, 84.59it/s]


Epoch [45/100], Train Loss: 0.1861
Epoch [45/100], Val Loss: 0.3511


train | loss: 0.1831: 100%|██████████| 10/10 [00:00<00:00, 84.84it/s]


Epoch [46/100], Train Loss: 0.1853
Epoch [46/100], Val Loss: 0.3510


train | loss: 0.1750: 100%|██████████| 10/10 [00:00<00:00, 89.94it/s]


Epoch [47/100], Train Loss: 0.1841
Epoch [47/100], Val Loss: 0.3514


train | loss: 0.1763: 100%|██████████| 10/10 [00:00<00:00, 44.45it/s]


Epoch [48/100], Train Loss: 0.1837
Epoch [48/100], Val Loss: 0.3519


train | loss: 0.1921: 100%|██████████| 10/10 [00:00<00:00, 51.67it/s]


Epoch [49/100], Train Loss: 0.1845
Epoch [49/100], Val Loss: 0.3519


train | loss: 0.1719: 100%|██████████| 10/10 [00:00<00:00, 82.79it/s]


Epoch [50/100], Train Loss: 0.1837
Epoch [50/100], Val Loss: 0.3516


train | loss: 0.1793: 100%|██████████| 10/10 [00:00<00:00, 94.87it/s]


Epoch [51/100], Train Loss: 0.1833
Epoch [51/100], Val Loss: 0.3520


train | loss: 0.1848: 100%|██████████| 10/10 [00:00<00:00, 88.77it/s]


Epoch [52/100], Train Loss: 0.1829
Epoch [52/100], Val Loss: 0.3520


train | loss: 0.1887: 100%|██████████| 10/10 [00:00<00:00, 85.94it/s]


Epoch [53/100], Train Loss: 0.1846
Epoch [53/100], Val Loss: 0.3522


train | loss: 0.1782: 100%|██████████| 10/10 [00:00<00:00, 62.51it/s]


Epoch [54/100], Train Loss: 0.1838
Epoch [54/100], Val Loss: 0.3519


train | loss: 0.1776: 100%|██████████| 10/10 [00:00<00:00, 43.50it/s]


Epoch [55/100], Train Loss: 0.1844
Epoch [55/100], Val Loss: 0.3518


train | loss: 0.1774: 100%|██████████| 10/10 [00:00<00:00, 72.01it/s]


Epoch [56/100], Train Loss: 0.1834
Epoch [56/100], Val Loss: 0.3520


train | loss: 0.1699: 100%|██████████| 10/10 [00:00<00:00, 84.97it/s]


Epoch [57/100], Train Loss: 0.1831
Epoch [57/100], Val Loss: 0.3524


train | loss: 0.1923: 100%|██████████| 10/10 [00:00<00:00, 88.01it/s]


Epoch [58/100], Train Loss: 0.1822
Epoch [58/100], Val Loss: 0.3524


train | loss: 0.1692: 100%|██████████| 10/10 [00:00<00:00, 93.54it/s]


Epoch [59/100], Train Loss: 0.1837
Epoch [59/100], Val Loss: 0.3527


train | loss: 0.1862: 100%|██████████| 10/10 [00:00<00:00, 53.20it/s]


Epoch [60/100], Train Loss: 0.1825
Epoch [60/100], Val Loss: 0.3523


train | loss: 0.1960: 100%|██████████| 10/10 [00:00<00:00, 49.38it/s]


Epoch [61/100], Train Loss: 0.1836
Epoch [61/100], Val Loss: 0.3523


train | loss: 0.1737: 100%|██████████| 10/10 [00:00<00:00, 73.44it/s]


Epoch [62/100], Train Loss: 0.1840
Epoch [62/100], Val Loss: 0.3524


train | loss: 0.1819: 100%|██████████| 10/10 [00:00<00:00, 88.91it/s]


Epoch [63/100], Train Loss: 0.1827
Epoch [63/100], Val Loss: 0.3516


train | loss: 0.1861: 100%|██████████| 10/10 [00:00<00:00, 86.86it/s]


Epoch [64/100], Train Loss: 0.1830
Epoch [64/100], Val Loss: 0.3525


train | loss: 0.1773: 100%|██████████| 10/10 [00:00<00:00, 88.44it/s]


Epoch [65/100], Train Loss: 0.1833
Epoch [65/100], Val Loss: 0.3527


train | loss: 0.1944: 100%|██████████| 10/10 [00:00<00:00, 82.96it/s]


Epoch [66/100], Train Loss: 0.1845
Epoch [66/100], Val Loss: 0.3532


train | loss: 0.1893: 100%|██████████| 10/10 [00:00<00:00, 53.87it/s]


Epoch [67/100], Train Loss: 0.1831
Epoch [67/100], Val Loss: 0.3520


train | loss: 0.1811: 100%|██████████| 10/10 [00:00<00:00, 55.56it/s]


Epoch [68/100], Train Loss: 0.1848
Epoch [68/100], Val Loss: 0.3526


train | loss: 0.1786: 100%|██████████| 10/10 [00:00<00:00, 83.09it/s]


Epoch [69/100], Train Loss: 0.1822
Epoch [69/100], Val Loss: 0.3516


train | loss: 0.1760: 100%|██████████| 10/10 [00:00<00:00, 84.61it/s]


Epoch [70/100], Train Loss: 0.1820
Epoch [70/100], Val Loss: 0.3525


train | loss: 0.1783: 100%|██████████| 10/10 [00:00<00:00, 80.91it/s]


Epoch [71/100], Train Loss: 0.1823
Epoch [71/100], Val Loss: 0.3529


train | loss: 0.1853: 100%|██████████| 10/10 [00:00<00:00, 91.34it/s]


Epoch [72/100], Train Loss: 0.1834
Epoch [72/100], Val Loss: 0.3528


train | loss: 0.1761: 100%|██████████| 10/10 [00:00<00:00, 20.04it/s]


Epoch [73/100], Train Loss: 0.1833
Epoch [73/100], Val Loss: 0.3527


train | loss: 0.1892: 100%|██████████| 10/10 [00:00<00:00, 38.88it/s]


Epoch [74/100], Train Loss: 0.1834
Epoch [74/100], Val Loss: 0.3522


train | loss: 0.1794: 100%|██████████| 10/10 [00:00<00:00, 84.32it/s]


Epoch [75/100], Train Loss: 0.1838
Epoch [75/100], Val Loss: 0.3521


train | loss: 0.1821: 100%|██████████| 10/10 [00:00<00:00, 88.00it/s]


Epoch [76/100], Train Loss: 0.1830
Epoch [76/100], Val Loss: 0.3522


train | loss: 0.1765: 100%|██████████| 10/10 [00:00<00:00, 86.14it/s]


Epoch [77/100], Train Loss: 0.1823
Epoch [77/100], Val Loss: 0.3525


train | loss: 0.1829: 100%|██████████| 10/10 [00:00<00:00, 64.98it/s]


Epoch [78/100], Train Loss: 0.1821
Epoch [78/100], Val Loss: 0.3531


train | loss: 0.1806: 100%|██████████| 10/10 [00:00<00:00, 41.24it/s]


Epoch [79/100], Train Loss: 0.1840
Epoch [79/100], Val Loss: 0.3529


train | loss: 0.1769: 100%|██████████| 10/10 [00:00<00:00, 78.27it/s]


Epoch [80/100], Train Loss: 0.1829
Epoch [80/100], Val Loss: 0.3526


train | loss: 0.1862: 100%|██████████| 10/10 [00:00<00:00, 88.44it/s]


Epoch [81/100], Train Loss: 0.1828
Epoch [81/100], Val Loss: 0.3519


train | loss: 0.1878: 100%|██████████| 10/10 [00:00<00:00, 76.80it/s]


Epoch [82/100], Train Loss: 0.1826
Epoch [82/100], Val Loss: 0.3522


train | loss: 0.1736: 100%|██████████| 10/10 [00:00<00:00, 89.81it/s]


Epoch [83/100], Train Loss: 0.1843
Epoch [83/100], Val Loss: 0.3531


train | loss: 0.1839: 100%|██████████| 10/10 [00:00<00:00, 50.56it/s]


Epoch [84/100], Train Loss: 0.1819
Epoch [84/100], Val Loss: 0.3518


train | loss: 0.1864: 100%|██████████| 10/10 [00:00<00:00, 47.89it/s]


Epoch [85/100], Train Loss: 0.1827
Epoch [85/100], Val Loss: 0.3530


train | loss: 0.1827: 100%|██████████| 10/10 [00:00<00:00, 80.45it/s]


Epoch [86/100], Train Loss: 0.1824
Epoch [86/100], Val Loss: 0.3533


train | loss: 0.1843: 100%|██████████| 10/10 [00:00<00:00, 86.87it/s]


Epoch [87/100], Train Loss: 0.1827
Epoch [87/100], Val Loss: 0.3525


train | loss: 0.1977: 100%|██████████| 10/10 [00:00<00:00, 92.45it/s]


Epoch [88/100], Train Loss: 0.1823
Epoch [88/100], Val Loss: 0.3520


train | loss: 0.1861: 100%|██████████| 10/10 [00:00<00:00, 91.31it/s]


Epoch [89/100], Train Loss: 0.1823
Epoch [89/100], Val Loss: 0.3534


train | loss: 0.1906: 100%|██████████| 10/10 [00:00<00:00, 29.65it/s]


Epoch [90/100], Train Loss: 0.1824
Epoch [90/100], Val Loss: 0.3529


train | loss: 0.1741: 100%|██████████| 10/10 [00:00<00:00, 44.04it/s]


Epoch [91/100], Train Loss: 0.1830
Epoch [91/100], Val Loss: 0.3531


train | loss: 0.1865: 100%|██████████| 10/10 [00:00<00:00, 88.92it/s]


Epoch [92/100], Train Loss: 0.1819
Epoch [92/100], Val Loss: 0.3524


train | loss: 0.1864: 100%|██████████| 10/10 [00:00<00:00, 85.50it/s]


Epoch [93/100], Train Loss: 0.1822
Epoch [93/100], Val Loss: 0.3527


train | loss: 0.1916: 100%|██████████| 10/10 [00:00<00:00, 87.41it/s]


Epoch [94/100], Train Loss: 0.1820
Epoch [94/100], Val Loss: 0.3532


train | loss: 0.1731: 100%|██████████| 10/10 [00:00<00:00, 72.77it/s]


Epoch [95/100], Train Loss: 0.1823
Epoch [95/100], Val Loss: 0.3528


train | loss: 0.1743: 100%|██████████| 10/10 [00:00<00:00, 49.96it/s]


Epoch [96/100], Train Loss: 0.1817
Epoch [96/100], Val Loss: 0.3530


train | loss: 0.1876: 100%|██████████| 10/10 [00:00<00:00, 58.72it/s]


Epoch [97/100], Train Loss: 0.1828
Epoch [97/100], Val Loss: 0.3531


train | loss: 0.1852: 100%|██████████| 10/10 [00:00<00:00, 77.12it/s]


Epoch [98/100], Train Loss: 0.1825
Epoch [98/100], Val Loss: 0.3526


train | loss: 0.1828: 100%|██████████| 10/10 [00:00<00:00, 85.01it/s]


Epoch [99/100], Train Loss: 0.1822
Epoch [99/100], Val Loss: 0.3523


train | loss: 0.1798: 100%|██████████| 10/10 [00:00<00:00, 84.68it/s]

Epoch [100/100], Train Loss: 0.1825
Epoch [100/100], Val Loss: 0.3533





In [28]:
test_loss_lin = validate(criterion, test_loader, lin_last_net)
print(f"lin test: {test_loss_lin}")

test_loss_mlp = validate(criterion, test_loader, mlp_last_net)
print(f"lin test: {test_loss_mlp}")


lin test: 0.31156280636787415
lin test: 0.34074345231056213
