In [86]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from taskdataset import TaskDataset
import matplotlib.pyplot as plt
import os
import time
import pickle as pkl
from tqdm import tqdm

from t2_functions import partition_ids

In [87]:
import sys

sys.path.append(os.path.join(os.getcwd(), "task_2_sybilattack/"))

from endpoints.requests import sybil, sybil_reset


In [89]:
print(os.getcwd())
dataset = torch.load("data/SybilAttack.pt")

/home/janek/Documents/Hackaton/ensembleAI-ScoutTeam/task_2_sybilattack


In [90]:
ids = np.array(dataset.ids)
binned_ids = partition_ids(ids, main_bin_num=10)

bin_size: 2000


In [57]:
def is_linearly_independent(set_of_vectors, vector):
    # Check if the new vector is linearly independent of the set
    coefficients = np.linalg.lstsq(set_of_vectors, vector, rcond=None)[0]
    return np.linalg.norm(vector - np.dot(set_of_vectors, coefficients)) > 1e-10

def find_independent_set(vectors, k=200):
    vectors = np.array(vectors)
    n = len(vectors)
    independent_set = vectors[:k]
    indexes = np.arange(k)
    i = 1
    for j, vector in enumerate(tqdm(vectors)):
        if i >= k:
            return np.array(independent_set), indexes, i
        if n - j < k - i:
            return np.array(independent_set), indexes, i
        if is_linearly_independent(np.array(independent_set[:i]).T, vector):
            independent_set[i] = vector
            indexes[i] = j
            i += 1
    return np.array(independent_set), indexes, i


In [82]:
task = 'binary'

sybil_reset(home_or_defense='home', binary_or_affine=task)
sybil_reset(home_or_defense='defense', binary_or_affine=task)


Request ok
{'msg': 'Successful sybil binary reset home'}
Request ok
{'msg': 'Successful sybil binary reset defense'}


In [83]:
ids = binned_ids[0]
print(len(ids))
A_reps = sybil(ids=ids,
               home_or_defense='home',
               binary_or_affine=task)

A_indep, indexes, success = find_independent_set(A_reps, k=200)
print(f"success: {success}")
print(indexes)

ids_train = ids[indexes]
mask = np.ones(len(ids), dtype=bool)
mask[indexes] = False
ids_test = ids[mask]

A_reps = np.array(A_reps)
A_train_reps = A_reps[indexes]
A_test_reps = A_reps[mask]

B_train_reps = sybil(ids=ids_train,
                 home_or_defense='defense',
                 binary_or_affine=task)
print(f"B train reps: {len(B_train_reps)}")

B_test_reps = sybil(ids=ids_test,
                 home_or_defense='defense',
                 binary_or_affine=task)
print(f"B test reps: {len(B_test_reps)}")


2000


 10%|█         | 200/2000 [00:00<00:04, 400.62it/s]


success: 200
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
 198 199]
A train reps: 2000
A test reps: 1800


In [84]:
task='affine_v3'

os.makedirs(f'data/{task}', exist_ok=True)

with open(f'data/{task}/A_train', 'wb') as f:
    pkl.dump(A_train_reps, f)
with open(f'data/{task}/B_train', 'wb') as f:
    pkl.dump(B_train_reps, f)
with open(f'data/{task}/A_test', 'wb') as f:
    pkl.dump(A_test_reps, f)
with open(f'data/{task}/B_test', 'wb') as f:
    pkl.dump(B_test_reps, f)


In [85]:
with open(f'data/{task}/A_train', 'rb') as f:
    A_train_reps = pkl.load(f)
with open(f'data/{task}/B_train', 'rb') as f:
    B_train_reps = pkl.load(f)
with open(f'data/{task}/A_test', 'rb') as f:
    A_test_reps = pkl.load(f)
with open(f'data/{task}/B_test', 'rb') as f:
    B_test_reps = pkl.load(f)
    

In [29]:
class RepresentationsDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x, y = self.x[idx], self.y[idx]
        return torch.tensor(x), torch.tensor(y)


