In [1]:
import os
import glob
import torch
import torch.nn as nn

from tqdm import tqdm
from torch.optim import SGD
from sklearn.datasets import make_moons
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

In [2]:
class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.coef = nn.Parameter(torch.randn((1, 2)))
        self.bias   = nn.Parameter(torch.randn(1))
    
    def forward(self, x):
        return 1 / (1 + torch.exp(-(self.bias + torch.sum(self.coef * x, dim = 1, keepdims = True))))

In [3]:
class Dataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y)

    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx].unsqueeze(dim = -1).to(torch.float64)

In [4]:
class Saver:
    def __init__(self, model_name, model_dir):
        self.model_base_name = model_name
        self.model_dir = model_dir

        os.makedirs(self.model_dir, exist_ok = True)

    def start_or_resume_training(self, model):
        models_paths = sorted(glob.glob(os.path.join(self.model_dir, f'{self.model_base_name}_*.pth')))
        if len(models_paths) == 0:
            return 1, model
        last_model_path = models_paths[-1]
        last_model_epoch = int(os.path.splitext(last_model_path)[0].split('_')[-1])
        print(f'Resuming training from epoch {epoch}...')
        model.load_state_dict(torch.load(last_model_path))
        return last_model_epoch, model

    def save_model(self, model, epoch = None):
        model_name = self.model_base_name
        if epoch is not None:
            model_name += f'_{epoch}'
        
        previous_models_paths = glob.glob(os.path.join(self.model_dir, f'{self.model_base_name}_*.pth'))
        for previous_model_path in previous_models_paths:
            os.remove(previous_model_path)

        torch.save(model.state_dict(), os.path.join(self.model_dir, model_name + '.pth'))
        

In [5]:
X, y = make_moons(n_samples = 10000)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = .2)

In [6]:
train_dataset = Dataset(X_train, y_train)
test_dataset  = Dataset(X_test,  y_test)
train_dataloader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
test_dataloader  = DataLoader(test_dataset,  batch_size = 1, shuffle = False)

In [7]:
model = LogisticRegression()
optimizer = SGD(model.parameters())
criterion = nn.BCELoss()
writer = SummaryWriter()
saver = Saver(model_name = 'model', model_dir = 'models')
graph_logged = False

In [8]:
n_epochs = 250
save_interval = 10
initial_epoch, model = saver.start_or_resume_training(model)

model.train()
for epoch in range(initial_epoch, n_epochs + 1):
    print(f'Epoch {epoch}/{n_epochs}')

    total_loss = 0
    for i, (x, y_true) in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        y_pred = model(x)
        if not graph_logged:
            writer.add_graph(model, x)
            graph_logged = True
        loss = criterion(y_pred, y_true)
        loss_value = loss.item()
        step = (epoch - 1) * len(train_dataloader) + i
        writer.add_scalar("Loss/batch_train_loss", loss_value, step)
        total_loss += loss_value
        loss.backward()
        optimizer.step()
    
    epoch_loss = total_loss / len(train_dataloader)
    writer.add_scalar("Loss/epoch_train_loss", epoch_loss, epoch)
    print('Training loss:', epoch_loss)

    if epoch % save_interval == 0:
        saver.save_model(model, epoch)

saver.save_model(model)



Epoch 1/250


  0%|          | 0/250 [00:00<?, ?it/s]

100%|██████████| 250/250 [00:00<00:00, 269.37it/s]


Training loss: 1.352207160867811
Epoch 2/250


100%|██████████| 250/250 [00:00<00:00, 467.06it/s]


Training loss: 1.25716627093661
Epoch 3/250


100%|██████████| 250/250 [00:00<00:00, 444.72it/s]


Training loss: 1.1670515939172716
Epoch 4/250


100%|██████████| 250/250 [00:00<00:00, 503.35it/s]


Training loss: 1.0824220881076423
Epoch 5/250


100%|██████████| 250/250 [00:00<00:00, 514.41it/s]


Training loss: 1.0038124996654996
Epoch 6/250


100%|██████████| 250/250 [00:00<00:00, 473.94it/s]


Training loss: 0.9316385847530537
Epoch 7/250


100%|██████████| 250/250 [00:00<00:00, 465.36it/s]


Training loss: 0.8661686612322135
Epoch 8/250


100%|██████████| 250/250 [00:00<00:00, 436.96it/s]


Training loss: 0.8074634491597085
Epoch 9/250


100%|██████████| 250/250 [00:00<00:00, 440.52it/s]


Training loss: 0.7554050334514767
Epoch 10/250


100%|██████████| 250/250 [00:00<00:00, 455.04it/s]


Training loss: 0.7096790408431513
Epoch 11/250


100%|██████████| 250/250 [00:00<00:00, 502.82it/s]


