In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from taskdataset import TaskDataset
import matplotlib.pyplot as plt
import os
import time

from t2_functions import partition_ids

In [2]:
import sys

path = os.getcwd()
print(path)
os.chdir('..')
print(os.getcwd())
dataset_t1 = torch.load("task_1_modelstealing/data/ModelStealingPub.pt")

sys.path.append(os.path.join(os.getcwd(), "task_2_sybilattack/"))

from endpoints.requests import sybil, sybil_reset


/home/janek/Documents/Hackaton/ensembleAI-ScoutTeam/task_2_sybilattack
/home/janek/Documents/Hackaton/ensembleAI-ScoutTeam


In [3]:
os.chdir(path)
print(os.getcwd())
dataset = torch.load("data/SybilAttack.pt")

/home/janek/Documents/Hackaton/ensembleAI-ScoutTeam/task_2_sybilattack


In [4]:
ids = np.array(dataset.ids)
print(len(ids))
print(ids[:10])
binned_ids = partition_ids(ids, main_bin_num=10, train=0.1, test=0.9)

20000
[101031   8526  43127 191394 298792 121086 149475 102605 163605 101855]
bin_size: 2000
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800
train: 200, test: 1800


In [36]:
task = 'affine'

sybil_reset(home_or_defense='home', binary_or_affine=task)
sybil_reset(home_or_defense='defense', binary_or_affine=task)


Request ok
{'msg': 'Successful sybil affine reset home'}
Request ok
{'msg': 'Successful sybil affine reset defense'}


In [37]:
ids_train, ids_test = binned_ids[0]
print(len(ids_train))
print(len(ids_test))
A_train_reps = sybil(ids=ids_train,
                 home_or_defense='home',
                 binary_or_affine=task)

B_train_reps = sybil(ids=ids_train,
                 home_or_defense='defense',
                 binary_or_affine=task)
print(f"A train reps: {len(A_train_reps)}")

A_test_reps = sybil(ids=ids_test,
                 home_or_defense='home',
                 binary_or_affine=task)

B_test_reps = sybil(ids=ids_test,
                 home_or_defense='defense',
                 binary_or_affine=task)
print(f"A test reps: {len(A_test_reps)}")


200
1800
A train reps: 200
A test reps: 1800


In [7]:
class RepresentationsDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x, y = self.x[idx], self.y[idx]
        return torch.tensor(x), torch.tensor(y)


In [8]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [9]:
class Linear(nn.Module):
    def __init__(self, input_size, output_size):
        super(Linear, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)
    
    def forward(self, x):
        x = self.fc1(x)
        return x

In [10]:
def validate(criterion, loader, net):
    net.eval()

    with torch.no_grad():
        true = []
        pred = []
        for x, y in loader:
            y_hat = net(x)
            true.append(y.numpy())
            pred.append(y_hat.numpy())

        true = torch.tensor(np.concatenate(true, axis=0))
        pred = torch.tensor(np.concatenate(pred, axis=0))

        loss = criterion(pred, true)

    return loss


In [11]:
def l1_reg(model, reg_lambda):
    l1_regularization = torch.tensor(0., device=model.parameters().__next__().device)
    for param in model.parameters():
        l1_regularization += torch.norm(param, p=1)
    return reg_lambda * l1_regularization


In [75]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def train(epochs, optim, criterion, regularise, trainloader, valloader, net, empty_net, reg_lambda=None):
    best_val_loss = np.inf
    best_net = empty_net

    for epoch in range(epochs):  # loop over the dataset multiple times
        train_loss = 0

        for iter, (x, y) in enumerate(tqdm(trainloader)):
            net.train()
            optim.zero_grad()

            y_hat = net(x)
            loss = criterion(y_hat, y)
            if reg_lambda is not None:
                loss += regularise(net, reg_lambda)
            loss.backward()
            optim.step()

            batch_loss = loss.item()
            train_loss += batch_loss

        train_loss /= len(trainloader)
        print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}")

        net.eval()

        val_loss = validate(criterion, valloader, net)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_net.load_state_dict(net.state_dict())

        print(f"Epoch [{epoch + 1}/{epochs}], Val Loss: {val_loss:.4f}")

    return net, best_net



