In [186]:
import torch
import torch.nn as nn

from tqdm import tqdm
from torch.optim import SGD
from sklearn.datasets import make_moons
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [187]:
class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.coef = nn.Parameter(torch.randn((1, 2)))
        self.bias   = nn.Parameter(torch.randn(1))
    
    def forward(self, x):
        return 1 / (1 + torch.exp(-(self.bias + torch.sum(self.coef * x, dim = 1, keepdims = True))))

In [188]:
class Dataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y)

    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx].unsqueeze(dim = -1).to(torch.float64)

In [189]:
X, y = make_moons(n_samples = 10000)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = .2)

In [190]:
train_dataset = Dataset(X_train, y_train)
test_dataset  = Dataset(X_test,  y_test)
train_dataloader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
test_dataloader  = DataLoader(test_dataset,  batch_size = 1, shuffle = False)

In [191]:
model = LogisticRegression()
optimizer = SGD(model.parameters())
criterion = nn.BCELoss()

In [192]:
model.train()

n_epochs = 250
for epoch in range(n_epochs):
    print(f'Epoch {epoch + 1}/{n_epochs}')

    total_loss = 0
    for x, y_true in tqdm(train_dataloader):
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y_true)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    
    print('Training loss:', total_loss / len(train_dataloader))



Epoch 1/250


100%|██████████| 250/250 [00:00<00:00, 508.34it/s]


Training loss: 0.860366883761883
Epoch 2/250


100%|██████████| 250/250 [00:00<00:00, 534.14it/s]


Training loss: 0.8191184048474529
Epoch 3/250


100%|██████████| 250/250 [00:00<00:00, 446.70it/s]


Training loss: 0.7809012059798482
Epoch 4/250


100%|██████████| 250/250 [00:00<00:00, 541.89it/s]


Training loss: 0.7456233331811369
Epoch 5/250


100%|██████████| 250/250 [00:00<00:00, 494.81it/s]


Training loss: 0.71315773847221
Epoch 6/250


100%|██████████| 250/250 [00:00<00:00, 508.37it/s]


Training loss: 0.6833650680913506
Epoch 7/250


100%|██████████| 250/250 [00:00<00:00, 531.77it/s]


Training loss: 0.6560871479074959
Epoch 8/250


100%|██████████| 250/250 [00:00<00:00, 481.90it/s]


Training loss: 0.6311559235391353
Epoch 9/250


100%|██████████| 250/250 [00:00<00:00, 608.09it/s]


Training loss: 0.6083900203043038
Epoch 10/250


100%|██████████| 250/250 [00:00<00:00, 602.43it/s]


Training loss: 0.5876139156129668
Epoch 11/250


100%|██████████| 250/250 [00:00<00:00, 621.30it/s]


Training loss: 0.5686579615858265
Epoch 12/250


100%|██████████| 250/250 [00:00<00:00, 631.60it/s]


Training loss: 0.5513576631738228
Epoch 13/250


100%|██████████| 250/250 [00:00<00:00, 598.62it/s]


Training loss: 0.535560161478643
Epoch 14/250


100%|██████████| 250/250 [00:00<00:00, 439.67it/s]


Training loss: 0.5211201836986579
Epoch 15/250


100%|██████████| 250/250 [00:00<00:00, 457.28it/s]


Training loss: 0.507903355156372
Epoch 16/250


100%|██████████| 250/250 [00:00<00:00, 540.13it/s]


Training loss: 0.49578661105824934
Epoch 17/250


100%|██████████| 250/250 [00:00<00:00, 533.17it/s]


Training loss: 0.48466217233267944
Epoch 18/250


100%|██████████| 250/250 [00:00<00:00, 561.76it/s]


Training loss: 0.47442893420188026
Epoch 19/250


100%|██████████| 250/250 [00:00<00:00, 499.13it/s]


Training loss: 0.4649995594800824
Epoch 20/250


100%|██████████| 250/250 [00:00<00:00, 566.70it/s]