Training loss: 0.6698323758045851
Epoch 12/250


100%|██████████| 250/250 [00:00<00:00, 510.46it/s]


Training loss: 0.6352990818243068
Epoch 13/250


100%|██████████| 250/250 [00:00<00:00, 525.22it/s]


Training loss: 0.6054727036277687
Epoch 14/250


100%|██████████| 250/250 [00:00<00:00, 444.33it/s]


Training loss: 0.5797479327827499
Epoch 15/250


100%|██████████| 250/250 [00:00<00:00, 514.76it/s]


Training loss: 0.5575383275721101
Epoch 16/250


100%|██████████| 250/250 [00:00<00:00, 521.83it/s]


Training loss: 0.538317019006377
Epoch 17/250


100%|██████████| 250/250 [00:00<00:00, 517.09it/s]


Training loss: 0.521617920439712
Epoch 18/250


100%|██████████| 250/250 [00:00<00:00, 502.19it/s]


Training loss: 0.5070387344296101
Epoch 19/250


100%|██████████| 250/250 [00:00<00:00, 506.19it/s]


Training loss: 0.4942446939245492
Epoch 20/250


100%|██████████| 250/250 [00:00<00:00, 408.21it/s]


Training loss: 0.4829453778094561
Epoch 21/250


100%|██████████| 250/250 [00:00<00:00, 495.03it/s]


Training loss: 0.47290901039230593
Epoch 22/250


100%|██████████| 250/250 [00:00<00:00, 475.19it/s]


Training loss: 0.4639459349215582
Epoch 23/250


100%|██████████| 250/250 [00:00<00:00, 503.08it/s]


Training loss: 0.4558865468647544
Epoch 24/250


100%|██████████| 250/250 [00:00<00:00, 494.26it/s]


Training loss: 0.4485986172265845
Epoch 25/250


100%|██████████| 250/250 [00:00<00:00, 517.51it/s]


Training loss: 0.441974126897996
Epoch 26/250


100%|██████████| 250/250 [00:00<00:00, 518.56it/s]


Training loss: 0.4359212100882389
Epoch 27/250


100%|██████████| 250/250 [00:00<00:00, 538.73it/s]


Training loss: 0.43036198640271245
Epoch 28/250


100%|██████████| 250/250 [00:00<00:00, 478.00it/s]


Training loss: 0.4252298848398536
Epoch 29/250


100%|██████████| 250/250 [00:00<00:00, 535.43it/s]


Training loss: 0.4204707558823628
Epoch 30/250


100%|██████████| 250/250 [00:00<00:00, 520.51it/s]


Training loss: 0.4160422486570173
Epoch 31/250


100%|██████████| 250/250 [00:00<00:00, 430.54it/s]


Training loss: 0.41190530011092735
Epoch 32/250


100%|██████████| 250/250 [00:00<00:00, 501.98it/s]


Training loss: 0.40802899256675
Epoch 33/250


100%|██████████| 250/250 [00:00<00:00, 534.62it/s]


Training loss: 0.40438397628291356
Epoch 34/250


100%|██████████| 250/250 [00:00<00:00, 536.85it/s]


Training loss: 0.400944904324424
Epoch 35/250


100%|██████████| 250/250 [00:00<00:00, 522.05it/s]


Training loss: 0.39769203663924413
Epoch 36/250


100%|██████████| 250/250 [00:00<00:00, 510.18it/s]


Training loss: 0.39460842016464537
Epoch 37/250


100%|██████████| 250/250 [00:00<00:00, 518.95it/s]


Training loss: 0.39167959184911455
Epoch 38/250


100%|██████████| 250/250 [00:00<00:00, 544.56it/s]


Training loss: 0.3888911990751127
Epoch 39/250


100%|██████████| 250/250 [00:00<00:00, 529.74it/s]


Training loss: 0.3862300290984538
Epoch 40/250


100%|██████████| 250/250 [00:00<00:00, 540.52it/s]


Training loss: 0.38368775813068423
Epoch 41/250


100%|██████████| 250/250 [00:00<00:00, 509.84it/s]


Training loss: 0.3812543321260343
Epoch 42/250


100%|██████████| 250/250 [00:00<00:00, 438.67it/s]


Training loss: 0.3789209661981547
Epoch 43/250


100%|██████████| 250/250 [00:00<00:00, 474.66it/s]


Training loss: 0.3766805644954065
Epoch 44/250


100%|██████████| 250/250 [00:00<00:00, 523.73it/s]


Training loss: 0.3745255349127061
Epoch 45/250


100%|██████████| 250/250 [00:00<00:00, 514.78it/s]


Training loss: 0.37245155799640295
Epoch 46/250


100%|██████████| 250/250 [00:00<00:00, 534.23it/s]


