In [1]:
import numpy as np
import torch
import torch.nn as nn
from utils import seed_torch, print_metrics_binary
import pickle
from data_utils import get_dataloader
from model.lstm import LSTM
from tqdm import tqdm
import datetime
import logging
seed_torch()

In [2]:
dataset_dict = pickle.load(open('/home/common/mover_data/surginf_cleaned/dataset_dict.pkl', 'rb'))


batch_size = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[logging.StreamHandler(),
                              logging.FileHandler("./log/train_lstm_{}.log".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))])

train_loader = get_dataloader(dataset_dict['train'], batch_size=batch_size, shuffle=True)
val_loader = get_dataloader(dataset_dict['val'], batch_size=batch_size, shuffle=False)
test_loader = get_dataloader(dataset_dict['test'], batch_size=batch_size, shuffle=False)

logging.info('%d, %d, %d'%(len(train_loader), len(val_loader), len(test_loader)))

2023-09-08 02:22:39,950 - INFO - 78, 12, 23


In [3]:
# Get a sample batch
for batch in train_loader:
    print(batch[0].shape, batch[1].shape, batch[2].shape)
    break

torch.Size([512, 980, 34]) torch.Size([512]) torch.Size([512, 980])


In [4]:
input_size = 34
hidden_size = 64
learning_rate = 1e-3

model = LSTM(input_dim=input_size, hidden_dim=hidden_size).to(device)
loss_fn = model.get_loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
epochs = 100
best_score = 0

for epoch in range(epochs):
    model.train()
    for i, batch in enumerate(train_loader):
        X, y, mask = batch
        X = X.to(device)
        y = y.to(device).unsqueeze(-1)
        mask = mask.to(device)
        optimizer.zero_grad()
        loss = loss_fn(X, y, mask)
        
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
            logging.info('Epoch: {}, Step: {}/{}, Loss: {}'.format(epoch, i, len(train_loader), loss.item()))
    model.eval()
    with torch.no_grad():
        y_pred = []
        y_true = []
        for batch in tqdm(val_loader):
            X, y, mask = batch
            X = X.to(device)
            y = y.to(device)
            mask = mask.to(device)
            y_pred.append(model(X, mask))
            y_true.append(y)
        y_pred = torch.cat(y_pred, dim=0).squeeze().cpu().numpy()
        y_true = torch.cat(y_true, dim=0).squeeze().cpu().numpy()
        ret = print_metrics_binary(y_true, y_pred, verbose=0)
        if ret['auprc'] > best_score:
            best_score = ret['auprc'] 
            torch.save(model.state_dict(), 'weights/lstm_model.pt')
            logging.info('Saved model')

2023-09-08 02:22:48,670 - INFO - Epoch: 0, Step: 0/78, Loss: 0.7483769655227661


In [None]:
model.eval()
model.load_state_dict(torch.load('weights/lstm_model.pt'))
with torch.no_grad():
    y_pred = []
    y_true = []
    for batch in tqdm(test_loader):
        X, y, mask = batch
        X = X.to(device)
        y = y.to(device)
        mask = mask.to(device)
        y_pred.append(model(X, mask))
        y_true.append(y)
    y_pred = torch.cat(y_pred, dim=0).squeeze().cpu().numpy()
    y_true = torch.cat(y_true, dim=0).squeeze().cpu().numpy()
    ret = print_metrics_binary(y_true, y_pred, verbose=1)
    logging.info('{}'%ret)