In [30]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [31]:
class Linear(nn.Module):
    def __init__(self, input_size, output_size):
        super(Linear, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)
    
    def forward(self, x):
        x = self.fc1(x)
        return x

In [32]:
def validate(criterion, loader, net):
    net.eval()

    with torch.no_grad():
        true = []
        pred = []
        for x, y in loader:
            y_hat = net(x)
            true.append(y.numpy())
            pred.append(y_hat.numpy())

        true = torch.tensor(np.concatenate(true, axis=0))
        pred = torch.tensor(np.concatenate(pred, axis=0))

        loss = criterion(pred, true)

    return loss


In [33]:
def l1_reg(model, reg_lambda):
    l1_regularization = torch.tensor(0., device=model.parameters().__next__().device)
    for param in model.parameters():
        l1_regularization += torch.norm(param, p=1)
    return reg_lambda * l1_regularization


In [34]:
from torch.utils.tensorboard import SummaryWriter

def train(epochs, optim, criterion, regularise, trainloader, valloader, net, empty_net, reg_lambda=None):
    best_val_loss = np.inf
    best_net = empty_net
    writer = SummaryWriter()


    for epoch in range(epochs):  # loop over the dataset multiple times
        train_loss = 0

        progress_bar = tqdm(trainloader)

        for iter, (x, y) in enumerate(progress_bar):
            net.train()
            optim.zero_grad()

            y_hat = net(x)
            loss = criterion(y_hat, y)
            if reg_lambda is not None:
                loss += regularise(net, reg_lambda)
            loss.backward()
            optim.step()

            batch_loss = loss.item()
            train_loss += batch_loss
            if iter % 20 == 0:
                progress_bar.set_description(f"train | loss: {batch_loss:.4f}")

            writer.add_scalar('Training Loss', batch_loss, epoch)

        train_loss /= len(trainloader)
        print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}")

        net.eval()

        val_loss = validate(criterion, valloader, net)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_net.load_state_dict(net.state_dict())

        if writer is not None:
            # log the validation loss and accuracy
            writer.add_scalar('Validation Loss', val_loss, epoch)

        print(f"Epoch [{epoch + 1}/{epochs}], Val Loss: {val_loss:.4f}")

    return net, best_net



In [35]:
criterion = nn.MSELoss()
batch_size = 32
epochs = 100
lr = 0.001


In [36]:
train_dataset = RepresentationsDataset(x=A_train_reps, y=B_train_reps)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = RepresentationsDataset(x=A_test_reps, y=B_test_reps)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [37]:
lin_net = Linear(384, 384)
lin_empty_net = Linear(384, 384)

optim = torch.optim.Adam(lr=lr, params=lin_net.parameters())

lin_last_net, lin_best_net = train(epochs, optim, criterion, l1_reg, train_loader, test_loader, lin_net, lin_empty_net, reg_lambda=0.0001)

train | loss: 4.1753: 100%|██████████| 7/7 [00:00<00:00, 57.14it/s]


Epoch [1/100], Train Loss: 2.5125
Epoch [1/100], Val Loss: 1.1748


train | loss: 1.6496: 100%|██████████| 7/7 [00:00<00:00, 34.01it/s]


Epoch [2/100], Train Loss: 1.3977
Epoch [2/100], Val Loss: 0.8715


train | loss: 1.2516: 100%|██████████| 7/7 [00:00<00:00, 98.36it/s]

Epoch [3/100], Train Loss: 1.1425





Epoch [3/100], Val Loss: 0.6619


train | loss: 1.0028: 100%|██████████| 7/7 [00:00<00:00, 63.88it/s]

Epoch [4/100], Train Loss: 0.9671





Epoch [4/100], Val Loss: 0.5595


train | loss: 0.9312: 100%|██████████| 7/7 [00:00<00:00, 33.74it/s]


Epoch [5/100], Train Loss: 0.8435
Epoch [5/100], Val Loss: 0.4889


train | loss: 0.7934: 100%|██████████| 7/7 [00:00<00:00, 46.84it/s]

Epoch [6/100], Train Loss: 0.7616





Epoch [6/100], Val Loss: 0.4387


train | loss: 0.7848: 100%|██████████| 7/7 [00:00<00:00, 80.26it/s]