Training loss: 0.37045305733532313
Epoch 47/250


100%|██████████| 250/250 [00:00<00:00, 531.72it/s]


Training loss: 0.36852471446942353
Epoch 48/250


100%|██████████| 250/250 [00:00<00:00, 528.54it/s]


Training loss: 0.36666229530782796
Epoch 49/250


100%|██████████| 250/250 [00:00<00:00, 535.77it/s]


Training loss: 0.36486262515192086
Epoch 50/250


100%|██████████| 250/250 [00:00<00:00, 542.99it/s]


Training loss: 0.3631212315866632
Epoch 51/250


100%|██████████| 250/250 [00:00<00:00, 519.15it/s]


Training loss: 0.3614347854353062
Epoch 52/250


100%|██████████| 250/250 [00:00<00:00, 523.43it/s]


Training loss: 0.35980019975419153
Epoch 53/250


100%|██████████| 250/250 [00:00<00:00, 422.46it/s]


Training loss: 0.3582149678818575
Epoch 54/250


100%|██████████| 250/250 [00:00<00:00, 516.28it/s]


Training loss: 0.35667590426072543
Epoch 55/250


100%|██████████| 250/250 [00:00<00:00, 526.52it/s]


Training loss: 0.355181143895597
Epoch 56/250


100%|██████████| 250/250 [00:00<00:00, 531.81it/s]


Training loss: 0.3537289112857226
Epoch 57/250


100%|██████████| 250/250 [00:00<00:00, 438.64it/s]


Training loss: 0.35231693234085354
Epoch 58/250


100%|██████████| 250/250 [00:00<00:00, 351.88it/s]


Training loss: 0.35094230480579036
Epoch 59/250


100%|██████████| 250/250 [00:00<00:00, 503.34it/s]


Training loss: 0.3496041940441011
Epoch 60/250


100%|██████████| 250/250 [00:00<00:00, 522.26it/s]


Training loss: 0.3483008801867059
Epoch 61/250


100%|██████████| 250/250 [00:00<00:00, 523.12it/s]


Training loss: 0.3470303818355659
Epoch 62/250


100%|██████████| 250/250 [00:00<00:00, 511.25it/s]


Training loss: 0.345792264063708
Epoch 63/250


100%|██████████| 250/250 [00:00<00:00, 451.99it/s]


Training loss: 0.34458372727247416
Epoch 64/250


100%|██████████| 250/250 [00:00<00:00, 480.53it/s]


Training loss: 0.34340463043071234
Epoch 65/250


100%|██████████| 250/250 [00:00<00:00, 491.29it/s]


Training loss: 0.3422528406516207
Epoch 66/250


100%|██████████| 250/250 [00:00<00:00, 538.70it/s]


Training loss: 0.3411281437968993
Epoch 67/250


100%|██████████| 250/250 [00:00<00:00, 525.04it/s]


Training loss: 0.3400290681541737
Epoch 68/250


100%|██████████| 250/250 [00:00<00:00, 530.56it/s]


Training loss: 0.33895413743277075
Epoch 69/250


100%|██████████| 250/250 [00:00<00:00, 495.01it/s]


Training loss: 0.3379031593447289
Epoch 70/250


100%|██████████| 250/250 [00:00<00:00, 537.47it/s]


Training loss: 0.3368746465097406
Epoch 71/250


100%|██████████| 250/250 [00:00<00:00, 503.68it/s]


Training loss: 0.3358683544056014
Epoch 72/250


100%|██████████| 250/250 [00:00<00:00, 522.41it/s]


Training loss: 0.33488310468204846
Epoch 73/250


100%|██████████| 250/250 [00:00<00:00, 531.96it/s]


Training loss: 0.33391854326854464
Epoch 74/250


100%|██████████| 250/250 [00:00<00:00, 460.99it/s]


Training loss: 0.3329728026529461
Epoch 75/250


100%|██████████| 250/250 [00:00<00:00, 506.87it/s]


Training loss: 0.3320465347441773
Epoch 76/250


100%|██████████| 250/250 [00:00<00:00, 528.20it/s]


Training loss: 0.3311380225122565
Epoch 77/250


100%|██████████| 250/250 [00:00<00:00, 514.54it/s]


Training loss: 0.3302478598857016
Epoch 78/250


100%|██████████| 250/250 [00:00<00:00, 526.25it/s]


Training loss: 0.32937437363248406
Epoch 79/250


100%|██████████| 250/250 [00:00<00:00, 529.12it/s]


Training loss: 0.3285173296386319
Epoch 80/250


100%|██████████| 250/250 [00:00<00:00, 534.41it/s]


Training loss: 0.32767640273819076
Epoch 81/250


100%|██████████| 250/250 [00:00<00:00, 506.71it/s]


