In [1]:
import os
import torch
import time
import numpy as np
import pickle
from torch import nn
from logging import getLogger
from data import Vocab, NLP, S2SDataset
from utils import build_optimizer, init_seed, init_logger, init_device, read_configuration, collate_fn_graph_text, \
    format_time
from module import GraphEncoder, GraphReconstructor, GraphPointer
from transformers import BartTokenizer, BartForConditionalGeneration, BertModel, BertTokenizer
from torch.utils.data import Dataset, DataLoader

OSError: dlopen(/usr/local/anaconda3/envs/bishe/lib/python3.7/site-packages/torch_sparse/_convert_cpu.so, 6): Symbol not found: __ZN2at8internal13_parallel_runExxxRKNSt3__18functionIFvxxmEEE
  Referenced from: /usr/local/anaconda3/envs/bishe/lib/python3.7/site-packages/torch_sparse/_convert_cpu.so
  Expected in: /usr/local/anaconda3/envs/bishe/lib/python3.7/site-packages/torch/lib/libtorch_cpu.dylib
 in /usr/local/anaconda3/envs/bishe/lib/python3.7/site-packages/torch_sparse/_convert_cpu.so

In [None]:
def test(config):
    init_logger(config)
    logger = getLogger()

    logger.info(config)
    init_seed(config["seed"], config["reproducibility"])
    device = init_device(config)

    logger.info("Build node and relation vocabularies.")
    vocabs = dict()
    vocabs["node"] = Vocab(config["node_vocab"])
    vocabs["relation"] = Vocab(config["relation_vocab"])

    # logger.info("Build Teacher Model.")
    # teacher = BartForConditionalGeneration.from_pretrained(config["teacher_dir"])
    # teacher.requires_grad = False
    # for para in teacher.parameters():
    #     para.requires_grad = False
    # teacher.to(device)

    logger.info("Build Student Model.")
    student = GraphEncoder(vocabs["node"].size(), vocabs["relation"].size(),
                           config["gnn_layers"], config["embedding_size"], config["node_embedding"])
    student.load_state_dict(torch.load(config["external_model"])["student"])
    student.to(device)

    logger.info("Build PLM Model.")
    bart_tokenizer = BartTokenizer.from_pretrained(config["fine_tuned_plm_dir"])
    plm = BartForConditionalGeneration.from_pretrained(config["fine_tuned_plm_dir"])
    plm.to(device)

    logger.info("Create testing dataset.")
    test_dataloader = DataLoader(
        S2SDataset(data_dir=config["data_dir"], dataset=config["dataset"],
                   tokenizer=bart_tokenizer, node_vocab=vocabs["node"], relation_vocab=vocabs["relation"],
                   num_samples="all", usage="test"),
        batch_size=config["test_batch_size"],
        shuffle=False,
        num_workers=4,
        drop_last=False,
        collate_fn=collate_fn_graph_text,
        pin_memory=True)

    student.eval()
    # teacher.eval()
    plm.eval()
    idx = 0
    generated_text = []
    reference_text = []
    with torch.no_grad():
        for batch in test_dataloader:
            nodes, edges, types, node_masks, kd_description, kd_description_masks, kd_positions, \
                recon_relations, recon_positions, recon_masks, gen_outputs, gen_masks, pointer, pointer_masks = batch

            # kd_description = kd_description.to(device)
            # kd_description_masks = kd_description_masks.to(device)
            # output_dict = teacher(kd_description,
            #                       attention_mask=kd_description_masks,
            #                       output_hidden_states=True,
            #                       return_dict=True)
            # positions = kd_positions.unsqueeze(-1).expand(-1, -1, output_dict["encoder_last_hidden_state"].size(-1)).to(device)
            # teacher_embeddings = torch.gather(output_dict["encoder_last_hidden_state"], dim=1, index=positions).detach()

            nodes = nodes.to(device)
            student_embeddings = student(nodes, edges, types)

            node_masks = node_masks.to(device)
            generated_ids = plm.generate(input_ids=None,
                                         inputs_embeds=student_embeddings,
                                         attention_mask=node_masks,
                                         num_beams=4,
                                         max_length=config["max_seq_length"],
                                         early_stopping=True)

            generated = bart_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            reference = bart_tokenizer.batch_decode(gen_outputs, skip_special_tokens=True)
            generated_text.extend(generated)
            reference_text.extend(reference)

            idx += 1
            logger.info("Finish {}-th example.".format(idx))

    assert len(generated_text) == len(reference_text)
    saved_file = "{}-{}.res".format(config["dataset"], config["num_samples"])
    saved_file_path = os.path.join(config["output_dir"], saved_file)
    fout = open(saved_file_path, "w")
    for i in range(len(generated_text)):
        fout.write("Generated text: " + generated_text[i].strip() + "\n")
        fout.write("Reference text: " + reference_text[i].strip() + "\n")
    fout.close()