In [83]:
criterion = nn.MSELoss()
batch_size = 32
epochs = 100
lr = 0.001


In [84]:
def lin_augment_affine(xs, ys, new_for_each_pair=3):
    new_x = []
    new_y = []
    for i, (x1, y1) in tqdm(enumerate(zip(xs, ys))):
        for j, (x2, y2) in enumerate(zip(xs, ys)):
            if i != j:
                weights = np.random.uniform(-1, 2, size=new_for_each_pair)
                for w in weights:
                    x_new = w*x1 + (1-w)*x2
                    y_new = w*y1 + (1-w)*y2
                    new_x.append(x_new)
                    new_y.append(y_new)
    new_x = np.array(new_x)
    new_y = np.array(new_y)
    return np.concatenate((xs, new_x), axis=0), np.concatenate((ys, new_y), axis=0)


In [85]:
A_train_reps_tensor = torch.tensor(A_train_reps, dtype=torch.float32)
B_train_reps_tensor = torch.tensor(B_train_reps, dtype=torch.float32)
A_train_aug, B_train_aug = lin_augment_affine(A_train_reps_tensor, B_train_reps_tensor)

0it [00:00, ?it/s]

200it [00:09, 20.26it/s]


In [86]:
train_dataset = RepresentationsDataset(x=A_train_reps, y=B_train_reps)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = RepresentationsDataset(x=A_test_reps, y=B_test_reps)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [87]:
lin_net = Linear(384, 384)
lin_empty_net = Linear(384, 384)

optim = torch.optim.Adam(lr=lr, params=lin_net.parameters())

lin_last_net, lin_best_net = train(epochs, optim, criterion, l1_reg, train_loader, test_loader, lin_net, lin_empty_net, reg_lambda=0.0001)

100%|██████████| 7/7 [00:00<00:00, 67.91it/s]


Epoch [1/100], Train Loss: 2.4230
Epoch [1/100], Val Loss: 1.1006


100%|██████████| 7/7 [00:00<00:00, 61.56it/s]

Epoch [2/100], Train Loss: 1.2805





Epoch [2/100], Val Loss: 0.8469


100%|██████████| 7/7 [00:00<00:00, 91.37it/s]

Epoch [3/100], Train Loss: 1.0739





Epoch [3/100], Val Loss: 0.6774


100%|██████████| 7/7 [00:00<00:00, 104.14it/s]

Epoch [4/100], Train Loss: 0.9176





Epoch [4/100], Val Loss: 0.5518


100%|██████████| 7/7 [00:00<00:00, 69.28it/s]

Epoch [5/100], Train Loss: 0.8001





Epoch [5/100], Val Loss: 0.4852


100%|██████████| 7/7 [00:00<00:00, 54.57it/s]

Epoch [6/100], Train Loss: 0.7371





Epoch [6/100], Val Loss: 0.4397


100%|██████████| 7/7 [00:00<00:00, 94.39it/s]

Epoch [7/100], Train Loss: 0.6871





Epoch [7/100], Val Loss: 0.4040


100%|██████████| 7/7 [00:00<00:00, 94.05it/s]

Epoch [8/100], Train Loss: 0.6292





Epoch [8/100], Val Loss: 0.3767


100%|██████████| 7/7 [00:00<00:00, 63.54it/s]

Epoch [9/100], Train Loss: 0.6076





Epoch [9/100], Val Loss: 0.3543


100%|██████████| 7/7 [00:00<00:00, 108.29it/s]

Epoch [10/100], Train Loss: 0.5728





Epoch [10/100], Val Loss: 0.3352


100%|██████████| 7/7 [00:00<00:00, 92.50it/s]

Epoch [11/100], Train Loss: 0.5465





Epoch [11/100], Val Loss: 0.3197


100%|██████████| 7/7 [00:00<00:00, 47.15it/s]

Epoch [12/100], Train Loss: 0.5208





Epoch [12/100], Val Loss: 0.3067


100%|██████████| 7/7 [00:00<00:00, 54.82it/s]

Epoch [13/100], Train Loss: 0.4996





