In [None]:
from torch.utils.data import DataLoader
from dataset import BaseDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from config import model_name
from tqdm import tqdm
import os
from pathlib import Path
from evaluate import evaluate
import importlib
import datetime
from view import Graph

try:
    Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)
    config = getattr(importlib.import_module('config'), f"{model_name}Config")
    print(f"model : {model_name}")
except AttributeError:
    print(f"{model_name} not included!")
    exit()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    print(f"mode : GPU_mode")
else:
    print(f"mode : CPU_mode")

In [3]:
class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience
        self.counter = 0
        self.best_loss = np.Inf

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            early_stop = False
            get_better = True
            self.counter = 0
            self.best_loss = val_loss
        else:
            get_better = False
            self.counter += 1
            if self.counter >= self.patience:
                early_stop = True
            else:
                early_stop = False

        return early_stop, get_better

In [4]:
def time_since(since):
    """
    Format elapsed time string.
    """
    now = time.time()
    elapsed_time = now - since
    return time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

In [5]:
AUC = []
HIT_10 = []
NDCG_5 = []
NDCG_10 = []
MRR = []
steps = []
epochs = []
train_avg_loss = []
train_recent_loss = []
batches = []

In [6]:
def train():

    try:
        pretrained_word_embedding = torch.from_numpy(
            np.load('../data/train/pretrained_word_embedding.npy')).float()
    except FileNotFoundError:
        pretrained_word_embedding = None
    
    model = Model(config, pretrained_word_embedding).to(device)

    print(model)

    dataset = BaseDataset('../data/train/behaviors_parsed.tsv',
                          '../data/train/news_parsed.tsv')

    print(f"Load training dataset with size {len(dataset)}.")

    dataloader = iter(
        DataLoader(dataset,
                   batch_size=config.batch_size,
                   shuffle=True,
                   num_workers=config.num_workers,
                   drop_last=True,
                   pin_memory=True))
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    
    start_time = time.time()
    loss_full = []
    exhaustion_count = 0
    step = 0
    epoch = 0
    early_stopping = EarlyStopping()

    for i in tqdm(range(1,config.num_epochs * len(dataset) // config.batch_size + 1),desc="Training"):
        try:
            minibatch = next(dataloader)
        except StopIteration:
            exhaustion_count += 1
            tqdm.write(f"Training data exhausted for {exhaustion_count} times after {i} batches, reuse the dataset.")
            
            epoch += 1
            epochs.append(epoch)
            model.eval()
            val_auc, val_mrr, val_ndcg5, val_ndcg10 = evaluate(model, '../data/val', config.num_workers, 200000)
            model.train()
            steps.append(step)
            AUC.append(val_auc)
            MRR.append(val_mrr)
            NDCG_5.append(val_ndcg5)
            NDCG_10.append(val_ndcg10)
            tqdm.write(
                f"Time {time_since(start_time)}, batches {i}, validation AUC: {val_auc:.4f}, validation MRR: {val_mrr:.4f}, validation nDCG@5: {val_ndcg5:.4f}, validation nDCG@10: {val_ndcg10:.4f}"
            )
            
            early_stop, get_better = early_stopping(-val_auc)
            if early_stop:
                tqdm.write('Early stop.')
                break
            elif get_better:
                try:
                    best_model_path = "./path/LSTUR/〇〇.pth"
                    torch.save(model.state_dict(), best_model_path)
                    print("Saved successfully!!!")
                except OSError as error:
                    print(f"OS error: {error}")
                    
            dataloader = iter(
                DataLoader(dataset,
                           batch_size=config.batch_size,
                           shuffle=True,
                           num_workers=config.num_workers,
                           drop_last=True,
                           pin_memory=True))
            minibatch = next(dataloader)

        step += 1
        
        if model_name == 'LSTUR':
            y_pred = model(minibatch["user"], minibatch["clicked_news_length"], minibatch["candidate_news"], minibatch["clicked_news"])

        y = torch.zeros(len(y_pred)).long().to(device)
        loss = criterion(y_pred, y)
        loss_full.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % config.num_batches_show_loss == 0:
            batches.append(i)
            train_avg_loss.append(np.mean(loss_full))
            train_recent_loss.append(np.mean(loss_full[-100:]))
            tqdm.write(
                f"Time {time_since(start_time)}, batches {i}, current loss {loss.item():.4f}, average loss: {np.mean(loss_full):.4f}, latest average loss: {np.mean(loss_full[-256:]):.4f}"
            )

In [None]:
if __name__ == '__main__':
    print(f'Training model {model_name}')
    train()

# Visualization of results
Graph(AUC,NDCG_5,NDCG_10,MRR,epochs,train_recent_loss,batches)