Epoch [7/100], Train Loss: 0.7086





Epoch [7/100], Val Loss: 0.4040


train | loss: 0.7111: 100%|██████████| 7/7 [00:00<00:00, 77.69it/s]

Epoch [8/100], Train Loss: 0.6691





Epoch [8/100], Val Loss: 0.3739


train | loss: 0.6459: 100%|██████████| 7/7 [00:00<00:00, 90.22it/s]

Epoch [9/100], Train Loss: 0.6246





Epoch [9/100], Val Loss: 0.3524


train | loss: 0.6542: 100%|██████████| 7/7 [00:00<00:00, 70.21it/s]

Epoch [10/100], Train Loss: 0.5964





Epoch [10/100], Val Loss: 0.3331


train | loss: 0.6088: 100%|██████████| 7/7 [00:00<00:00, 79.97it/s]

Epoch [11/100], Train Loss: 0.5643





Epoch [11/100], Val Loss: 0.3175


train | loss: 0.5562: 100%|██████████| 7/7 [00:00<00:00, 62.84it/s]

Epoch [12/100], Train Loss: 0.5404





Epoch [12/100], Val Loss: 0.3052


train | loss: 0.5371: 100%|██████████| 7/7 [00:00<00:00, 88.45it/s]

Epoch [13/100], Train Loss: 0.5210





Epoch [13/100], Val Loss: 0.2931


train | loss: 0.5124: 100%|██████████| 7/7 [00:00<00:00, 68.60it/s]

Epoch [14/100], Train Loss: 0.5061





Epoch [14/100], Val Loss: 0.2827


train | loss: 0.4879: 100%|██████████| 7/7 [00:00<00:00, 49.83it/s]

Epoch [15/100], Train Loss: 0.4728





Epoch [15/100], Val Loss: 0.2728


train | loss: 0.4755: 100%|██████████| 7/7 [00:00<00:00, 88.15it/s]

Epoch [16/100], Train Loss: 0.4521





Epoch [16/100], Val Loss: 0.2636


train | loss: 0.3979: 100%|██████████| 7/7 [00:00<00:00, 79.09it/s]

Epoch [17/100], Train Loss: 0.4413





Epoch [17/100], Val Loss: 0.2558


train | loss: 0.4309: 100%|██████████| 7/7 [00:00<00:00, 56.98it/s]

Epoch [18/100], Train Loss: 0.4272





Epoch [18/100], Val Loss: 0.2495


train | loss: 0.4334: 100%|██████████| 7/7 [00:00<00:00, 85.46it/s]

Epoch [19/100], Train Loss: 0.4143





Epoch [19/100], Val Loss: 0.2430


train | loss: 0.4036: 100%|██████████| 7/7 [00:00<00:00, 62.02it/s]

Epoch [20/100], Train Loss: 0.3957





Epoch [20/100], Val Loss: 0.2365


train | loss: 0.4205: 100%|██████████| 7/7 [00:00<00:00, 54.19it/s]

Epoch [21/100], Train Loss: 0.3844





Epoch [21/100], Val Loss: 0.2314


train | loss: 0.3585: 100%|██████████| 7/7 [00:00<00:00, 69.16it/s]

Epoch [22/100], Train Loss: 0.3747





Epoch [22/100], Val Loss: 0.2258


train | loss: 0.3666: 100%|██████████| 7/7 [00:00<00:00, 74.77it/s]

Epoch [23/100], Train Loss: 0.3605





Epoch [23/100], Val Loss: 0.2212


train | loss: 0.3671: 100%|██████████| 7/7 [00:00<00:00, 54.41it/s]

Epoch [24/100], Train Loss: 0.3484





Epoch [24/100], Val Loss: 0.2170


train | loss: 0.3562: 100%|██████████| 7/7 [00:00<00:00, 93.56it/s]

Epoch [25/100], Train Loss: 0.3416





Epoch [25/100], Val Loss: 0.2128


train | loss: 0.3602: 100%|██████████| 7/7 [00:00<00:00, 80.60it/s]

Epoch [26/100], Train Loss: 0.3294





