In [1]:
import argparse
import datetime
import os
import shutil
import time
import random

import dgl
import numpy as np
import torch
from rouge import Rouge

from HiGraph import HSumGraph, HSumDocGraph
from Tester import SLTester
from module.dataloader import ExampleSet, MultiExampleSet, graph_collate_fn
from module.embedding import Word_Embedding
from module.vocabulary import Vocab
from tools.logger import *

  from .autonotebook import tqdm as notebook_tqdm


### Load the model from checkpoints

In [2]:
checkpoint_path = "checkpoints/multinews.ckpt"

python evaluation.py --data_dir datasets/multinews --cache_dir cache/MultiNews --embedding_path glove/glove.6B.100d.txt --model HDSG --save_root save/ --log_root log/ -m 3 --test_model multi --use_pyrouge

In [3]:
args = argparse.Namespace(
    data_dir='datasets/multinews',
    cache_dir='cache/MultiNews',
    embedding_path='glove/glove.42B.300d.txt',
    model="HDSG",
    test_model='evalbestmodel',
    use_pyrouge=False,
    save_root='save/',
    log_root='log/',
    gpu='0',
    cuda=False,
    vocab_size=50000,
    batch_size=32,
    n_iter=1,
    word_embedding=True,
    word_emb_dim=300,
    embed_train=False,
    feat_embed_size=50,
    n_layers=1,
    lstm_hidden_state=128,
    lstm_layers=2,
    bidirectional=True,
    n_feature_size=128,
    hidden_size=64,
    gcn_hidden_size=128,
    ffn_inner_hidden_size=512,
    n_head=8,
    recurrent_dropout_prob=0.1,
    atten_dropout_prob=0.1,
    ffn_dropout_prob=0.1,
    use_orthnormal_init=True,
    sent_max_len=100,
    doc_max_timesteps=50,
    save_label=False,
    limited=False,
    blocking=False,
    m=3
)

# File paths
DATA_FILE = os.path.join(args.data_dir, "test.label.jsonl")
VOCAL_FILE = os.path.join(args.cache_dir, "vocab")
FILTER_WORD = os.path.join(args.cache_dir, "filter_word.txt")
LOG_PATH = args.log_root

logger.info("Pytorch %s", torch.__version__)
logger.info("[INFO] Create Vocab, vocab path is %s", VOCAL_FILE)
vocab = Vocab(VOCAL_FILE, args.vocab_size)
embed = torch.nn.Embedding(vocab.size(), args.word_emb_dim)
if args.word_embedding:
    embed_loader = Word_Embedding(args.embedding_path, vocab)
    vectors = embed_loader.load_my_vecs(args.word_emb_dim)
    pretrained_weight = embed_loader.add_unknown_words_by_avg(vectors, args.word_emb_dim)
    embed.weight.data.copy_(torch.Tensor(pretrained_weight))
    embed.weight.requires_grad = args.embed_train

hps = args

2024-09-06 16:42:03,449 INFO    : Pytorch 1.12.0
2024-09-06 16:42:03,451 INFO    : [INFO] Create Vocab, vocab path is cache/MultiNews/vocab
2024-09-06 16:42:03,488 INFO    : [INFO] max_size of vocab was specified as 50000; we now have 50000 words. Stopping reading.
2024-09-06 16:42:03,489 INFO    : [INFO] Finished constructing vocabulary of 50000 total words. Last word added: medicated
2024-09-06 16:42:03,658 INFO    : [INFO] Loading external word embedding...
2024-09-06 16:43:23,395 INFO    : [INFO] External Word Embedding iov count: 48908, oov count: 1092


In [4]:
model = HSumDocGraph(hps,embed)

In [5]:
# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))  # or 'cuda' if using GPU
model.load_state_dict(checkpoint)  # Load model state from the checkpoint

<All keys matched successfully>

### Save the model

In [6]:
def save_model(model, save_file):
    with open(save_file, 'wb') as f:
        torch.save(model.state_dict(), f)
    logger.info('[INFO] Saving model to %s', save_file)

In [8]:
eval_dir = "save/eval/"

for i in range(3):  
    save_file = os.path.join(eval_dir, f"bestmodel_{i}")
    save_model(model, save_file)

2024-09-06 16:57:49,153 INFO    : [INFO] Saving model to save/eval/bestmodel_0
2024-09-06 16:57:49,197 INFO    : [INFO] Saving model to save/eval/bestmodel_1
2024-09-06 16:57:49,241 INFO    : [INFO] Saving model to save/eval/bestmodel_2
