In [1]:
import torch
from torch import nn
from data import get_batch

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def train(model, loss_fn, optim, iters=25):
    model.train()
    Loss = []
    for _ in range(iters):
        input, output = get_batch()
        out = model(input)

        loss = loss_fn(out, output)
        Loss.append(loss.item())

        model.zero_grad()
        loss.backward()
        optim.step()
    return sum(Loss) / len(Loss)


def test(model, loss_fn, iters=5):
    model.eval()
    Loss = []
    for _ in range(iters):
        input, output = get_batch(test=True)
        out = model(input)

        loss = loss_fn(out, output)
        Loss.append(loss.item())
    return sum(Loss) / len(Loss)

In [4]:
model = nn.Sequential(
    nn.Linear(384, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 64),
    nn.Sigmoid()
).to(device)

loss_fn = nn.BCELoss()
optim = torch.optim.Adam(model.parameters(), 3e-4)

In [5]:
best = 1

In [6]:
last_save = 0
epoch = 0
while epoch - last_save < 300:
    epoch += 1
    print(f'Epoch {epoch}...')

    loss = train(model, loss_fn, optim)
    print(f'Train loss: {loss}')

    loss = test(model, loss_fn)
    print(f'Test loss: {loss}')

    if loss < best:
        torch.save(model,'../model.pt')
        best = loss
        last_save = epoch
        print('Model saved')

print(f'Last save on {last_save} epoch')

Epoch 1...
Train loss: 0.590489696264267
Test loss: 0.30027714371681213
Model saved
Epoch 2...
Train loss: 0.19218393564224243
Test loss: 0.16832111775875092
Model saved
Epoch 3...
Train loss: 0.15260956764221192
Test loss: 0.14461359977722169
Model saved
Epoch 4...
Train loss: 0.14179442703723907
Test loss: 0.13934496641159058
Model saved
Epoch 5...
Train loss: 0.14035926043987274
Test loss: 0.13655787110328674
Model saved
Epoch 6...
Train loss: 0.13877664208412172
Test loss: 0.13705521523952485
Epoch 7...
Train loss: 0.13790787637233734
Test loss: 0.13594107329845428
Model saved
Epoch 8...
Train loss: 0.1371196061372757
Test loss: 0.13309445977210999
Model saved
Epoch 9...
Train loss: 0.13400350749492645
Test loss: 0.13375204801559448
Epoch 10...
Train loss: 0.1325165206193924
Test loss: 0.1297953099012375
Model saved
Epoch 11...
Train loss: 0.1308620947599411
Test loss: 0.12893966734409332
Model saved
Epoch 12...
Train loss: 0.13107216775417327
Test loss: 0.13018912672996522
Epoch 1

In [8]:
model = torch.load('../model.pt')
print(test(model, loss_fn))

0.12071859836578369