Training loss: 0.32685110143641843
Epoch 82/250


100%|██████████| 250/250 [00:00<00:00, 540.83it/s]


Training loss: 0.3260402421013843
Epoch 83/250


100%|██████████| 250/250 [00:00<00:00, 523.45it/s]


Training loss: 0.32524418135340055
Epoch 84/250


100%|██████████| 250/250 [00:00<00:00, 546.33it/s]


Training loss: 0.32446197582124936
Epoch 85/250


100%|██████████| 250/250 [00:00<00:00, 444.01it/s]


Training loss: 0.323693328636937
Epoch 86/250


100%|██████████| 250/250 [00:00<00:00, 514.50it/s]


Training loss: 0.3229384589966011
Epoch 87/250


100%|██████████| 250/250 [00:00<00:00, 517.67it/s]


Training loss: 0.32219588317020537
Epoch 88/250


100%|██████████| 250/250 [00:00<00:00, 516.33it/s]


Training loss: 0.32146524815996
Epoch 89/250


100%|██████████| 250/250 [00:00<00:00, 519.97it/s]


Training loss: 0.32074733279082035
Epoch 90/250


100%|██████████| 250/250 [00:00<00:00, 511.91it/s]


Training loss: 0.3200408496342942
Epoch 91/250


100%|██████████| 250/250 [00:00<00:00, 527.82it/s]


Training loss: 0.3193460795506056
Epoch 92/250


100%|██████████| 250/250 [00:00<00:00, 603.48it/s]


Training loss: 0.31866247001141657
Epoch 93/250


100%|██████████| 250/250 [00:00<00:00, 531.71it/s]


Training loss: 0.31798893425186897
Epoch 94/250


100%|██████████| 250/250 [00:00<00:00, 574.71it/s]


Training loss: 0.31732620364152997
Epoch 95/250


100%|██████████| 250/250 [00:00<00:00, 586.00it/s]


Training loss: 0.31667423685706253
Epoch 96/250


100%|██████████| 250/250 [00:00<00:00, 455.05it/s]


Training loss: 0.31603150372122263
Epoch 97/250


100%|██████████| 250/250 [00:00<00:00, 491.86it/s]


Training loss: 0.31539826502312557
Epoch 98/250


100%|██████████| 250/250 [00:00<00:00, 527.86it/s]


Training loss: 0.31477506956557766
Epoch 99/250


100%|██████████| 250/250 [00:00<00:00, 521.11it/s]


Training loss: 0.3141605615613297
Epoch 100/250


100%|██████████| 250/250 [00:00<00:00, 534.20it/s]


Training loss: 0.31355515348705215
Epoch 101/250


100%|██████████| 250/250 [00:00<00:00, 522.88it/s]


Training loss: 0.31295869049639086
Epoch 102/250


100%|██████████| 250/250 [00:00<00:00, 510.20it/s]


Training loss: 0.312370701109213
Epoch 103/250


100%|██████████| 250/250 [00:00<00:00, 516.48it/s]


Training loss: 0.3117906006567276
Epoch 104/250


100%|██████████| 250/250 [00:00<00:00, 511.13it/s]


Training loss: 0.3112187728306436
Epoch 105/250


100%|██████████| 250/250 [00:00<00:00, 504.61it/s]


Training loss: 0.31065473492290824
Epoch 106/250


100%|██████████| 250/250 [00:00<00:00, 502.02it/s]


Training loss: 0.31009878390440887
Epoch 107/250


100%|██████████| 250/250 [00:00<00:00, 457.30it/s]


Training loss: 0.309550091348625
Epoch 108/250


100%|██████████| 250/250 [00:00<00:00, 519.43it/s]


Training loss: 0.3090087881747783
Epoch 109/250


100%|██████████| 250/250 [00:00<00:00, 520.58it/s]


Training loss: 0.30847458639217734
Epoch 110/250


100%|██████████| 250/250 [00:00<00:00, 531.21it/s]


Training loss: 0.30794780903199453
Epoch 111/250


100%|██████████| 250/250 [00:00<00:00, 490.24it/s]


Training loss: 0.307427635916816
Epoch 112/250


100%|██████████| 250/250 [00:00<00:00, 511.65it/s]


Training loss: 0.3069142968605076
Epoch 113/250


100%|██████████| 250/250 [00:00<00:00, 524.19it/s]


Training loss: 0.30640738490169767
Epoch 114/250


100%|██████████| 250/250 [00:00<00:00, 514.26it/s]


Training loss: 0.30590655031615277
Epoch 115/250


100%|██████████| 250/250 [00:00<00:00, 528.62it/s]


Training loss: 0.3054118983999687
Epoch 116/250


100%|██████████| 250/250 [00:00<00:00, 519.98it/s]