Epoch [13/100], Val Loss: 0.2949


100%|██████████| 7/7 [00:00<00:00, 84.81it/s]

Epoch [14/100], Train Loss: 0.4786





Epoch [14/100], Val Loss: 0.2846


100%|██████████| 7/7 [00:00<00:00, 49.20it/s]

Epoch [15/100], Train Loss: 0.4638





Epoch [15/100], Val Loss: 0.2762


100%|██████████| 7/7 [00:00<00:00, 43.51it/s]

Epoch [16/100], Train Loss: 0.4453





Epoch [16/100], Val Loss: 0.2669


100%|██████████| 7/7 [00:00<00:00, 50.34it/s]

Epoch [17/100], Train Loss: 0.4313





Epoch [17/100], Val Loss: 0.2608


100%|██████████| 7/7 [00:00<00:00, 85.02it/s]

Epoch [18/100], Train Loss: 0.4129





Epoch [18/100], Val Loss: 0.2529


100%|██████████| 7/7 [00:00<00:00, 108.91it/s]

Epoch [19/100], Train Loss: 0.4097





Epoch [19/100], Val Loss: 0.2469


100%|██████████| 7/7 [00:00<00:00, 42.51it/s]

Epoch [20/100], Train Loss: 0.3874





Epoch [20/100], Val Loss: 0.2406


100%|██████████| 7/7 [00:00<00:00, 52.08it/s]

Epoch [21/100], Train Loss: 0.3761





Epoch [21/100], Val Loss: 0.2351


100%|██████████| 7/7 [00:00<00:00, 85.70it/s]

Epoch [22/100], Train Loss: 0.3639





Epoch [22/100], Val Loss: 0.2296


100%|██████████| 7/7 [00:00<00:00, 98.01it/s]

Epoch [23/100], Train Loss: 0.3498





Epoch [23/100], Val Loss: 0.2253


100%|██████████| 7/7 [00:00<00:00, 102.05it/s]

Epoch [24/100], Train Loss: 0.3394





Epoch [24/100], Val Loss: 0.2211


100%|██████████| 7/7 [00:00<00:00, 105.74it/s]

Epoch [25/100], Train Loss: 0.3338





Epoch [25/100], Val Loss: 0.2167


100%|██████████| 7/7 [00:00<00:00, 70.96it/s]

Epoch [26/100], Train Loss: 0.3246





Epoch [26/100], Val Loss: 0.2135


100%|██████████| 7/7 [00:00<00:00, 91.17it/s]

Epoch [27/100], Train Loss: 0.3154





Epoch [27/100], Val Loss: 0.2099


100%|██████████| 7/7 [00:00<00:00, 97.50it/s]

Epoch [28/100], Train Loss: 0.3065





Epoch [28/100], Val Loss: 0.2072


100%|██████████| 7/7 [00:00<00:00, 72.32it/s]

Epoch [29/100], Train Loss: 0.2941





Epoch [29/100], Val Loss: 0.2034


100%|██████████| 7/7 [00:00<00:00, 24.63it/s]


Epoch [30/100], Train Loss: 0.2896
Epoch [30/100], Val Loss: 0.2001


100%|██████████| 7/7 [00:00<00:00, 95.27it/s]

Epoch [31/100], Train Loss: 0.2815





Epoch [31/100], Val Loss: 0.1970


100%|██████████| 7/7 [00:00<00:00, 87.62it/s]

Epoch [32/100], Train Loss: 0.2776





Epoch [32/100], Val Loss: 0.1946


100%|██████████| 7/7 [00:00<00:00, 62.50it/s]

Epoch [33/100], Train Loss: 0.2706





Epoch [33/100], Val Loss: 0.1918


100%|██████████| 7/7 [00:00<00:00, 50.11it/s]

Epoch [34/100], Train Loss: 0.2637





Epoch [34/100], Val Loss: 0.1887


100%|██████████| 7/7 [00:00<00:00, 73.68it/s]

Epoch [35/100], Train Loss: 0.2585





Epoch [35/100], Val Loss: 0.1869


100%|██████████| 7/7 [00:00<00:00, 93.12it/s]