Epoch [26/100], Val Loss: 0.2092


train | loss: 0.3013: 100%|██████████| 7/7 [00:00<00:00, 52.35it/s]

Epoch [27/100], Train Loss: 0.3261





Epoch [27/100], Val Loss: 0.2056


train | loss: 0.3128: 100%|██████████| 7/7 [00:00<00:00, 28.74it/s]


Epoch [28/100], Train Loss: 0.3125
Epoch [28/100], Val Loss: 0.2019


train | loss: 0.3043: 100%|██████████| 7/7 [00:00<00:00, 82.90it/s]

Epoch [29/100], Train Loss: 0.3066





Epoch [29/100], Val Loss: 0.1985


train | loss: 0.3003: 100%|██████████| 7/7 [00:00<00:00, 78.84it/s]

Epoch [30/100], Train Loss: 0.2946





Epoch [30/100], Val Loss: 0.1953


train | loss: 0.3135: 100%|██████████| 7/7 [00:00<00:00, 44.60it/s]

Epoch [31/100], Train Loss: 0.2921





Epoch [31/100], Val Loss: 0.1923


train | loss: 0.2843: 100%|██████████| 7/7 [00:00<00:00, 79.29it/s]

Epoch [32/100], Train Loss: 0.2816





Epoch [32/100], Val Loss: 0.1890


train | loss: 0.2869: 100%|██████████| 7/7 [00:00<00:00, 64.28it/s]

Epoch [33/100], Train Loss: 0.2761





Epoch [33/100], Val Loss: 0.1869


train | loss: 0.2861: 100%|██████████| 7/7 [00:00<00:00, 70.35it/s]

Epoch [34/100], Train Loss: 0.2694





Epoch [34/100], Val Loss: 0.1844


train | loss: 0.2906: 100%|██████████| 7/7 [00:00<00:00, 81.78it/s]

Epoch [35/100], Train Loss: 0.2635





Epoch [35/100], Val Loss: 0.1819


train | loss: 0.2452: 100%|██████████| 7/7 [00:00<00:00, 57.54it/s]

Epoch [36/100], Train Loss: 0.2586





Epoch [36/100], Val Loss: 0.1797


train | loss: 0.2465: 100%|██████████| 7/7 [00:00<00:00, 57.05it/s]

Epoch [37/100], Train Loss: 0.2525





Epoch [37/100], Val Loss: 0.1776


train | loss: 0.2297: 100%|██████████| 7/7 [00:00<00:00, 67.87it/s]

Epoch [38/100], Train Loss: 0.2463





Epoch [38/100], Val Loss: 0.1752


train | loss: 0.2354: 100%|██████████| 7/7 [00:00<00:00, 85.88it/s]

Epoch [39/100], Train Loss: 0.2454





Epoch [39/100], Val Loss: 0.1730


train | loss: 0.2315: 100%|██████████| 7/7 [00:00<00:00, 57.85it/s]

Epoch [40/100], Train Loss: 0.2416





Epoch [40/100], Val Loss: 0.1714


train | loss: 0.2276: 100%|██████████| 7/7 [00:00<00:00, 42.60it/s]

Epoch [41/100], Train Loss: 0.2314





Epoch [41/100], Val Loss: 0.1689


train | loss: 0.2313: 100%|██████████| 7/7 [00:00<00:00, 80.47it/s]

Epoch [42/100], Train Loss: 0.2311





Epoch [42/100], Val Loss: 0.1672


train | loss: 0.2156: 100%|██████████| 7/7 [00:00<00:00, 77.59it/s]

Epoch [43/100], Train Loss: 0.2227





Epoch [43/100], Val Loss: 0.1649


train | loss: 0.2094: 100%|██████████| 7/7 [00:00<00:00, 48.24it/s]

Epoch [44/100], Train Loss: 0.2231





Epoch [44/100], Val Loss: 0.1635


train | loss: 0.2015: 100%|██████████| 7/7 [00:00<00:00, 58.75it/s]

Epoch [45/100], Train Loss: 0.2173





Epoch [45/100], Val Loss: 0.1622