Training loss: 0.4562924776850954
Epoch 21/250


100%|██████████| 250/250 [00:00<00:00, 514.04it/s]


Training loss: 0.44823664168131505
Epoch 22/250


100%|██████████| 250/250 [00:00<00:00, 569.92it/s]


Training loss: 0.4407673369388497
Epoch 23/250


100%|██████████| 250/250 [00:00<00:00, 578.41it/s]


Training loss: 0.4338284047064383
Epoch 24/250


100%|██████████| 250/250 [00:00<00:00, 604.42it/s]


Training loss: 0.4273707774622223
Epoch 25/250


100%|██████████| 250/250 [00:00<00:00, 521.53it/s]


Training loss: 0.42134758798932204
Epoch 26/250


100%|██████████| 250/250 [00:00<00:00, 506.37it/s]


Training loss: 0.4157183538012509
Epoch 27/250


100%|██████████| 250/250 [00:00<00:00, 482.77it/s]


Training loss: 0.41044527154773414
Epoch 28/250


100%|██████████| 250/250 [00:00<00:00, 493.13it/s]


Training loss: 0.40549940707533283
Epoch 29/250


100%|██████████| 250/250 [00:00<00:00, 505.31it/s]


Training loss: 0.4008525299091844
Epoch 30/250


100%|██████████| 250/250 [00:00<00:00, 484.68it/s]


Training loss: 0.3964776359985035
Epoch 31/250


100%|██████████| 250/250 [00:00<00:00, 460.72it/s]


Training loss: 0.39235237391813876
Epoch 32/250


100%|██████████| 250/250 [00:00<00:00, 519.10it/s]


Training loss: 0.3884553444496454
Epoch 33/250


100%|██████████| 250/250 [00:00<00:00, 545.98it/s]


Training loss: 0.3847693734430087
Epoch 34/250


100%|██████████| 250/250 [00:00<00:00, 448.17it/s]


Training loss: 0.38127891453307994
Epoch 35/250


100%|██████████| 250/250 [00:00<00:00, 470.92it/s]


Training loss: 0.3779670040076729
Epoch 36/250


100%|██████████| 250/250 [00:00<00:00, 472.69it/s]


Training loss: 0.3748197929882455
Epoch 37/250


100%|██████████| 250/250 [00:00<00:00, 487.00it/s]


Training loss: 0.3718257773590521
Epoch 38/250


100%|██████████| 250/250 [00:00<00:00, 470.40it/s]


Training loss: 0.3689748206015767
Epoch 39/250


100%|██████████| 250/250 [00:00<00:00, 449.19it/s]


Training loss: 0.3662558023444943
Epoch 40/250


100%|██████████| 250/250 [00:00<00:00, 471.81it/s]


Training loss: 0.3636603941703742
Epoch 41/250


100%|██████████| 250/250 [00:00<00:00, 506.67it/s]


Training loss: 0.36118040924087597
Epoch 42/250


100%|██████████| 250/250 [00:00<00:00, 529.37it/s]


Training loss: 0.358806046765908
Epoch 43/250


100%|██████████| 250/250 [00:00<00:00, 465.29it/s]


Training loss: 0.3565325447797724
Epoch 44/250


100%|██████████| 250/250 [00:00<00:00, 596.77it/s]


Training loss: 0.35435235004926624
Epoch 45/250


100%|██████████| 250/250 [00:00<00:00, 449.45it/s]


Training loss: 0.35226061612280457
Epoch 46/250


100%|██████████| 250/250 [00:00<00:00, 658.08it/s]


Training loss: 0.3502508096474263
Epoch 47/250


100%|██████████| 250/250 [00:00<00:00, 659.70it/s]


Training loss: 0.34831756259473456
Epoch 48/250


100%|██████████| 250/250 [00:00<00:00, 609.44it/s]


Training loss: 0.34645773837012817
Epoch 49/250


100%|██████████| 250/250 [00:00<00:00, 568.94it/s]