Epoch [36/100], Train Loss: 0.2501





Epoch [36/100], Val Loss: 0.1845


100%|██████████| 7/7 [00:00<00:00, 17.43it/s]


Epoch [37/100], Train Loss: 0.2480
Epoch [37/100], Val Loss: 0.1820


100%|██████████| 7/7 [00:00<00:00, 43.31it/s]

Epoch [38/100], Train Loss: 0.2422





Epoch [38/100], Val Loss: 0.1805


100%|██████████| 7/7 [00:00<00:00, 98.94it/s]

Epoch [39/100], Train Loss: 0.2330





Epoch [39/100], Val Loss: 0.1777


100%|██████████| 7/7 [00:00<00:00, 81.24it/s]

Epoch [40/100], Train Loss: 0.2307





Epoch [40/100], Val Loss: 0.1762


100%|██████████| 7/7 [00:00<00:00, 64.09it/s]

Epoch [41/100], Train Loss: 0.2264





Epoch [41/100], Val Loss: 0.1736


100%|██████████| 7/7 [00:00<00:00, 122.29it/s]

Epoch [42/100], Train Loss: 0.2200





Epoch [42/100], Val Loss: 0.1728


100%|██████████| 7/7 [00:00<00:00, 68.51it/s]

Epoch [43/100], Train Loss: 0.2171





Epoch [43/100], Val Loss: 0.1703


100%|██████████| 7/7 [00:00<00:00, 48.89it/s]

Epoch [44/100], Train Loss: 0.2128





Epoch [44/100], Val Loss: 0.1685


100%|██████████| 7/7 [00:00<00:00, 97.41it/s]

Epoch [45/100], Train Loss: 0.2103





Epoch [45/100], Val Loss: 0.1678


100%|██████████| 7/7 [00:00<00:00, 98.34it/s]

Epoch [46/100], Train Loss: 0.2090





Epoch [46/100], Val Loss: 0.1656


100%|██████████| 7/7 [00:00<00:00, 79.78it/s]

Epoch [47/100], Train Loss: 0.2030





Epoch [47/100], Val Loss: 0.1642


100%|██████████| 7/7 [00:00<00:00, 41.29it/s]

Epoch [48/100], Train Loss: 0.1993





Epoch [48/100], Val Loss: 0.1618


100%|██████████| 7/7 [00:00<00:00, 74.37it/s]

Epoch [49/100], Train Loss: 0.1967





Epoch [49/100], Val Loss: 0.1607


100%|██████████| 7/7 [00:00<00:00, 103.80it/s]

Epoch [50/100], Train Loss: 0.1937





Epoch [50/100], Val Loss: 0.1596


100%|██████████| 7/7 [00:00<00:00, 41.46it/s]

Epoch [51/100], Train Loss: 0.1911





Epoch [51/100], Val Loss: 0.1586


100%|██████████| 7/7 [00:00<00:00, 61.95it/s]

Epoch [52/100], Train Loss: 0.1893





Epoch [52/100], Val Loss: 0.1571


100%|██████████| 7/7 [00:00<00:00, 87.80it/s]

Epoch [53/100], Train Loss: 0.1866





Epoch [53/100], Val Loss: 0.1559


100%|██████████| 7/7 [00:00<00:00, 73.50it/s]

Epoch [54/100], Train Loss: 0.1824





Epoch [54/100], Val Loss: 0.1542


100%|██████████| 7/7 [00:00<00:00, 56.80it/s]

Epoch [55/100], Train Loss: 0.1804





Epoch [55/100], Val Loss: 0.1528


100%|██████████| 7/7 [00:00<00:00, 56.47it/s]

Epoch [56/100], Train Loss: 0.1773





Epoch [56/100], Val Loss: 0.1516


100%|██████████| 7/7 [00:00<00:00, 95.29it/s]

Epoch [57/100], Train Loss: 0.1750





Epoch [57/100], Val Loss: 0.1504


100%|██████████| 7/7 [00:00<00:00, 86.45it/s]

Epoch [58/100], Train Loss: 0.1731





Epoch [58/100], Val Loss: 0.1499


100%|██████████| 7/7 [00:00<00:00, 76.17it/s]