train | loss: 0.2079: 100%|██████████| 7/7 [00:00<00:00, 71.52it/s]

Epoch [46/100], Train Loss: 0.2137





Epoch [46/100], Val Loss: 0.1606


train | loss: 0.2073: 100%|██████████| 7/7 [00:00<00:00, 40.87it/s]

Epoch [47/100], Train Loss: 0.2086





Epoch [47/100], Val Loss: 0.1595


train | loss: 0.2099: 100%|██████████| 7/7 [00:00<00:00, 43.47it/s]

Epoch [48/100], Train Loss: 0.2034





Epoch [48/100], Val Loss: 0.1573


train | loss: 0.2009: 100%|██████████| 7/7 [00:00<00:00, 84.56it/s]

Epoch [49/100], Train Loss: 0.2013





Epoch [49/100], Val Loss: 0.1559


train | loss: 0.2090: 100%|██████████| 7/7 [00:00<00:00, 78.72it/s]

Epoch [50/100], Train Loss: 0.1985





Epoch [50/100], Val Loss: 0.1546


train | loss: 0.1981: 100%|██████████| 7/7 [00:00<00:00, 48.84it/s]

Epoch [51/100], Train Loss: 0.1954





Epoch [51/100], Val Loss: 0.1531


train | loss: 0.1810: 100%|██████████| 7/7 [00:00<00:00, 91.14it/s]

Epoch [52/100], Train Loss: 0.1932





Epoch [52/100], Val Loss: 0.1522


train | loss: 0.1834: 100%|██████████| 7/7 [00:00<00:00, 60.20it/s]

Epoch [53/100], Train Loss: 0.1927





Epoch [53/100], Val Loss: 0.1512


train | loss: 0.1895: 100%|██████████| 7/7 [00:00<00:00, 61.75it/s]

Epoch [54/100], Train Loss: 0.1896





Epoch [54/100], Val Loss: 0.1500


train | loss: 0.1919: 100%|██████████| 7/7 [00:00<00:00, 72.11it/s]

Epoch [55/100], Train Loss: 0.1854





Epoch [55/100], Val Loss: 0.1486


train | loss: 0.1865: 100%|██████████| 7/7 [00:00<00:00, 63.78it/s]

Epoch [56/100], Train Loss: 0.1819





Epoch [56/100], Val Loss: 0.1474


train | loss: 0.1840: 100%|██████████| 7/7 [00:00<00:00, 52.54it/s]

Epoch [57/100], Train Loss: 0.1828





Epoch [57/100], Val Loss: 0.1467


train | loss: 0.1800: 100%|██████████| 7/7 [00:00<00:00, 73.22it/s]

Epoch [58/100], Train Loss: 0.1771





Epoch [58/100], Val Loss: 0.1452


train | loss: 0.1763: 100%|██████████| 7/7 [00:00<00:00, 79.39it/s]

Epoch [59/100], Train Loss: 0.1764





Epoch [59/100], Val Loss: 0.1444


train | loss: 0.1716: 100%|██████████| 7/7 [00:00<00:00, 51.61it/s]

Epoch [60/100], Train Loss: 0.1748





Epoch [60/100], Val Loss: 0.1432


train | loss: 0.1726: 100%|██████████| 7/7 [00:00<00:00, 77.85it/s]

Epoch [61/100], Train Loss: 0.1724





Epoch [61/100], Val Loss: 0.1425


train | loss: 0.1720: 100%|██████████| 7/7 [00:00<00:00, 83.47it/s]

Epoch [62/100], Train Loss: 0.1690





Epoch [62/100], Val Loss: 0.1409


train | loss: 0.1698: 100%|██████████| 7/7 [00:00<00:00, 73.70it/s]

Epoch [63/100], Train Loss: 0.1675





Epoch [63/100], Val Loss: 0.1399


train | loss: 0.1714: 100%|██████████| 7/7 [00:00<00:00, 70.67it/s]

Epoch [64/100], Train Loss: 0.1649





Epoch [64/100], Val Loss: 0.1388


train | loss: 0.1641: 100%|██████████| 7/7 [00:00<00:00, 76.87it/s]