Training loss: 0.3446664784584391
Epoch 50/250


100%|██████████| 250/250 [00:00<00:00, 523.18it/s]


Training loss: 0.3429408234112466
Epoch 51/250


100%|██████████| 250/250 [00:00<00:00, 645.58it/s]


Training loss: 0.3412757163458171
Epoch 52/250


100%|██████████| 250/250 [00:00<00:00, 620.35it/s]


Training loss: 0.3396684683803455
Epoch 53/250


100%|██████████| 250/250 [00:00<00:00, 553.15it/s]


Training loss: 0.3381154588901281
Epoch 54/250


100%|██████████| 250/250 [00:00<00:00, 574.57it/s]


Training loss: 0.3366143806328298
Epoch 55/250


100%|██████████| 250/250 [00:00<00:00, 575.37it/s]


Training loss: 0.3351630778414005
Epoch 56/250


100%|██████████| 250/250 [00:00<00:00, 498.39it/s]


Training loss: 0.3337576884483271
Epoch 57/250


100%|██████████| 250/250 [00:00<00:00, 493.36it/s]


Training loss: 0.3323963812705749
Epoch 58/250


100%|██████████| 250/250 [00:00<00:00, 412.95it/s]


Training loss: 0.331077287961771
Epoch 59/250


100%|██████████| 250/250 [00:00<00:00, 428.68it/s]


Training loss: 0.32979759015380794
Epoch 60/250


100%|██████████| 250/250 [00:00<00:00, 485.64it/s]


Training loss: 0.3285559483520338
Epoch 61/250


100%|██████████| 250/250 [00:00<00:00, 525.96it/s]


Training loss: 0.3273506112544016
Epoch 62/250


100%|██████████| 250/250 [00:00<00:00, 579.98it/s]


Training loss: 0.32617986852288694
Epoch 63/250


100%|██████████| 250/250 [00:00<00:00, 622.87it/s]


Training loss: 0.32504175033891036
Epoch 64/250


100%|██████████| 250/250 [00:00<00:00, 639.31it/s]


Training loss: 0.32393514978261095
Epoch 65/250


100%|██████████| 250/250 [00:00<00:00, 537.34it/s]


Training loss: 0.3228583647714595
Epoch 66/250


100%|██████████| 250/250 [00:00<00:00, 566.96it/s]


Training loss: 0.32181042451408093
Epoch 67/250


100%|██████████| 250/250 [00:00<00:00, 605.88it/s]


Training loss: 0.320789694028315
Epoch 68/250


100%|██████████| 250/250 [00:00<00:00, 460.28it/s]


Training loss: 0.3197955324766035
Epoch 69/250


100%|██████████| 250/250 [00:00<00:00, 521.52it/s]


Training loss: 0.3188266303327485
Epoch 70/250


100%|██████████| 250/250 [00:00<00:00, 543.95it/s]


Training loss: 0.31788179121628307
Epoch 71/250


100%|██████████| 250/250 [00:00<00:00, 475.83it/s]


Training loss: 0.3169598808460202
Epoch 72/250


100%|██████████| 250/250 [00:00<00:00, 583.31it/s]


Training loss: 0.31605995710986257
Epoch 73/250


100%|██████████| 250/250 [00:00<00:00, 529.96it/s]


Training loss: 0.3151814376370198
Epoch 74/250


100%|██████████| 250/250 [00:00<00:00, 536.61it/s]


Training loss: 0.314323322875272
Epoch 75/250


100%|██████████| 250/250 [00:00<00:00, 532.15it/s]


Training loss: 0.3134846520058768
Epoch 76/250


100%|██████████| 250/250 [00:00<00:00, 560.99it/s]


Training loss: 0.3126655911796954
Epoch 77/250


100%|██████████| 250/250 [00:00<00:00, 438.64it/s]


Training loss: 0.3118646671208474
Epoch 78/250


100%|██████████| 250/250 [00:00<00:00, 440.62it/s]


Training loss: 0.3110803831698538
Epoch 79/250


100%|██████████| 250/250 [00:00<00:00, 585.21it/s]


Training loss: 0.31031355846096487
Epoch 80/250


100%|██████████| 250/250 [00:00<00:00, 515.48it/s]


Training loss: 0.30956306191947874
Epoch 81/250


100%|██████████| 250/250 [00:00<00:00, 561.64it/s]


Training loss: 0.30882789946016975
Epoch 82/250


100%|██████████| 250/250 [00:00<00:00, 537.23it/s]


Training loss: 0.30810801915206476
Epoch 83/250


100%|██████████| 250/250 [00:00<00:00, 579.64it/s]


Training loss: 0.3074028183463898
Epoch 84/250


100%|██████████| 250/250 [00:00<00:00, 489.92it/s]


Training loss: 0.3067114129320826
Epoch 85/250


100%|██████████| 250/250 [00:00<00:00, 507.40it/s]


Training loss: 0.3060337347194823
Epoch 86/250


100%|██████████| 250/250 [00:00<00:00, 601.12it/s]


Training loss: 0.30536949666497415
Epoch 87/250


100%|██████████| 250/250 [00:00<00:00, 659.16it/s]


Training loss: 0.30471756650647025
Epoch 88/250


100%|██████████| 250/250 [00:00<00:00, 468.38it/s]


Training loss: 0.3040781359991921
Epoch 89/250


100%|██████████| 250/250 [00:00<00:00, 622.86it/s]


Training loss: 0.3034505102563201
Epoch 90/250


100%|██████████| 250/250 [00:00<00:00, 562.54it/s]


Training loss: 0.3028342682907074
Epoch 91/250


100%|██████████| 250/250 [00:00<00:00, 538.57it/s]


Training loss: 0.30222894355922525
Epoch 92/250


100%|██████████| 250/250 [00:00<00:00, 515.12it/s]


Training loss: 0.3016346443046869
Epoch 93/250


100%|██████████| 250/250 [00:00<00:00, 678.58it/s]


Training loss: 0.301050505244054
Epoch 94/250


100%|██████████| 250/250 [00:00<00:00, 570.01it/s]


Training loss: 0.3004766454621712
Epoch 95/250


100%|██████████| 250/250 [00:00<00:00, 530.47it/s]


Training loss: 0.2999132071732845
Epoch 96/250


100%|██████████| 250/250 [00:00<00:00, 483.21it/s]


Training loss: 0.29935870494719835
Epoch 97/250


100%|██████████| 250/250 [00:00<00:00, 507.31it/s]


Training loss: 0.2988132774339141
Epoch 98/250


100%|██████████| 250/250 [00:00<00:00, 625.93it/s]


Training loss: 0.2982772028988864
Epoch 99/250


100%|██████████| 250/250 [00:00<00:00, 492.05it/s]


Training loss: 0.29774962943300093
Epoch 100/250


100%|██████████| 250/250 [00:00<00:00, 644.34it/s]


Training loss: 0.297230490879452
Epoch 101/250


100%|██████████| 250/250 [00:00<00:00, 639.40it/s]


Training loss: 0.29671966634432684
Epoch 102/250


100%|██████████| 250/250 [00:00<00:00, 579.22it/s]


Training loss: 0.29621717830851857
Epoch 103/250


100%|██████████| 250/250 [00:00<00:00, 536.38it/s]


Training loss: 0.2957219503339693
Epoch 104/250


100%|██████████| 250/250 [00:00<00:00, 557.67it/s]


Training loss: 0.29523385769825483
Epoch 105/250


100%|██████████| 250/250 [00:00<00:00, 609.40it/s]


Training loss: 0.2947537017037237
Epoch 106/250


100%|██████████| 250/250 [00:00<00:00, 524.17it/s]


Training loss: 0.2942808333870164
Epoch 107/250


100%|██████████| 250/250 [00:00<00:00, 488.51it/s]


Training loss: 0.2938149526810265
Epoch 108/250


