In [1]:
import sys
from pathlib import Path
from typing import List
%config Completer.use_jedi = False # fix autocomplete nor working

project_dir = Path('/home/al.thomas/sync/development/data2text/')
sys.path.insert(0, str(project_dir))
from hdfs_utils import copy_from_hdfs_to_local, copy_from_local_to_hdfs

def copy_artifacts(run_id: str, source_dir: str, source_files: List[str]):
    """
    run_id: mlflow run id
    source_dir: name of the directory containing source files in mlflow artifacts directory (on hdfs)
    source_files: files useful in this notebook (source files, config, etc)
    """
    
    # copy mlflow model checkpoints to local
    artifact_path = f'viewfs:///user/al.thomas/mlflow_artifacts/{run_id}/artifacts'
    checkpoints = []
    checkpoints += copy_from_hdfs_to_local(artifact_path + '/t2g_model.ptXbest', str(project_dir / f'models/{run_id}'))
    checkpoints += copy_from_hdfs_to_local(artifact_path + '/g2t_model.ptXbest', str(project_dir / f'models/{run_id}'))
    print('Copied checkpoints:\n' + '\n'.join(checkpoints))

    # copy source code used to train model to local
    copied_files = []
    for f in source_files:
        copied_files += copy_from_hdfs_to_local(artifact_path + f'/{source_dir}/{f}', str(project_dir / f'models/{run_id}/artifact_code'))
    print('Copied source files:\n' + '\n'.join(copied_files))
    # and add the source code to python path
    sys.path.insert(0, str(project_dir / f'models/{run_id}'))

# With original CycleGT code

In [35]:
run_id = 'f0c02d2aa4704d8ebfaf8f37fbf28176'
artifact_dir = project_dir / f'models/{run_id}'
copy_artifacts(run_id,  source_dir='code', source_files=['g2t_model.py', 'main.py', 'data.py', 
                                                         'config.yaml', 'tmp_vocab.pt', 
                                                         'train.json', 'dev.json', 'test.json'])

Copied checkpoints:
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/t2g_model.ptXbest
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/g2t_model.ptXbest
Copied source files:
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/artifact_code/g2t_model.py
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/artifact_code/main.py
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/artifact_code/data.py
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/artifact_code/config.yaml
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/artifact_code/tmp_vocab.pt
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/artifact_code/train.json
/home/al.thomas/sync/development/data2text/models/f0c02d2aa4704d8ebfaf8f37fbf28176/artifact_code/dev.json
/hom

In [57]:
from artifact_code.g2t_model import GraphWriter
from artifact_code.main import prep_data, write_txt
from artifact_code.data import batch2tensor_g2t
from itertools import islice
import copy
import yaml
import torch

In [44]:
# load data
config = yaml.safe_load(open(artifact_dir/'artifact_code/config.yaml', "r"))
config["main"]["train_file"] = str(artifact_dir/'artifact_code/train.json')
config["main"]["dev_file"] = str(artifact_dir/'artifact_code/dev.json')
config["main"]["test_file"] = str(artifact_dir/'artifact_code/train.json')
pool, vocab = prep_data(config["main"], load=str(artifact_dir/'artifact_code/tmp_vocab.pt'))

INFO:root:MAX_LEN 31


In [45]:
# load model
model = GraphWriter(copy.deepcopy(config["g2t"]), vocab)

In [62]:
i = 10
batch = next(islice(pool.draw_with_type(batch_size=32, shuffle=False, _type="test"), i, i+1))
batch = batch2tensor_g2t(batch, 'cpu', vocab)

In [63]:
seq = model(batch, beam_size=5)
pred = write_txt(batch, seq, vocab["text"])
target = write_txt(batch, batch["tgt"], vocab["text"])

In [64]:
pred[0], target[0]

([' 8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0   Andrews County Airport   8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0   Andrews County Airport   Andrews County Airport   8.0   Andrews County Airport   8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0   Andrews County Airport   8.0   8.0   8.0   8.0   8.0   8.0   Andrews County Airport   8.0   8.0   8.0   8.0   8.0   8.0   8.0   8.0 '],
 [' Andrews County Airport  runway length is  8.0  .'])

In [65]:
seq[0]

tensor([   1, 2288, 2288, 2288, 2288, 2288, 2288, 2288, 2288, 2289, 2288, 2288,
        2288, 2288, 2288, 2288, 2288, 2288, 2288, 2288, 2288, 2289, 2289, 2288,
        2289, 2288, 2288, 2288, 2288, 2288, 2288, 2288, 2288, 2288, 2289, 2288,
        2288, 2288, 2288, 2288, 2288, 2289, 2288, 2288, 2288, 2288, 2288, 2288,
        2288, 2288,    0])