Epoch [65/100], Train Loss: 0.1655





Epoch [65/100], Val Loss: 0.1385


train | loss: 0.1593: 100%|██████████| 7/7 [00:00<00:00, 91.61it/s]

Epoch [66/100], Train Loss: 0.1622





Epoch [66/100], Val Loss: 0.1376


train | loss: 0.1613: 100%|██████████| 7/7 [00:00<00:00, 37.04it/s]

Epoch [67/100], Train Loss: 0.1609





Epoch [67/100], Val Loss: 0.1365


train | loss: 0.1604: 100%|██████████| 7/7 [00:00<00:00, 50.31it/s]

Epoch [68/100], Train Loss: 0.1600





Epoch [68/100], Val Loss: 0.1361


train | loss: 0.1598: 100%|██████████| 7/7 [00:00<00:00, 74.03it/s]

Epoch [69/100], Train Loss: 0.1583





Epoch [69/100], Val Loss: 0.1348


train | loss: 0.1607: 100%|██████████| 7/7 [00:00<00:00, 57.52it/s]

Epoch [70/100], Train Loss: 0.1558





Epoch [70/100], Val Loss: 0.1343


train | loss: 0.1592: 100%|██████████| 7/7 [00:00<00:00, 49.45it/s]

Epoch [71/100], Train Loss: 0.1520





Epoch [71/100], Val Loss: 0.1334


train | loss: 0.1453: 100%|██████████| 7/7 [00:00<00:00, 79.68it/s]

Epoch [72/100], Train Loss: 0.1524





Epoch [72/100], Val Loss: 0.1323


train | loss: 0.1522: 100%|██████████| 7/7 [00:00<00:00, 91.20it/s]

Epoch [73/100], Train Loss: 0.1526





Epoch [73/100], Val Loss: 0.1318


train | loss: 0.1427: 100%|██████████| 7/7 [00:00<00:00, 21.62it/s]


Epoch [74/100], Train Loss: 0.1498
Epoch [74/100], Val Loss: 0.1306


train | loss: 0.1551: 100%|██████████| 7/7 [00:00<00:00, 40.74it/s]

Epoch [75/100], Train Loss: 0.1495





Epoch [75/100], Val Loss: 0.1301


train | loss: 0.1435: 100%|██████████| 7/7 [00:00<00:00, 79.22it/s]

Epoch [76/100], Train Loss: 0.1480





Epoch [76/100], Val Loss: 0.1294


train | loss: 0.1472: 100%|██████████| 7/7 [00:00<00:00, 58.50it/s]

Epoch [77/100], Train Loss: 0.1469





Epoch [77/100], Val Loss: 0.1291


train | loss: 0.1526: 100%|██████████| 7/7 [00:00<00:00, 38.27it/s]

Epoch [78/100], Train Loss: 0.1446





Epoch [78/100], Val Loss: 0.1279


train | loss: 0.1398: 100%|██████████| 7/7 [00:00<00:00, 81.44it/s]

Epoch [79/100], Train Loss: 0.1438





Epoch [79/100], Val Loss: 0.1271


train | loss: 0.1492: 100%|██████████| 7/7 [00:00<00:00, 69.95it/s]

Epoch [80/100], Train Loss: 0.1418





Epoch [80/100], Val Loss: 0.1264


train | loss: 0.1422: 100%|██████████| 7/7 [00:00<00:00, 55.79it/s]

Epoch [81/100], Train Loss: 0.1411





Epoch [81/100], Val Loss: 0.1257


train | loss: 0.1362: 100%|██████████| 7/7 [00:00<00:00, 43.17it/s]

Epoch [82/100], Train Loss: 0.1413





Epoch [82/100], Val Loss: 0.1255


train | loss: 0.1391: 100%|██████████| 7/7 [00:00<00:00, 29.15it/s]


Epoch [83/100], Train Loss: 0.1405
Epoch [83/100], Val Loss: 0.1246


train | loss: 0.1332: 100%|██████████| 7/7 [00:00<00:00, 71.61it/s]

Epoch [84/100], Train Loss: 0.1377