100%|██████████| 250/250 [00:00<00:00, 595.33it/s]


Training loss: 0.29335559338241424
Epoch 109/250


100%|██████████| 250/250 [00:00<00:00, 587.25it/s]


Training loss: 0.2929028103407941
Epoch 110/250


100%|██████████| 250/250 [00:00<00:00, 604.85it/s]


Training loss: 0.2924569401068564
Epoch 111/250


100%|██████████| 250/250 [00:00<00:00, 620.00it/s]


Training loss: 0.29201660855071787
Epoch 112/250


100%|██████████| 250/250 [00:00<00:00, 650.66it/s]


Training loss: 0.29158280946041604
Epoch 113/250


100%|██████████| 250/250 [00:00<00:00, 669.69it/s]


Training loss: 0.2911547552349186
Epoch 114/250


100%|██████████| 250/250 [00:00<00:00, 666.00it/s]


Training loss: 0.2907325399301904
Epoch 115/250


100%|██████████| 250/250 [00:00<00:00, 650.28it/s]


Training loss: 0.2903164437998471
Epoch 116/250


100%|██████████| 250/250 [00:00<00:00, 657.75it/s]


Training loss: 0.2899056601284717
Epoch 117/250


100%|██████████| 250/250 [00:00<00:00, 673.83it/s]


Training loss: 0.2895001356980618
Epoch 118/250


100%|██████████| 250/250 [00:00<00:00, 732.32it/s]


Training loss: 0.2890999000675506
Epoch 119/250


100%|██████████| 250/250 [00:00<00:00, 690.81it/s]


Training loss: 0.28870491952425603
Epoch 120/250


100%|██████████| 250/250 [00:00<00:00, 682.90it/s]


Training loss: 0.2883151650927008
Epoch 121/250


100%|██████████| 250/250 [00:00<00:00, 652.40it/s]


Training loss: 0.28793021543306685
Epoch 122/250


100%|██████████| 250/250 [00:00<00:00, 686.84it/s]


Training loss: 0.28754975858708165
Epoch 123/250


100%|██████████| 250/250 [00:00<00:00, 589.65it/s]


Training loss: 0.28717408610580686
Epoch 124/250


100%|██████████| 250/250 [00:00<00:00, 666.61it/s]


Training loss: 0.2868032721258781
Epoch 125/250


100%|██████████| 250/250 [00:00<00:00, 685.26it/s]


Training loss: 0.28643697325210893
Epoch 126/250


100%|██████████| 250/250 [00:00<00:00, 674.23it/s]


Training loss: 0.2860753521826978
Epoch 127/250


100%|██████████| 250/250 [00:00<00:00, 658.42it/s]


Training loss: 0.2857171926582277
Epoch 128/250


100%|██████████| 250/250 [00:00<00:00, 660.49it/s]


Training loss: 0.2853639825265095
Epoch 129/250


100%|██████████| 250/250 [00:00<00:00, 672.17it/s]


Training loss: 0.28501491476562146
Epoch 130/250


100%|██████████| 250/250 [00:00<00:00, 672.09it/s]


Training loss: 0.2846691785588136
Epoch 131/250


100%|██████████| 250/250 [00:00<00:00, 675.93it/s]


Training loss: 0.28432802347242697
Epoch 132/250


100%|██████████| 250/250 [00:00<00:00, 667.98it/s]


Training loss: 0.2839908188127149
Epoch 133/250


100%|██████████| 250/250 [00:00<00:00, 668.08it/s]


Training loss: 0.2836573808093003
Epoch 134/250


100%|██████████| 250/250 [00:00<00:00, 659.92it/s]


Training loss: 0.2833273837603958
Epoch 135/250


100%|██████████| 250/250 [00:00<00:00, 688.12it/s]


Training loss: 0.28300117685940024
Epoch 136/250


100%|██████████| 250/250 [00:00<00:00, 659.70it/s]


Training loss: 0.2826783956221375
Epoch 137/250


