In [1]:
import sys
from pathlib import Path
from typing import List
%config Completer.use_jedi = False # fix autocomplete nor working

project_dir = Path('/home/al.thomas/sync/development/data2text/')
sys.path.insert(0, str(project_dir))
from hdfs_utils import copy_from_hdfs_to_local, copy_from_local_to_hdfs

def copy_artifacts(run_id: str, source_dir: str, model_checkpoints: List[str], source_files: List[str]):
    """
    run_id: mlflow run id
    source_dir: name of the directory containing source files in mlflow artifacts directory (on hdfs)
    model_checkpoints: file names of the checkpoints to load
    source_files: files useful in this notebook (source files, config, etc)
    """
    
    # copy mlflow model checkpoints to local
    artifact_path = f'viewfs:///user/al.thomas/mlflow_artifacts/{run_id}/artifacts'
    checkpoints = []
    for m in model_checkpoints:
        checkpoints += copy_from_hdfs_to_local(f'{artifact_path}/{m}', str(project_dir / f'models/{run_id}'))
    print('Copied checkpoints:\n' + '\n'.join(checkpoints))

    # copy source code used to train model to local
    copied_files = []
    for f in source_files:
        copied_files += copy_from_hdfs_to_local(artifact_path + f'/{source_dir}/{f}', str(project_dir / f'models/{run_id}/artifact_code'))
    print('Copied source files:\n' + '\n'.join(copied_files))

def get_relations_as_plain_text(ex):
    """
    ex: one sample of batch_raw, containing raw relations and entity text
    """
    # list of entities in the sentence as plain text
    ents = [[y for y in vocab["entity"](x) if y[0] != "<"] for x in ex["ent_text"]]
    
    # convert relations from indices (entity number, relation tokens) to text
    rel_text = [(ents[e1], vocab["relation"](r), ents[e2]) for e1, r, e2 in ex["raw_relation"]]
    rel_text = [f"{' '.join(e1)} -> {r} -> {' '.join(e2)}" for e1, r, e2 in rel_text]

    return rel_text

def get_pz_qz(model, batch):
    """
    model: g2t model
    batch: g2t batch, with graph, corresponding entities/relations, and target text
    
    return: (mu_p, log_sigma_p, mu_q, log_sigma_q), parameters of the conditional 
        prior and variational posterior
    """
    def len2mask(lens, device):
        max_len = max(lens)
        mask = torch.arange(max_len, device=device).unsqueeze(0).expand(len(lens), max_len)
        mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1)
        return mask

    # get the embedding of the graph g_root
    ent_mask = len2mask(batch["ent_len"], batch["ent_text"].device)
    ent_text_mask = batch["ent_text"] == 0  # (sum(num_ent_i), max_ent_len)
    rel_mask = batch["rel"] == 0  # (bs, max_num_rel), 0 means the <PAD>
    g_ent, g_root, ent_enc = model.enc_forward(
        batch, ent_mask, ent_text_mask, batch["ent_len"], rel_mask
    )  # (bs, max_num_ent, d), except for g_root which is missing the 1st dim
    
    # get text embedding tar_inp, with entities (bs, max_sent_len, d)
    device = torch.device("cpu")
    outs = []
    _mask = (batch["text"] >= len(model.text_vocab)).long()  # 0 if token is in vocab, 1 if entity or unknown
    _inp = (_mask * 3 + (1.0 - _mask) * batch["text"])  # 3 is <UNK>, otherwise use token index
    tar_inp = model.tar_emb(_inp.long())
    # embeddings for tokens in text vocab (0. if unknown or entity)
    embeddings_text = (1.0 - _mask[:, :, None]) * tar_inp  # (bs, max_sent_len, d)
    # embeddings for entity tokens (0. elsewhere)
    embeddings_ent = ent_enc[
        torch.arange(len(batch["text"]))[:, None].to(device),
        ((batch["text"] - len(model.text_vocab)) * _mask).long(),  # 0 for ENT_0 and other tokens, i for ENT_i
    ]  # (bs, max_sent_len, d)
    embeddings_ent = (embeddings_ent * _mask[:, :, None])  # set to 0. if not entity
    tar_inp = embeddings_text + embeddings_ent
    
    # get prior p(z|y)
    mu_p, log_sigma_p = model.get_vae_pz(g_root)
    std_p = torch.exp(0.5*log_sigma_p)  # log_sigma is actually log(sigma**2) in cycleGT code
    
    # get variational posterior q(z|x)
    mu_q, log_sigma_q = model.get_vae_qz(tar_inp)
    std_q = torch.exp(0.5*log_sigma_q)
    
    return mu_p, std_p, mu_q, std_q

# With original CycleGT code

In [2]:
run_id = '98491bf7557e4c93bdde85ceff3fb098'
g2t_checkpoint = 'g2t_model.pt_best_ep49'
artifact_dir = project_dir / f'models/{run_id}'
copy_artifacts(run_id,  source_dir='code', 
               model_checkpoints = [g2t_checkpoint],
               source_files=['g2t_model.py', 'main.py', 'data.py', 
                             'config.yaml', 'tmp_vocab.pt', 
                             'train.json', 'dev.json', 'test.json'])