Epoch [84/100], Val Loss: 0.1242


train | loss: 0.1374: 100%|██████████| 7/7 [00:00<00:00, 71.53it/s]

Epoch [85/100], Train Loss: 0.1369





Epoch [85/100], Val Loss: 0.1233


train | loss: 0.1422: 100%|██████████| 7/7 [00:00<00:00, 70.69it/s]

Epoch [86/100], Train Loss: 0.1357





Epoch [86/100], Val Loss: 0.1228


train | loss: 0.1329: 100%|██████████| 7/7 [00:00<00:00, 90.35it/s]

Epoch [87/100], Train Loss: 0.1358





Epoch [87/100], Val Loss: 0.1220


train | loss: 0.1349: 100%|██████████| 7/7 [00:00<00:00, 60.13it/s]

Epoch [88/100], Train Loss: 0.1334





Epoch [88/100], Val Loss: 0.1213


train | loss: 0.1340: 100%|██████████| 7/7 [00:00<00:00, 52.89it/s]

Epoch [89/100], Train Loss: 0.1325





Epoch [89/100], Val Loss: 0.1206


train | loss: 0.1317: 100%|██████████| 7/7 [00:00<00:00, 69.72it/s]

Epoch [90/100], Train Loss: 0.1318





Epoch [90/100], Val Loss: 0.1200


train | loss: 0.1315: 100%|██████████| 7/7 [00:00<00:00, 69.81it/s]

Epoch [91/100], Train Loss: 0.1307





Epoch [91/100], Val Loss: 0.1195


train | loss: 0.1293: 100%|██████████| 7/7 [00:00<00:00, 35.62it/s]

Epoch [92/100], Train Loss: 0.1308





Epoch [92/100], Val Loss: 0.1184


train | loss: 0.1293: 100%|██████████| 7/7 [00:00<00:00, 73.00it/s]

Epoch [93/100], Train Loss: 0.1301





Epoch [93/100], Val Loss: 0.1187


train | loss: 0.1264: 100%|██████████| 7/7 [00:00<00:00, 81.90it/s]

Epoch [94/100], Train Loss: 0.1288





Epoch [94/100], Val Loss: 0.1177


train | loss: 0.1280: 100%|██████████| 7/7 [00:00<00:00, 37.33it/s]

Epoch [95/100], Train Loss: 0.1302





Epoch [95/100], Val Loss: 0.1176


train | loss: 0.1253: 100%|██████████| 7/7 [00:00<00:00, 76.88it/s]

Epoch [96/100], Train Loss: 0.1280





Epoch [96/100], Val Loss: 0.1169


train | loss: 0.1288: 100%|██████████| 7/7 [00:00<00:00, 66.92it/s]

Epoch [97/100], Train Loss: 0.1260





Epoch [97/100], Val Loss: 0.1161


train | loss: 0.1298: 100%|██████████| 7/7 [00:00<00:00, 34.08it/s]


Epoch [98/100], Train Loss: 0.1260
Epoch [98/100], Val Loss: 0.1158


train | loss: 0.1265: 100%|██████████| 7/7 [00:00<00:00, 25.56it/s]


Epoch [99/100], Train Loss: 0.1259
Epoch [99/100], Val Loss: 0.1150


train | loss: 0.1329: 100%|██████████| 7/7 [00:00<00:00, 46.02it/s]

Epoch [100/100], Train Loss: 0.1249





Epoch [100/100], Val Loss: 0.1144


In [None]:
mlp = MLP(384, 384, 384)
empty_mpl = MLP(384, 384, 384)

optim = torch.optim.Adam(lr=lr, params=mlp.parameters())

mlp_last_net, mlp_best_net = train(epochs, optim, criterion, l1_reg, train_loader, test_loader, mlp, empty_mpl, reg_lambda=0.001)

In [39]:
test_loss_lin = validate(criterion, test_loader, lin_last_net)
print(f"lin test: {test_loss_lin}")

train_loss_lin = validate(criterion, train_loader, lin_last_net)
print(f"lin train: {train_loss_lin}")


lin test: 0.11437146365642548
lin train: 0.06165158003568649