Training loss: 0.30492419488969164
Epoch 117/250


100%|██████████| 250/250 [00:00<00:00, 534.66it/s]


Training loss: 0.3044418687981087
Epoch 118/250


100%|██████████| 250/250 [00:00<00:00, 371.18it/s]


Training loss: 0.30396578308319144
Epoch 119/250


100%|██████████| 250/250 [00:00<00:00, 360.99it/s]


Training loss: 0.3034954761975484
Epoch 120/250


100%|██████████| 250/250 [00:00<00:00, 483.96it/s]


Training loss: 0.3030308716717934
Epoch 121/250


100%|██████████| 250/250 [00:00<00:00, 500.45it/s]


Training loss: 0.3025711198139084
Epoch 122/250


100%|██████████| 250/250 [00:00<00:00, 516.63it/s]


Training loss: 0.3021173081386194
Epoch 123/250


100%|██████████| 250/250 [00:00<00:00, 512.18it/s]


Training loss: 0.3016689691805261
Epoch 124/250


100%|██████████| 250/250 [00:00<00:00, 516.44it/s]


Training loss: 0.30122595640803174
Epoch 125/250


100%|██████████| 250/250 [00:00<00:00, 507.03it/s]


Training loss: 0.3007874272215196
Epoch 126/250


100%|██████████| 250/250 [00:00<00:00, 528.75it/s]


Training loss: 0.3003543507058815
Epoch 127/250


100%|██████████| 250/250 [00:00<00:00, 511.75it/s]


Training loss: 0.2999258985343182
Epoch 128/250


100%|██████████| 250/250 [00:00<00:00, 506.88it/s]


Training loss: 0.29950252367602326
Epoch 129/250


100%|██████████| 250/250 [00:00<00:00, 469.04it/s]


Training loss: 0.2990837061349133
Epoch 130/250


100%|██████████| 250/250 [00:00<00:00, 520.09it/s]


Training loss: 0.2986699844585359
Epoch 131/250


100%|██████████| 250/250 [00:00<00:00, 510.81it/s]


Training loss: 0.29826046762622493
Epoch 132/250


100%|██████████| 250/250 [00:00<00:00, 520.34it/s]


Training loss: 0.29785534156441834
Epoch 133/250


100%|██████████| 250/250 [00:00<00:00, 517.64it/s]


Training loss: 0.29745453876962813
Epoch 134/250


100%|██████████| 250/250 [00:00<00:00, 515.35it/s]


Training loss: 0.2970583663064144
Epoch 135/250


100%|██████████| 250/250 [00:00<00:00, 513.86it/s]


Training loss: 0.296666571474373
Epoch 136/250


100%|██████████| 250/250 [00:00<00:00, 511.17it/s]


Training loss: 0.2962787009975168
Epoch 137/250


100%|██████████| 250/250 [00:00<00:00, 467.73it/s]


Training loss: 0.29589490059462503
Epoch 138/250


100%|██████████| 250/250 [00:00<00:00, 494.27it/s]


Training loss: 0.29551523103154687
Epoch 139/250


100%|██████████| 250/250 [00:00<00:00, 374.58it/s]


Training loss: 0.2951390808794096
Epoch 140/250


100%|██████████| 250/250 [00:00<00:00, 516.78it/s]


Training loss: 0.29476746195125425
Epoch 141/250


100%|██████████| 250/250 [00:00<00:00, 506.48it/s]


Training loss: 0.29439917268602794
Epoch 142/250


100%|██████████| 250/250 [00:00<00:00, 501.55it/s]


Training loss: 0.2940347501978601
Epoch 143/250


100%|██████████| 250/250 [00:00<00:00, 530.92it/s]


Training loss: 0.2936740037465033
Epoch 144/250


100%|██████████| 250/250 [00:00<00:00, 496.09it/s]


Training loss: 0.2933172018334224
Epoch 145/250


100%|██████████| 250/250 [00:00<00:00, 457.92it/s]


Training loss: 0.2929635746938607
Epoch 146/250


100%|██████████| 250/250 [00:00<00:00, 469.74it/s]


Training loss: 0.29261367035937974
Epoch 147/250


100%|██████████| 250/250 [00:00<00:00, 473.00it/s]


Training loss: 0.2922674609583276
Epoch 148/250


100%|██████████| 250/250 [00:00<00:00, 477.07it/s]


Training loss: 0.2919246226340851
Epoch 149/250


100%|██████████| 250/250 [00:00<00:00, 469.26it/s]


Training loss: 0.29158495280689867
Epoch 150/250


100%|██████████| 250/250 [00:00<00:00, 503.17it/s]


Training loss: 0.29124842713054827
Epoch 151/250


100%|██████████| 250/250 [00:00<00:00, 454.47it/s]