# and add the source code to python path
sys.path.insert(0, str(project_dir / f'models/{run_id}'))

Copied checkpoints:
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/g2t_model.pt_best_ep49
Copied source files:
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/g2t_model.py
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/main.py
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/data.py
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/config.yaml
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/tmp_vocab.pt
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/train.json
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/dev.json
/home/al.thomas/sync/development/data2text/models/98491bf7557e4c93bdde85ceff3fb098/artifact_code/tes

In [3]:
from artifact_code.g2t_model import GraphWriter
from artifact_code.main import prep_data, write_txt, eval_g2t
from artifact_code.data import batch2tensor_g2t
from itertools import islice
import copy
import yaml
import torch
from torch.distributions import Normal

config = yaml.safe_load(open(artifact_dir/'artifact_code/config.yaml', "r"))
config["main"]["train_file"] = str(artifact_dir/'artifact_code/train.json')
config["main"]["dev_file"] = str(artifact_dir/'artifact_code/dev.json')
config["main"]["test_file"] = str(artifact_dir/'artifact_code/train.json')
dim_z = config["g2t"]["vae_dim"]

Using backend: pytorch
INFO:root:Start Logging
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/al.thomas/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [4]:
# load data
pool, vocab = prep_data(config["main"], load=str(artifact_dir/'artifact_code/tmp_vocab.pt'))

INFO:root:MAX_LEN 31


In [5]:
# load model
model = GraphWriter(copy.deepcopy(config["g2t"]), vocab)
model.load_state_dict(torch.load(artifact_dir/ g2t_checkpoint))

<All keys matched successfully>

In [6]:
def get_pred(m, voc, batch, z):
    pred = m(batch, beam_size=5, vae_z=z)
    pred = write_txt(batch, pred, voc["text"])
    return pred

def sample_predictions(batch, batch_raw):
    # input graph and text ground truth
    graph = [get_relations_as_plain_text(ex) for ex in batch_raw]
    target = write_txt(batch, batch["tgt"], vocab["text"])

    # predictions using prior
    pred_mu_p = get_pred(model, vocab, batch, p_z.loc)
    preds_p = []
    for z in p_z.sample((10,)):
        preds_p.append(get_pred(model, vocab, batch, z))

    # predicitons using posterior
    pred_mu_q = get_pred(model, vocab, batch, q_z.loc)
    preds_q = []
    for z in q_z.sample((10,)):
        preds_q.append(get_pred(model, vocab, batch, z))

    
    for i in range(len(target)):
        print(f"-----------\n{i}")
        print(f"Graph y:\n\t{graph[i]}")
        print(f"Target text x:\n\t{target[i]}")
        print(f"Predictions with z=mu_p: \n\t{pred_mu_p[i]}")
        print("Predictions with samples z~p(z|y) from prior:")
        for pred in preds_p:
            print(f"\t{pred[i]}")
        print(f"Predictions with z=mu_q: \n\t{pred_mu_q[i]}")
        print("Predictions with samples z~q(z|x) from posterior:")
        for pred in preds_q:
            print(f"\t{pred[i]}")

In [7]:
b = 15  # batch

batch_raw = next(islice(pool.draw_with_type(batch_size=32, shuffle=False, _type="dev"), b, b+1))
batch = batch2tensor_g2t(batch_raw, 'cpu', vocab)

mu_p, std_p, mu_q, std_q = get_pz_qz(model, batch)
p_z = Normal(mu_p, std_p)
q_z = Normal(mu_q, std_q)

sample_predictions(batch, batch_raw)

-----------
0
Graph y:
	["American Journal of Mathematics -> impactFactor -> `` 1.337 ''"]
Target text x:
	["the impact factor of the  American Journal of Mathematics  is  `` 1.337 ''  ."]
Predictions with z=mu_p: 
	["the  American Journal of Mathematics  is an impact factor of  `` 1.337 ''  ."]
Predictions with samples z~p(z|y) from prior:
	["the  American Journal of Mathematics  is published in the year  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  is in the city of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  is published in  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  is based in  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the impact fact

In [88]:
b = 15  # batch

batch_raw = next(islice(pool.draw_with_type(batch_size=32, shuffle=False, _type="dev"), b, b+1))
batch = batch2tensor_g2t(batch_raw, 'cpu', vocab)

mu_p, std_p, mu_q, std_q = get_pz_qz(model, batch)
p_z = Normal(mu_p, std_p)
q_z = Normal(mu_q, std_q)

sample_predictions(batch, batch_raw)

-----------
0

Graph y:
	["American Journal of Mathematics -> impactFactor -> `` 1.337 ''"]
Target text x:
	["the impact factor of the  American Journal of Mathematics  is  `` 1.337 ''  ."]
Predictions with z=mu_p: 
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
Predictions with samples z~p(z|y) from prior:
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	["the  American Journal of Mathematics  has an impact factor of  `` 1.337 ''  ."]
	['the  American Journal of Mathematics  is published by the  American Jou