100%|██████████| 250/250 [00:00<00:00, 672.90it/s]


Training loss: 0.2823593018641987
Epoch 138/250


100%|██████████| 250/250 [00:00<00:00, 674.08it/s]


Training loss: 0.2820436346690516
Epoch 139/250


100%|██████████| 250/250 [00:00<00:00, 679.05it/s]


Training loss: 0.28173131891045416
Epoch 140/250


100%|██████████| 250/250 [00:00<00:00, 653.76it/s]


Training loss: 0.28142217733769126
Epoch 141/250


100%|██████████| 250/250 [00:00<00:00, 656.64it/s]


Training loss: 0.28111657093419246
Epoch 142/250


100%|██████████| 250/250 [00:00<00:00, 651.13it/s]


Training loss: 0.28081394057225245
Epoch 143/250


100%|██████████| 250/250 [00:00<00:00, 686.31it/s]


Training loss: 0.2805148422189449
Epoch 144/250


100%|██████████| 250/250 [00:00<00:00, 633.76it/s]


Training loss: 0.28021894727145763
Epoch 145/250


100%|██████████| 250/250 [00:00<00:00, 691.94it/s]


Training loss: 0.2799254549541652
Epoch 146/250


100%|██████████| 250/250 [00:00<00:00, 669.10it/s]


Training loss: 0.2796348859016339
Epoch 147/250


100%|██████████| 250/250 [00:00<00:00, 675.59it/s]


Training loss: 0.279347671239221
Epoch 148/250


100%|██████████| 250/250 [00:00<00:00, 681.96it/s]


Training loss: 0.27906323998883004
Epoch 149/250


100%|██████████| 250/250 [00:00<00:00, 397.72it/s]


Training loss: 0.278782018193367
Epoch 150/250


100%|██████████| 250/250 [00:00<00:00, 480.44it/s]


Training loss: 0.27850267235078563
Epoch 151/250


100%|██████████| 250/250 [00:00<00:00, 616.79it/s]


Training loss: 0.27822700814240503
Epoch 152/250


100%|██████████| 250/250 [00:00<00:00, 677.50it/s]


Training loss: 0.27795311202909667
Epoch 153/250


100%|██████████| 250/250 [00:00<00:00, 668.60it/s]


Training loss: 0.27768186017143676
Epoch 154/250


100%|██████████| 250/250 [00:00<00:00, 679.56it/s]


Training loss: 0.27741405721332224
Epoch 155/250


100%|██████████| 250/250 [00:00<00:00, 674.41it/s]


Training loss: 0.2771486115511874
Epoch 156/250


100%|██████████| 250/250 [00:00<00:00, 661.27it/s]


Training loss: 0.27688551266557376
Epoch 157/250


100%|██████████| 250/250 [00:00<00:00, 655.73it/s]


Training loss: 0.2766246221093809
Epoch 158/250


100%|██████████| 250/250 [00:00<00:00, 671.78it/s]


Training loss: 0.2763664639071993
Epoch 159/250


100%|██████████| 250/250 [00:00<00:00, 631.32it/s]


Training loss: 0.2761100966114572
Epoch 160/250


100%|██████████| 250/250 [00:00<00:00, 680.27it/s]


Training loss: 0.27585683446179
Epoch 161/250


100%|██████████| 250/250 [00:00<00:00, 663.77it/s]


Training loss: 0.2756055458997864
Epoch 162/250


100%|██████████| 250/250 [00:00<00:00, 660.25it/s]


Training loss: 0.27535695939980526
Epoch 163/250


100%|██████████| 250/250 [00:00<00:00, 664.31it/s]


Training loss: 0.2751099901383817
Epoch 164/250


100%|██████████| 250/250 [00:00<00:00, 686.07it/s]


Training loss: 0.27486558902606034
Epoch 165/250


100%|██████████| 250/250 [00:00<00:00, 676.66it/s]


Training loss: 0.2746236213999159
Epoch 166/250


