In [10]:
import torch
import torch.nn as nn
from tqdm import tqdm

from utils import Config, set_random_seed
from dataset import get_dataset, train_test_split, RIIIDDataset
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, accuracy_score
from model import SaintPlusTransformer

%load_ext autoreload
%autoreload 2

set_random_seed(0)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
config = Config('config.yaml')
config.__dict__

{'DATASOURCE': 'data/train_sample.csv',
 'TR_FRAC': 0.9,
 'MAX_LEN': 100,
 'BATCH_SIZE': 64}

In [26]:
data = get_dataset(config)
print(len(data))

980084


In [27]:
tr_data, va_data = train_test_split(data=data, config=config)

tr_data = RIIIDDataset(dataset=tr_data, config=config)
va_data = RIIIDDataset(dataset=va_data, config=config)

tr_dataloader = DataLoader(dataset=tr_data, batch_size=config.BATCH_SIZE, shuffle=True)
va_dataloader = DataLoader(dataset=va_data, batch_size=config.BATCH_SIZE, shuffle=True)

In [28]:
tr_data.__getitem__(123)

{'ex': tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0, 7901, 7877,  176,
         1279, 2066, 2065, 2064, 3364, 3366, 3365, 2949, 2947, 2948, 2594, 2595,
         2596, 4493, 4121, 4697, 6117, 6174, 6371, 6880, 6881, 6878, 6879, 7219,
         7218, 7217, 7220], dtype=torch.int32),
 'ac': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2,
         1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2

In [29]:
for tr_batch in tqdm(tr_dataloader):
    pass

100%|██████████| 54/54 [00:01<00:00, 34.56it/s]


# Train

In [37]:
data['content_id'].max()

13522

In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SaintPlusTransformer(
    question_vocab_size=13523+1,
    answer_corr_vocab_size=2+1,
    max_len=config.MAX_LEN-1,
    embed_size=128,
    hidden_size=4*128,
    dropout=0.1,
    heads=8,
    N=1
).to(device)

In [48]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4, eps=1e-9)
loss_fn = nn.CrossEntropyLoss(ignore_index=0).to(device)

for epoch in range(1):
    
    tr_true = []
    tr_pred_label = []
    tr_pred_probs = []

    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(tr_dataloader, desc=f"Processing Epoch {epoch:02d}")
    for tr_batch in batch_iterator:
        
        # enc
        ex = tr_batch['ex'].to(device) 

        # dec
        ac = tr_batch['ac'].to(device) 

        # label
        label = tr_batch['label'].to(device).long()

        # masks
        src_mask = tr_batch['src_mask'].to(device)
        tgt_mask = tr_batch['tgt_mask'].to(device)

        preds = model(
            src=ex, 
            tgt=ac,
            src_mask=src_mask, 
            tgt_mask=tgt_mask,
        )
        
        _, pred_labels = torch.max(preds[:,-1], dim=1)
        pred_probs = torch.softmax(preds[:,-1][:,1:3], dim=-1)
        
        loss = loss_fn(
            target=label.view(-1), # (batch_size * max_len)
            input=preds.view(-1, 3), # (batch_size * max_len, vocab_size)
        )

        batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

        # backpropagate the loss
        loss.backward()

        # update the weights
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        
        tr_true.extend(label[:,-1].cpu().detach().numpy())
        tr_pred_label.extend(pred_labels.cpu().detach().numpy())
        tr_pred_probs.extend(pred_probs[:,1].cpu().detach().numpy())
    
    print(f'TRAIN ACCURACY: {accuracy_score(tr_true, tr_pred_label):.4f}')
    print(f'TRAIN ROC-AUC: {roc_auc_score(tr_true, tr_pred_probs):.4f}')
    
    ########################
    ########################

    va_true = []
    va_pred_label = []
    va_pred_probs = []

    model.eval()
    with torch.no_grad():
        for va_batch in tqdm(va_dataloader):

            # enc
            ex = va_batch['ex'].to(device)

            # dec
            ac = va_batch['ac'].to(device) 

            # label
            label = va_batch['label'].to(device).long()

            # masks
            src_mask = va_batch['src_mask'].to(device)
            tgt_mask = va_batch['tgt_mask'].to(device)

            preds = model(
                src=ex, 
                tgt=ac,
                src_mask=src_mask, 
                tgt_mask=tgt_mask,
            )

            _, pred_labels = torch.max(preds[:,-1], dim=1)
            pred_probs = torch.softmax(preds[:,-1][:,1:3], dim=-1)

            va_true.extend(label[:,-1].cpu().detach().numpy())
            va_pred_label.extend(pred_labels.cpu().detach().numpy())
            va_pred_probs.extend(pred_probs[:,1].cpu().detach().numpy())

        print(f'VALID ACCURACY: {accuracy_score(va_true, va_pred_label):.4f}')
        print(f'VALID ROC-AUC: {roc_auc_score(va_true, va_pred_probs):.4f}')

Processing Epoch 00:   0%|          | 0/54 [00:00<?, ?it/s]

Processing Epoch 00: 100%|██████████| 54/54 [00:32<00:00,  1.65it/s, loss=0.697]


TRAIN ACCURACY: 0.5498
TRAIN ROCAUC: 0.5663


100%|██████████| 6/6 [00:01<00:00,  5.96it/s]

VALID ACCURACY: 0.6066
VALID ROCAUC: 0.6487



