## Global Settings and Imports

In [1]:
# jupyter notebook에서 import 해서 쓰는 모듈의 코드가 변경될 시, 변동 사항을 자동으로 반영해주는 기능 켜기
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import yaml
from dotmap import DotMap
from os import path
import numpy as np
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
from models.lstur import LSTUR
from models.nrms import NRMS
from models.naml import NAML
from models.naml_simple import NAML_Simple
from models.sentirec import SENTIREC
from models.robust_sentirec import ROBUST_SENTIREC
from data.dataset import BaseDataset
from tqdm import tqdm

## Prepare parameters

In [4]:
args = argparse.Namespace(
    config = "config/model/nrms/exp1.yaml",
    ckpt = "logs/lightning_logs/checkpoints/nrms/exp1/epoch=9-val_auc_epoch=0.6697.ckpt"
)

with open(args.config, 'r') as ymlfile:
    config = yaml.load(ymlfile, Loader=yaml.FullLoader)
    config = DotMap(config)

assert(config.name in ["lstur", "nrms", "naml", "naml_simple", "sentirec", "robust_sentirec"])

pl.seed_everything(1234)

logger = TensorBoardLogger(
    **config.logger
)

Seed set to 1234


## Load data

In [5]:
preprocess_path = f"{config.preprocess_data_path}/{config.dataset_size}/"

test_dataset = BaseDataset(
    path.join(preprocess_path + config.test_behavior),
    path.join(preprocess_path + config.test_news), 
    config)
test_loader = DataLoader(
    test_dataset,
    **config.test_dataloader)

100%|██████████| 18723/18723 [00:00<00:00, 26686.43it/s]
100%|██████████| 7538/7538 [00:02<00:00, 2757.52it/s]


In [6]:
# load embedding pre-trained embedding weights
embedding_weights=[]
with open(path.join(preprocess_path + config.embedding_weights), 'r') as file: 
    lines = file.readlines()
    for line in tqdm(lines):
        weights = [float(w) for w in line.split(" ")]
        embedding_weights.append(weights)
pretrained_word_embedding = torch.from_numpy(
    np.array(embedding_weights, dtype=np.float32)
)

100%|██████████| 42562/42562 [00:02<00:00, 19540.68it/s]


## Load model from checkpoint

In [7]:
print(config.name)
if config.name == "lstur":
    model = LSTUR.load_from_checkpoint(
        args.ckpt, 
        config=config, 
        pretrained_word_embedding=pretrained_word_embedding
    )
elif config.name == "nrms":
    model = NRMS.load_from_checkpoint(
        args.ckpt, 
        config=config, 
        pretrained_word_embedding=pretrained_word_embedding
    )
elif config.name == "naml":
    model = NAML.load_from_checkpoint(
        args.ckpt, 
        config=config, 
        pretrained_word_embedding=pretrained_word_embedding
    )
elif config.name == "naml_simple":
    model = NAML_Simple.load_from_checkpoint(
        args.ckpt, 
        config=config, 
        pretrained_word_embedding=pretrained_word_embedding
    )
elif config.name == "sentirec":
    model = SENTIREC.load_from_checkpoint(
        args.ckpt, 
        config=config, 
        pretrained_word_embedding=pretrained_word_embedding
    )
elif config.name == "robust_sentirec":
    model = ROBUST_SENTIREC.load_from_checkpoint(
        args.ckpt, 
        config=config, 
        pretrained_word_embedding=pretrained_word_embedding
    )
# elif:
    # UPCOMING MODELS

nrms


## Test model

In [8]:
trainer = Trainer(
    **config.trainer,
    logger=logger
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
trainer.test(
    model=model, 
    dataloaders=test_loader
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\USER\anaconda3\envs\newsrec\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         Test metric                 DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auc_epoch             0.5560961365699768
test_ils_senti@10_bert_epoch      0.4856659173965454
test_ils_senti@10_vader_epoch     0.2443128377199173
 test_ils_senti@5_bert_epoch      0.47626766562461853
test_ils_senti@5_vader_epoch      0.25722575187683105
   test_ils_topic@10_epoch        0.09298563003540039
   test_ils_topic@5_epoch         0.12804388999938965
       test_mrr_epoch             0.24117176234722137
     test_ndcg@10_epoch           0.3258844017982483
      test_ndcg@5_epoch           0.25681784749031067
  test_senti@10_bert_epoch        0.3998955190181732
  test_senti@10_vader_epoch       0.02770446240901947
   test_senti@5_bert_epoch        0.26882338523864746
  test_sent

[{'test_auc_epoch': 0.5560961365699768,
  'test_mrr_epoch': 0.24117176234722137,
  'test_ndcg@10_epoch': 0.3258844017982483,
  'test_ndcg@5_epoch': 0.25681784749031067,
  'test_senti@10_vader_epoch': 0.02770446240901947,
  'test_senti@5_vader_epoch': 0.02272985316812992,
  'test_senti_mrr_vader_epoch': 0.01998109742999077,
  'test_senti@10_bert_epoch': 0.3998955190181732,
  'test_senti@5_bert_epoch': 0.26882338523864746,
  'test_senti_mrr_bert_epoch': 0.18272240459918976,
  'test_topic_div@10_epoch': 0.4310603439807892,
  'test_topic_div@5_epoch': 0.3462444245815277,
  'test_topic_mrr_epoch': 0.4286380708217621,
  'test_ils_senti@10_vader_epoch': 0.2443128377199173,
  'test_ils_senti@5_vader_epoch': 0.25722575187683105,
  'test_ils_senti@10_bert_epoch': 0.4856659173965454,
  'test_ils_senti@5_bert_epoch': 0.47626766562461853,
  'test_ils_topic@10_epoch': 0.09298563003540039,
  'test_ils_topic@5_epoch': 0.12804388999938965}]

In [1]:
test_dataset[0]

NameError: name 'test_dataset' is not defined

In [30]:
test_dataset[11]['c_title']

tensor([[ 2679,  4349,  4726,  2851,   550,    29,  5067,   164,  3585, 10645,
          2250,  1471,     0,     0,     0,     0,     0,     0,     0,     0],
        [  887,   128,  3392,  1401,  1764,   126,  4545,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]])