In [1]:
from torch import nn
import torch
from Models import LSTMClassifier
from utils import get_device, get_loaders, prediction_binary

In [2]:
train_loader, val_loader, test_loader = get_loaders()

In [3]:
device = get_device()
model = LSTMClassifier(47, 256, 1, device)
model.to(device)

print(device)
print(model)

cuda:0
LSTMClassifier(
  (rnn): LSTMCell(47, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (activation): Sigmoid()
)


In [4]:
opt = torch.optim.Adam(params=model.parameters(), lr=0.0001)
loss_fn = nn.BCELoss().to(device)
best = 0


In [5]:
TL = []
VL = []
VA = []

for epoch in range(100):
    train_loss = 0
    for i, data in enumerate(train_loader):

        model.train()
        opt.zero_grad()

        inputs, label = data
        inputs = inputs.to(torch.float32).to(device)
        label = label.to(torch.float32).to(device)

        pred = model(inputs)
        loss = loss_fn(pred[:, 0], label)

        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 25)
        opt.step()
        train_loss += loss.detach().cpu()

    val_loss, auc = prediction_binary(model, val_loader, loss_fn, device)

    if auc > best:
        best = auc
        torch.save(model, './ihm_model')

    train_loss /= len(train_loader)

    print(f'Epoch: {epoch:.1f} || Train Loss: {train_loss:.4f}  Val Loss: {val_loss:.4f}  Val AUROC: {auc:.4f}')
    TL.append(train_loss)
    VL.append(val_loss)
    VA.append(auc)


Epoch : 0.0 Train Loss 0.4450 Val Loss 0.3633 Val AUROC 0.6865
Epoch : 1.0 Train Loss 0.3749 Val Loss 0.3505 Val AUROC 0.7374
Epoch : 2.0 Train Loss 0.3622 Val Loss 0.3433 Val AUROC 0.7489
Epoch : 3.0 Train Loss 0.3561 Val Loss 0.3402 Val AUROC 0.7562
Epoch : 4.0 Train Loss 0.3513 Val Loss 0.3370 Val AUROC 0.7585
Epoch : 5.0 Train Loss 0.3482 Val Loss 0.3363 Val AUROC 0.7643
Epoch : 6.0 Train Loss 0.3440 Val Loss 0.3358 Val AUROC 0.7642
Epoch : 7.0 Train Loss 0.3420 Val Loss 0.3345 Val AUROC 0.7642
Epoch : 8.0 Train Loss 0.3402 Val Loss 0.3348 Val AUROC 0.7662
Epoch : 9.0 Train Loss 0.3374 Val Loss 0.3326 Val AUROC 0.7701
Epoch : 10.0 Train Loss 0.3359 Val Loss 0.3303 Val AUROC 0.7758
Epoch : 11.0 Train Loss 0.3329 Val Loss 0.3314 Val AUROC 0.7749
Epoch : 12.0 Train Loss 0.3297 Val Loss 0.3278 Val AUROC 0.7754
Epoch : 13.0 Train Loss 0.3282 Val Loss 0.3256 Val AUROC 0.7820
Epoch : 14.0 Train Loss 0.3252 Val Loss 0.3282 Val AUROC 0.7772
Epoch : 15.0 Train Loss 0.3230 Val Loss 0.3253 Val

In [6]:

model = torch.load('./ihm_model')
loss, auc = prediction_binary(model, test_loader, loss_fn, device)
print(auc)


0.8347456131625979