Training loss: 0.2909153314178262
Epoch 152/250


100%|██████████| 250/250 [00:00<00:00, 457.09it/s]


Training loss: 0.29058560783458864
Epoch 153/250


100%|██████████| 250/250 [00:00<00:00, 514.91it/s]


Training loss: 0.2902583890272668
Epoch 154/250


100%|██████████| 250/250 [00:00<00:00, 413.38it/s]


Training loss: 0.2899351483895724
Epoch 155/250


100%|██████████| 250/250 [00:00<00:00, 429.76it/s]


Training loss: 0.28961454369342454
Epoch 156/250


100%|██████████| 250/250 [00:00<00:00, 378.12it/s]


Training loss: 0.28929726915324655
Epoch 157/250


100%|██████████| 250/250 [00:00<00:00, 296.12it/s]


Training loss: 0.28898223883555957
Epoch 158/250


100%|██████████| 250/250 [00:00<00:00, 265.58it/s]


Training loss: 0.28867065987495527
Epoch 159/250


100%|██████████| 250/250 [00:00<00:00, 444.17it/s]


Training loss: 0.2883618226538476
Epoch 160/250


100%|██████████| 250/250 [00:00<00:00, 504.07it/s]


Training loss: 0.28805617652831006
Epoch 161/250


100%|██████████| 250/250 [00:00<00:00, 515.34it/s]


Training loss: 0.28775319843794034
Epoch 162/250


100%|██████████| 250/250 [00:00<00:00, 527.14it/s]


Training loss: 0.28745278169976834
Epoch 163/250


100%|██████████| 250/250 [00:00<00:00, 478.25it/s]


Training loss: 0.28715505234230804
Epoch 164/250


100%|██████████| 250/250 [00:00<00:00, 497.60it/s]


Training loss: 0.2868599712619386
Epoch 165/250


100%|██████████| 250/250 [00:00<00:00, 520.36it/s]


Training loss: 0.2865677936606782
Epoch 166/250


100%|██████████| 250/250 [00:00<00:00, 518.10it/s]


Training loss: 0.286277987911565
Epoch 167/250


100%|██████████| 250/250 [00:00<00:00, 531.58it/s]


Training loss: 0.28599061964673167
Epoch 168/250


100%|██████████| 250/250 [00:00<00:00, 434.08it/s]


Training loss: 0.2857062215698092
Epoch 169/250


100%|██████████| 250/250 [00:00<00:00, 485.61it/s]


Training loss: 0.2854242706877607
Epoch 170/250


100%|██████████| 250/250 [00:00<00:00, 484.96it/s]


Training loss: 0.28514473030368037
Epoch 171/250


100%|██████████| 250/250 [00:00<00:00, 501.31it/s]


Training loss: 0.28486746917445016
Epoch 172/250


100%|██████████| 250/250 [00:00<00:00, 472.98it/s]


Training loss: 0.2845926054242224
Epoch 173/250


100%|██████████| 250/250 [00:00<00:00, 447.05it/s]


Training loss: 0.2843200542837689
Epoch 174/250


100%|██████████| 250/250 [00:00<00:00, 424.82it/s]


Training loss: 0.284050462094655
Epoch 175/250


100%|██████████| 250/250 [00:00<00:00, 343.02it/s]


Training loss: 0.283782724529628
Epoch 176/250


100%|██████████| 250/250 [00:00<00:00, 433.77it/s]


Training loss: 0.2835169582851343
Epoch 177/250


100%|██████████| 250/250 [00:00<00:00, 467.78it/s]


Training loss: 0.28325390914532883
Epoch 178/250


100%|██████████| 250/250 [00:00<00:00, 373.97it/s]


Training loss: 0.28299312273279814
Epoch 179/250


100%|██████████| 250/250 [00:00<00:00, 446.28it/s]


Training loss: 0.2827343458863923
Epoch 180/250


100%|██████████| 250/250 [00:00<00:00, 349.48it/s]


Training loss: 0.2824773441028821
Epoch 181/250


100%|██████████| 250/250 [00:00<00:00, 469.13it/s]


Training loss: 0.2822233270292647
Epoch 182/250


100%|██████████| 250/250 [00:00<00:00, 451.73it/s]


Training loss: 0.28197078158144906
Epoch 183/250


100%|██████████| 250/250 [00:00<00:00, 432.17it/s]


Training loss: 0.2817204766211582
Epoch 184/250


100%|██████████| 250/250 [00:00<00:00, 453.88it/s]


Training loss: 0.28147258134126324
Epoch 185/250


100%|██████████| 250/250 [00:00<00:00, 480.67it/s]


Training loss: 0.2812267222745068
Epoch 186/250


100%|██████████| 250/250 [00:00<00:00, 436.24it/s]