Epoch [59/100], Train Loss: 0.1711





Epoch [59/100], Val Loss: 0.1482


100%|██████████| 7/7 [00:00<00:00, 98.26it/s]

Epoch [60/100], Train Loss: 0.1705





Epoch [60/100], Val Loss: 0.1477


100%|██████████| 7/7 [00:00<00:00, 98.19it/s]

Epoch [61/100], Train Loss: 0.1657





Epoch [61/100], Val Loss: 0.1460


100%|██████████| 7/7 [00:00<00:00, 27.45it/s]


Epoch [62/100], Train Loss: 0.1652
Epoch [62/100], Val Loss: 0.1456


100%|██████████| 7/7 [00:00<00:00, 61.51it/s]

Epoch [63/100], Train Loss: 0.1631





Epoch [63/100], Val Loss: 0.1441


100%|██████████| 7/7 [00:00<00:00, 100.66it/s]

Epoch [64/100], Train Loss: 0.1611





Epoch [64/100], Val Loss: 0.1432


100%|██████████| 7/7 [00:00<00:00, 73.12it/s]

Epoch [65/100], Train Loss: 0.1590





Epoch [65/100], Val Loss: 0.1424


100%|██████████| 7/7 [00:00<00:00, 80.18it/s]

Epoch [66/100], Train Loss: 0.1557





Epoch [66/100], Val Loss: 0.1414


100%|██████████| 7/7 [00:00<00:00, 56.47it/s]

Epoch [67/100], Train Loss: 0.1539





Epoch [67/100], Val Loss: 0.1403


100%|██████████| 7/7 [00:00<00:00, 91.30it/s]

Epoch [68/100], Train Loss: 0.1548





Epoch [68/100], Val Loss: 0.1393


100%|██████████| 7/7 [00:00<00:00, 72.58it/s]

Epoch [69/100], Train Loss: 0.1513





Epoch [69/100], Val Loss: 0.1391


100%|██████████| 7/7 [00:00<00:00, 53.06it/s]

Epoch [70/100], Train Loss: 0.1499





Epoch [70/100], Val Loss: 0.1379


100%|██████████| 7/7 [00:00<00:00, 71.71it/s]

Epoch [71/100], Train Loss: 0.1500





Epoch [71/100], Val Loss: 0.1371


100%|██████████| 7/7 [00:00<00:00, 97.33it/s]

Epoch [72/100], Train Loss: 0.1494





Epoch [72/100], Val Loss: 0.1360


100%|██████████| 7/7 [00:00<00:00, 68.84it/s]

Epoch [73/100], Train Loss: 0.1454





Epoch [73/100], Val Loss: 0.1354


100%|██████████| 7/7 [00:00<00:00, 59.36it/s]

Epoch [74/100], Train Loss: 0.1460





Epoch [74/100], Val Loss: 0.1346


100%|██████████| 7/7 [00:00<00:00, 101.66it/s]

Epoch [75/100], Train Loss: 0.1434





Epoch [75/100], Val Loss: 0.1337


100%|██████████| 7/7 [00:00<00:00, 92.69it/s]

Epoch [76/100], Train Loss: 0.1421





Epoch [76/100], Val Loss: 0.1326


100%|██████████| 7/7 [00:00<00:00, 41.19it/s]

Epoch [77/100], Train Loss: 0.1401





Epoch [77/100], Val Loss: 0.1319


100%|██████████| 7/7 [00:00<00:00, 72.44it/s]

Epoch [78/100], Train Loss: 0.1404





Epoch [78/100], Val Loss: 0.1309


100%|██████████| 7/7 [00:00<00:00, 106.26it/s]

Epoch [79/100], Train Loss: 0.1386





Epoch [79/100], Val Loss: 0.1304


100%|██████████| 7/7 [00:00<00:00, 92.46it/s]

Epoch [80/100], Train Loss: 0.1381





Epoch [80/100], Val Loss: 0.1299


100%|██████████| 7/7 [00:00<00:00, 38.26it/s]

Epoch [81/100], Train Loss: 0.1354





Epoch [81/100], Val Loss: 0.1293