100%|██████████| 250/250 [00:00<00:00, 663.73it/s]


Training loss: 0.27438326426062465
Epoch 167/250


100%|██████████| 250/250 [00:00<00:00, 671.97it/s]


Training loss: 0.274145552681891
Epoch 168/250


100%|██████████| 250/250 [00:00<00:00, 657.62it/s]


Training loss: 0.2739093796141651
Epoch 169/250


100%|██████████| 250/250 [00:00<00:00, 647.63it/s]


Training loss: 0.273675296307217
Epoch 170/250


100%|██████████| 250/250 [00:00<00:00, 658.65it/s]


Training loss: 0.27344371829954184
Epoch 171/250


100%|██████████| 250/250 [00:00<00:00, 650.55it/s]


Training loss: 0.2732137651675907
Epoch 172/250


100%|██████████| 250/250 [00:00<00:00, 641.08it/s]


Training loss: 0.2729861411389263
Epoch 173/250


100%|██████████| 250/250 [00:00<00:00, 680.87it/s]


Training loss: 0.2727598506538922
Epoch 174/250


100%|██████████| 250/250 [00:00<00:00, 684.76it/s]


Training loss: 0.27253583648791313
Epoch 175/250


100%|██████████| 250/250 [00:00<00:00, 612.85it/s]


Training loss: 0.27231377739348933
Epoch 176/250


100%|██████████| 250/250 [00:00<00:00, 680.92it/s]


Training loss: 0.27209367833491693
Epoch 177/250


100%|██████████| 250/250 [00:00<00:00, 667.61it/s]


Training loss: 0.2718753760927918
Epoch 178/250


100%|██████████| 250/250 [00:00<00:00, 675.12it/s]


Training loss: 0.2716586165176411
Epoch 179/250


100%|██████████| 250/250 [00:00<00:00, 661.93it/s]


Training loss: 0.27144340458745797
Epoch 180/250


100%|██████████| 250/250 [00:00<00:00, 654.23it/s]


Training loss: 0.2712306090200731
Epoch 181/250


100%|██████████| 250/250 [00:00<00:00, 676.78it/s]


Training loss: 0.27101952872266405
Epoch 182/250


100%|██████████| 250/250 [00:00<00:00, 655.64it/s]


Training loss: 0.2708097954244045
Epoch 183/250


100%|██████████| 250/250 [00:00<00:00, 625.70it/s]


Training loss: 0.27060205875451104
Epoch 184/250


100%|██████████| 250/250 [00:00<00:00, 647.97it/s]


Training loss: 0.2703955586904307
Epoch 185/250


100%|██████████| 250/250 [00:00<00:00, 606.79it/s]


Training loss: 0.2701910124932288
Epoch 186/250


100%|██████████| 250/250 [00:00<00:00, 660.04it/s]


Training loss: 0.2699878624116281
Epoch 187/250


100%|██████████| 250/250 [00:00<00:00, 648.53it/s]


Training loss: 0.2697864581634483
Epoch 188/250


100%|██████████| 250/250 [00:00<00:00, 666.27it/s]


Training loss: 0.2695867328034492
Epoch 189/250


100%|██████████| 250/250 [00:00<00:00, 670.71it/s]


Training loss: 0.26938862747166076
Epoch 190/250


100%|██████████| 250/250 [00:00<00:00, 744.97it/s]


Training loss: 0.269192168093476
Epoch 191/250


100%|██████████| 250/250 [00:00<00:00, 686.24it/s]


Training loss: 0.2689966494495365
Epoch 192/250


100%|██████████| 250/250 [00:00<00:00, 686.80it/s]


Training loss: 0.2688033418177068
Epoch 193/250


100%|██████████| 250/250 [00:00<00:00, 667.92it/s]


Training loss: 0.2686110816659208
Epoch 194/250


100%|██████████| 250/250 [00:00<00:00, 649.02it/s]


Training loss: 0.26842045577865503
Epoch 195/250