Training loss: 0.2809825590462115
Epoch 187/250


100%|██████████| 250/250 [00:00<00:00, 497.28it/s]


Training loss: 0.2807405165868852
Epoch 188/250


100%|██████████| 250/250 [00:00<00:00, 423.86it/s]


Training loss: 0.280500121673247
Epoch 189/250


100%|██████████| 250/250 [00:00<00:00, 488.12it/s]


Training loss: 0.28026251994150775
Epoch 190/250


100%|██████████| 250/250 [00:00<00:00, 493.70it/s]


Training loss: 0.2800257620567549
Epoch 191/250


100%|██████████| 250/250 [00:00<00:00, 464.82it/s]


Training loss: 0.2797913441580206
Epoch 192/250


100%|██████████| 250/250 [00:00<00:00, 500.39it/s]


Training loss: 0.27955913138294664
Epoch 193/250


100%|██████████| 250/250 [00:00<00:00, 493.39it/s]


Training loss: 0.279327973773017
Epoch 194/250


100%|██████████| 250/250 [00:00<00:00, 495.30it/s]


Training loss: 0.27909933526809666
Epoch 195/250


100%|██████████| 250/250 [00:00<00:00, 479.79it/s]


Training loss: 0.27887252994962497
Epoch 196/250


100%|██████████| 250/250 [00:00<00:00, 487.39it/s]


Training loss: 0.2786474234046775
Epoch 197/250


100%|██████████| 250/250 [00:00<00:00, 503.23it/s]


Training loss: 0.2784239319474637
Epoch 198/250


100%|██████████| 250/250 [00:00<00:00, 395.63it/s]


Training loss: 0.2782020722120022
Epoch 199/250


100%|██████████| 250/250 [00:00<00:00, 486.96it/s]


Training loss: 0.27798250371218947
Epoch 200/250


100%|██████████| 250/250 [00:00<00:00, 498.36it/s]


Training loss: 0.2777641224837689
Epoch 201/250


100%|██████████| 250/250 [00:00<00:00, 498.59it/s]


Training loss: 0.2775478291287325
Epoch 202/250


100%|██████████| 250/250 [00:00<00:00, 410.00it/s]


Training loss: 0.2773330489510995
Epoch 203/250


100%|██████████| 250/250 [00:00<00:00, 417.97it/s]


Training loss: 0.27711963504232406
Epoch 204/250


100%|██████████| 250/250 [00:00<00:00, 496.33it/s]


Training loss: 0.2769078816870593
Epoch 205/250


100%|██████████| 250/250 [00:00<00:00, 486.85it/s]


Training loss: 0.27669790119451987
Epoch 206/250


100%|██████████| 250/250 [00:00<00:00, 500.03it/s]


Training loss: 0.2764897623450907
Epoch 207/250


100%|██████████| 250/250 [00:00<00:00, 506.10it/s]


Training loss: 0.27628294172486706
Epoch 208/250


100%|██████████| 250/250 [00:00<00:00, 429.67it/s]


Training loss: 0.2760777576194329
Epoch 209/250


100%|██████████| 250/250 [00:00<00:00, 488.44it/s]


Training loss: 0.2758744041896835
Epoch 210/250


100%|██████████| 250/250 [00:00<00:00, 454.61it/s]


Training loss: 0.2756721138978818
Epoch 211/250


100%|██████████| 250/250 [00:00<00:00, 410.59it/s]


Training loss: 0.27547157228228053
Epoch 212/250


100%|██████████| 250/250 [00:00<00:00, 445.75it/s]


Training loss: 0.2752723450585845
Epoch 213/250


100%|██████████| 250/250 [00:00<00:00, 378.12it/s]


Training loss: 0.2750749172147076
Epoch 214/250


100%|██████████| 250/250 [00:00<00:00, 446.14it/s]


Training loss: 0.274879132145214
Epoch 215/250


100%|██████████| 250/250 [00:00<00:00, 476.62it/s]


Training loss: 0.27468397354945434
Epoch 216/250


100%|██████████| 250/250 [00:00<00:00, 428.19it/s]


Training loss: 0.27449115822255804
Epoch 217/250


100%|██████████| 250/250 [00:00<00:00, 452.63it/s]


Training loss: 0.27429902490958513
Epoch 218/250


100%|██████████| 250/250 [00:00<00:00, 408.55it/s]


Training loss: 0.2741087908256077
Epoch 219/250


100%|██████████| 250/250 [00:00<00:00, 510.97it/s]


Training loss: 0.27391974417516907
Epoch 220/250


100%|██████████| 250/250 [00:00<00:00, 487.57it/s]


Training loss: 0.27373237861909133
Epoch 221/250


100%|██████████| 250/250 [00:00<00:00, 467.67it/s]