100%|██████████| 7/7 [00:00<00:00, 70.31it/s]

Epoch [82/100], Train Loss: 0.1343





Epoch [82/100], Val Loss: 0.1284


100%|██████████| 7/7 [00:00<00:00, 94.13it/s]

Epoch [83/100], Train Loss: 0.1330





Epoch [83/100], Val Loss: 0.1278


100%|██████████| 7/7 [00:00<00:00, 56.87it/s]

Epoch [84/100], Train Loss: 0.1341





Epoch [84/100], Val Loss: 0.1269


100%|██████████| 7/7 [00:00<00:00, 75.19it/s]

Epoch [85/100], Train Loss: 0.1353





Epoch [85/100], Val Loss: 0.1269


100%|██████████| 7/7 [00:00<00:00, 68.57it/s]

Epoch [86/100], Train Loss: 0.1309





Epoch [86/100], Val Loss: 0.1260


100%|██████████| 7/7 [00:00<00:00, 71.05it/s]

Epoch [87/100], Train Loss: 0.1309





Epoch [87/100], Val Loss: 0.1260


100%|██████████| 7/7 [00:00<00:00, 98.02it/s]

Epoch [88/100], Train Loss: 0.1298





Epoch [88/100], Val Loss: 0.1246


100%|██████████| 7/7 [00:00<00:00, 89.20it/s]

Epoch [89/100], Train Loss: 0.1286





Epoch [89/100], Val Loss: 0.1241


100%|██████████| 7/7 [00:00<00:00, 99.83it/s]

Epoch [90/100], Train Loss: 0.1272





Epoch [90/100], Val Loss: 0.1237


100%|██████████| 7/7 [00:00<00:00, 98.28it/s]

Epoch [91/100], Train Loss: 0.1270





Epoch [91/100], Val Loss: 0.1228


100%|██████████| 7/7 [00:00<00:00, 68.04it/s]

Epoch [92/100], Train Loss: 0.1260





Epoch [92/100], Val Loss: 0.1217


100%|██████████| 7/7 [00:00<00:00, 71.86it/s]

Epoch [93/100], Train Loss: 0.1253





Epoch [93/100], Val Loss: 0.1214


100%|██████████| 7/7 [00:00<00:00, 93.60it/s]

Epoch [94/100], Train Loss: 0.1246





Epoch [94/100], Val Loss: 0.1212


100%|██████████| 7/7 [00:00<00:00, 59.60it/s]

Epoch [95/100], Train Loss: 0.1248





Epoch [95/100], Val Loss: 0.1206


100%|██████████| 7/7 [00:00<00:00, 47.02it/s]

Epoch [96/100], Train Loss: 0.1227





Epoch [96/100], Val Loss: 0.1197


100%|██████████| 7/7 [00:00<00:00, 57.55it/s]

Epoch [97/100], Train Loss: 0.1222





Epoch [97/100], Val Loss: 0.1194


100%|██████████| 7/7 [00:00<00:00, 98.00it/s]

Epoch [98/100], Train Loss: 0.1210





Epoch [98/100], Val Loss: 0.1186


100%|██████████| 7/7 [00:00<00:00, 56.85it/s]

Epoch [99/100], Train Loss: 0.1194





Epoch [99/100], Val Loss: 0.1178


100%|██████████| 7/7 [00:00<00:00, 38.74it/s]

Epoch [100/100], Train Loss: 0.1199





Epoch [100/100], Val Loss: 0.1177


In [65]:
mlp = MLP(384, 32, 384)
empty_mpl = MLP(384, 32, 384)

optim = torch.optim.Adam(lr=lr, params=mlp.parameters())

mlp_last_net, mlp_best_net = train(epochs, optim, criterion, l1_reg, train_loader, test_loader, mlp, empty_mpl, reg_lambda=0.0)

RuntimeError: mat1 and mat2 must have the same dtype, but got Double and Float

In [None]:
test_loss_lin = validate(criterion, test_loader, lin_last_net)
print(f"lin test: {test_loss_lin}")

test_loss_mlp = validate(criterion, test_loader, mlp_last_net)
print(f"lin test: {test_loss_mlp}")