100%|██████████| 250/250 [00:00<00:00, 645.66it/s]


Training loss: 0.26823128915301486
Epoch 196/250


100%|██████████| 250/250 [00:00<00:00, 661.32it/s]


Training loss: 0.2680435498622374
Epoch 197/250


100%|██████████| 250/250 [00:00<00:00, 683.18it/s]


Training loss: 0.2678575678731768
Epoch 198/250


100%|██████████| 250/250 [00:00<00:00, 670.23it/s]


Training loss: 0.2676723493773529
Epoch 199/250


100%|██████████| 250/250 [00:00<00:00, 674.31it/s]


Training loss: 0.2674889247333397
Epoch 200/250


100%|██████████| 250/250 [00:00<00:00, 696.24it/s]


Training loss: 0.26730680572176724
Epoch 201/250


100%|██████████| 250/250 [00:00<00:00, 681.36it/s]


Training loss: 0.2671257332804231
Epoch 202/250


100%|██████████| 250/250 [00:00<00:00, 671.80it/s]


Training loss: 0.2669465824990272
Epoch 203/250


100%|██████████| 250/250 [00:00<00:00, 672.25it/s]


Training loss: 0.26676846058921116
Epoch 204/250


100%|██████████| 250/250 [00:00<00:00, 678.93it/s]


Training loss: 0.2665914677265215
Epoch 205/250


100%|██████████| 250/250 [00:00<00:00, 665.06it/s]


Training loss: 0.26641594832637966
Epoch 206/250


100%|██████████| 250/250 [00:00<00:00, 631.04it/s]


Training loss: 0.26624164310349685
Epoch 207/250


100%|██████████| 250/250 [00:00<00:00, 622.89it/s]


Training loss: 0.26606947191729213
Epoch 208/250


100%|██████████| 250/250 [00:00<00:00, 585.36it/s]


Training loss: 0.2658969135012907
Epoch 209/250


100%|██████████| 250/250 [00:00<00:00, 653.78it/s]


Training loss: 0.26572659859799663
Epoch 210/250


100%|██████████| 250/250 [00:00<00:00, 663.19it/s]


Training loss: 0.26555762289697576
Epoch 211/250


100%|██████████| 250/250 [00:00<00:00, 649.51it/s]


Training loss: 0.26538948011494634
Epoch 212/250


100%|██████████| 250/250 [00:00<00:00, 660.75it/s]


Training loss: 0.2652227670183281
Epoch 213/250


100%|██████████| 250/250 [00:00<00:00, 619.43it/s]


Training loss: 0.2650570609873157
Epoch 214/250


100%|██████████| 250/250 [00:00<00:00, 624.90it/s]


Training loss: 0.26489262106069633
Epoch 215/250


100%|██████████| 250/250 [00:00<00:00, 660.84it/s]


Training loss: 0.2647293087172784
Epoch 216/250


100%|██████████| 250/250 [00:00<00:00, 657.04it/s]


Training loss: 0.26456738393754914
Epoch 217/250


100%|██████████| 250/250 [00:00<00:00, 670.18it/s]


Training loss: 0.26440635829733194
Epoch 218/250


100%|██████████| 250/250 [00:00<00:00, 667.56it/s]


Training loss: 0.2642463228955471
Epoch 219/250


 50%|█████     | 125/250 [00:00<00:00, 623.12it/s]

In [None]:
model.eval()

with torch.no_grad():
    total_loss = 0
    correct = 0
    for x, y_true in tqdm(test_dataloader):
        y_pred = model(x)
        if torch.round(y_true) == torch.round(y_pred):
            correct += 1
        loss = criterion(y_pred, y_true)
        total_loss += loss.item()
        
    print('Test loss:', total_loss / len(test_dataloader))
    print('Accuracy:', correct / len(test_dataloader)) # TODO: substituir por acurácia balanceada

  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [00:01<00:00, 1832.20it/s]

Test loss: 0.23434035190377164
Accuracy: 0.889