Training loss: 0.2735460104430097
Epoch 222/250


100%|██████████| 250/250 [00:00<00:00, 445.38it/s]


Training loss: 0.2733607631986307
Epoch 223/250


100%|██████████| 250/250 [00:00<00:00, 423.42it/s]


Training loss: 0.2731772057562532
Epoch 224/250


100%|██████████| 250/250 [00:00<00:00, 496.15it/s]


Training loss: 0.2729952020072604
Epoch 225/250


100%|██████████| 250/250 [00:00<00:00, 456.86it/s]


Training loss: 0.27281444541039107
Epoch 226/250


100%|██████████| 250/250 [00:00<00:00, 501.71it/s]


Training loss: 0.2726348444772164
Epoch 227/250


100%|██████████| 250/250 [00:00<00:00, 470.36it/s]


Training loss: 0.2724563157408767
Epoch 228/250


100%|██████████| 250/250 [00:00<00:00, 417.50it/s]


Training loss: 0.2722794919136886
Epoch 229/250


100%|██████████| 250/250 [00:00<00:00, 394.38it/s]


Training loss: 0.2721035134279819
Epoch 230/250


100%|██████████| 250/250 [00:00<00:00, 354.68it/s]


Training loss: 0.27192871032415983
Epoch 231/250


100%|██████████| 250/250 [00:00<00:00, 428.95it/s]


Training loss: 0.2717557955831113
Epoch 232/250


100%|██████████| 250/250 [00:00<00:00, 473.16it/s]


Training loss: 0.27158369718112235
Epoch 233/250


100%|██████████| 250/250 [00:00<00:00, 478.49it/s]


Training loss: 0.2714129027240349
Epoch 234/250


100%|██████████| 250/250 [00:00<00:00, 469.87it/s]


Training loss: 0.27124293144617506
Epoch 235/250


100%|██████████| 250/250 [00:00<00:00, 479.13it/s]


Training loss: 0.271074441828616
Epoch 236/250


100%|██████████| 250/250 [00:00<00:00, 447.43it/s]


Training loss: 0.27090730209767355
Epoch 237/250


100%|██████████| 250/250 [00:00<00:00, 475.41it/s]


Training loss: 0.2707409335951394
Epoch 238/250


100%|██████████| 250/250 [00:00<00:00, 409.05it/s]


Training loss: 0.270575961374154
Epoch 239/250


100%|██████████| 250/250 [00:00<00:00, 439.39it/s]


Training loss: 0.2704123474593857
Epoch 240/250


100%|██████████| 250/250 [00:00<00:00, 448.80it/s]


Training loss: 0.27024936907197716
Epoch 241/250


100%|██████████| 250/250 [00:00<00:00, 392.88it/s]


Training loss: 0.27008799390286403
Epoch 242/250


100%|██████████| 250/250 [00:00<00:00, 445.35it/s]


Training loss: 0.26992758993282523
Epoch 243/250


100%|██████████| 250/250 [00:00<00:00, 463.67it/s]


Training loss: 0.26976831694598974
Epoch 244/250


100%|██████████| 250/250 [00:00<00:00, 479.84it/s]


Training loss: 0.2696097214199321
Epoch 245/250


100%|██████████| 250/250 [00:00<00:00, 465.70it/s]


Training loss: 0.26945247816463147
Epoch 246/250


100%|██████████| 250/250 [00:00<00:00, 374.00it/s]


Training loss: 0.2692969183502798
Epoch 247/250


100%|██████████| 250/250 [00:00<00:00, 439.33it/s]


Training loss: 0.2691416465857558
Epoch 248/250


100%|██████████| 250/250 [00:00<00:00, 465.84it/s]


Training loss: 0.268987564023394
Epoch 249/250


100%|██████████| 250/250 [00:00<00:00, 474.69it/s]


Training loss: 0.26883448295922524
Epoch 250/250


100%|██████████| 250/250 [00:00<00:00, 477.72it/s]

Training loss: 0.2686825186070905





In [13]:
model.eval()

with torch.no_grad():
    total_loss = 0
    correct = 0
    for i, (x, y_true) in enumerate(tqdm(test_dataloader)):
        y_pred = model(x)
        if y_true == torch.round(y_pred):
            correct += 1
        loss = criterion(y_pred, y_true)
        total_loss += loss.item()
    
    test_loss = total_loss / len(test_dataloader)
    writer.add_scalar("Loss/test_loss-X", test_loss)
    print('Test loss:', test_loss)
    print('Accuracy:', correct / len(test_dataloader)) # TODO: substituir por acurácia balanceada

100%|██████████| 2000/2000 [00:00<00:00, 2038.57it/s]

Test loss: 0.25924823939506075
Accuracy: 0.877





In [10]:
writer.close()