## Global Settings and Imports

In [1]:
# jupyter notebook에서 import 해서 쓰는 .py 모듈의 코드가 변경될 시, 변동 사항을 자동으로 반영해주는 기능 ON
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import yaml
from dotmap import DotMap
from os import path
import numpy as np
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
from models.lstur import LSTUR
from models.nrms import NRMS
from models.naml import NAML
from models.naml_simple import NAML_Simple
from models.sentirec import SENTIREC
from models.robust_sentirec import ROBUST_SENTIREC
from data.dataset import BaseDataset
from tqdm import tqdm

## Prepare parameters

In [3]:
args = argparse.Namespace(
    config = "config/model/nrms/exp1.yaml",
    resume = None
)

with open(args.config, 'r') as ymlfile:
    config = yaml.load(ymlfile, Loader=yaml.FullLoader)
    config = DotMap(config)

assert(config.name in ["lstur", "nrms", "naml", "naml_simple", "sentirec", "robust_sentirec"])

pl.seed_everything(1234)

logger = TensorBoardLogger(
    **config.logger
)

Seed set to 1234


In [4]:
checkpoint_callback = ModelCheckpoint(
    **config.checkpoint
)

## Load data

In [5]:
preprocess_path = f"{config.preprocess_data_path}/{config.dataset_size}/"

train_dataset = BaseDataset(
    path.join(preprocess_path+config.train_behavior),
    path.join(preprocess_path+config.train_news), 
    config)
val_dataset = BaseDataset(
    path.join(preprocess_path+config.val_behavior),
    path.join(preprocess_path+config.train_news), 
    config) 
train_loader = DataLoader(
    train_dataset,
    **config.train_dataloader)
val_loader = DataLoader(
    val_dataset,
    **config.val_dataloader)

100%|██████████| 26740/26740 [00:01<00:00, 14749.41it/s]
100%|██████████| 28994/28994 [00:14<00:00, 2066.49it/s]
100%|██████████| 26740/26740 [00:01<00:00, 13958.46it/s]
100%|██████████| 2204/2204 [00:02<00:00, 1075.78it/s]


In [6]:
# load embedding pre-trained embedding weights
embedding_weights=[]
with open(path.join(preprocess_path+config.embedding_weights), 'r') as file: 
    lines = file.readlines()
    for line in tqdm(lines):
        weights = [float(w) for w in line.split(" ")]
        embedding_weights.append(weights)
pretrained_word_embedding = torch.from_numpy(
    np.array(embedding_weights, dtype=np.float32)
)

100%|██████████| 42562/42562 [00:03<00:00, 13072.91it/s]


## Load model from checkpoint

In [7]:
if config.name == "lstur":
        model = LSTUR(config, pretrained_word_embedding)
elif config.name == "nrms":
    model = NRMS(config, pretrained_word_embedding)
elif config.name == "naml":
    model = NAML(config, pretrained_word_embedding)
elif config.name == "naml_simple":
    model = NAML_Simple(config, pretrained_word_embedding)
elif config.name == "sentirec":
    model = SENTIREC(config, pretrained_word_embedding)
elif config.name == "robust_sentirec":
    model = ROBUST_SENTIREC(config, pretrained_word_embedding)

## Train model

In [8]:
early_stop_callback = EarlyStopping(
    **config.early_stop
)
if args.resume is not None:
    model = model.load_from_checkpoint(
        args.resume, 
        config=config, 
        pretrained_word_embedding=pretrained_word_embedding)
    trainer = Trainer(
        **config.trainer,
        callbacks=[early_stop_callback, checkpoint_callback],
        logger=logger,
        #strategy=DDPStrategy(process_group_backend="gloo"),
        #plugins=DDPPlugin(find_unused_parameters=config.find_unused_parameters), 
        resume_from_checkpoint=args.resume
    )
else:
    trainer = Trainer(
        **config.trainer,
        callbacks=[early_stop_callback, checkpoint_callback],
        logger=logger,
        #strategy=DDPStrategy(process_group_backend="gloo")
        #plugins=DDPPlugin(find_unused_parameters=config.find_unused_parameters)
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(
    model=model, 
    train_dataloaders=train_loader, 
    val_dataloaders=val_loader
)

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                                   | Type             | Params | Mode 
-------------------------------------------------------------------------------------
0  | news_encoder                           | TimeDistributed  | 13.2 M | train
1  | user_encoder                           | UserEncoder      | 421 K  | train
2  | val_performance_metrics                | MetricCollection | 0      | train
3  | val_sentiment_diversity_metrics_vader  | MetricCollection | 0      | train
4  | val_sentiment_diversity_metrics_bert   | MetricCollection | 0      | train
5  | test_perfor

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\nclud\anaconda3\envs\newsrec\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

c:\Users\nclud\anaconda3\envs\newsrec\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0:  35%|███▌      | 159/453 [00:19<00:35,  8.25it/s, v_num=exp1]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined