In [13]:
import os
import sys

# get the absolute path of the current script
current_dir = "/import/nlp/social_media_timeline_dashboard/"

# add THVAE-summary in directory to look for modules
thvae_dir = os.path.abspath(os.path.join(current_dir,"THVAE-summary"))
if thvae_dir not in sys.path:
    sys.path.insert(0, thvae_dir)

from thvae.data_pipelines.assemblers import assemble_vocab_pipeline
from thvae.data_pipelines.assemblers import assemble_train_pipeline, \
    assemble_eval_pipeline, assemble_vocab_pipeline, assemble_infer_pipeline
from thvae.utils.helpers.io import get_rev_number
from thvae.utils.helpers.run import gen_summs
from transformers import BartTokenizer
from thvae.utils.fields import InpDataF
from mltoolkit.mldp.utils.tools import Vocabulary
from thvae.utils.hparams import ModelHP, RunHP
from thvae.modelling import ThVAE as Model
from thvae.modelling.interfaces import IThVAE
from mltoolkit.mlmo.utils.tools.annealing import KlCycAnnealing
from thvae.utils.tools import SeqPostProcessor
from thvae.utils.fields import ModelF
from thvae.utils.constants import VOCAB_DEFAULT_SYMBOLS

from mltoolkit.mlutils.helpers.paths_and_files import comb_paths, get_file_name

from functools import partial

from thvae.modelling.interfaces import IDevThVAE as IDev, IThVAE as IModel

In [6]:
import json
import pandas as pd

In [20]:
# create a dataset in the right input format for THVAE
# load posts
user_id = "tomorrowistomato"
post_ids = [
        "ialh84",
        "iam4l5",
        "ii04gh",
        "iik5wz",
        "ioou4c",
        "iozs0z",
        "iu7ufs",
        "iusum0",
        "iuuvc3",
        "ixj0zx",
        "iymeea",
        "j2eb4i",
        "j3dkjk",
        "j3pp6r",
        "j72puw",
        "j9j4nu",
        "jawbob"
    ]

THVAE_test_set = {
    "group_id": [],
    "review_text": [],
    "category": [],
}


with open(f"public/data/{user_id}_posts.json", "r") as f:
    posts = json.load(f)

for post_id in post_ids:
    THVAE_test_set["group_id"].append(post_id)
    THVAE_test_set["review_text"].append(f"{posts[post_id]['title']}\n{posts[post_id]['body']}")
    THVAE_test_set["category"].append("post")

test_set_df = pd.DataFrame(THVAE_test_set)

In [22]:
test_set_df.to_csv("thvae_test.csv", index=False)

In [2]:
run_hp   = RunHP(root_path=os.path.join(current_dir, "THVAE-summary")) 

In [3]:
model_hp = ModelHP()

In [4]:
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [5]:
vocab_data_source = {"data_path": run_hp.train_fp}

In [6]:
#   PIPELINES AND VOCAB   #
vocab_pipeline = assemble_vocab_pipeline(text_fname=InpDataF.REV_TEXT)
word_vocab = Vocabulary(vocab_pipeline, name_prefix="word")

bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')


# adding special symbols before creating vocab, so they would appear on top
for st in VOCAB_DEFAULT_SYMBOLS:
    if st not in word_vocab:
        word_vocab.add_special_symbol(st)

word_vocab.load_or_create(run_hp.words_vocab_fp,
                          data_source=vocab_data_source,
                          max_size=model_hp.ext_vocab_size, sep=' ',
                          data_fnames=InpDataF.REV_TEXT)

In [7]:
summary_pipeline = assemble_sum_pipeline(word_vocab,
                                       bart_tokenizer=bart_tokenizer,
                                       max_groups_per_batch=run_hp.val_max_groups_per_batch,
                                       min_revs_per_group=run_hp.max_rev_per_group,
                                       max_revs_per_group=run_hp.max_rev_per_group,
                                       seed=run_hp.seed, workers=1)

In [None]:
def gen_summs(data_iter, output_file_path, summ_gen_func):
    """Generates summaries and saves them to a txt file. Order is preserved"""
    out_file = open(output_file_path, encoding='utf-8', mode='w')
    for batch in data_iter:
        summs = summ_gen_func(batch)
        for summ in summs:
            out_file.write(summ + '\n')
    out_file.close()

In [None]:
infer_bsz = 40
infer_inp_file_path = "/import/nlp/social_media_timeline_dashboard/THVAE-summary/thvae/artifacts/amazon/data/infer_input.csv"
out_file_name = get_file_name(infer_inp_file_path)
infer_out_file_path = comb_paths(run_hp.output_path,
                                    f'{out_file_name}.out.txt')

assert infer_inp_file_path is not None
rev_num = get_rev_number(infer_inp_file_path)

print("Performing inference/summary generation")
infer_data_pipeline = assemble_infer_pipeline(word_vocab, max_reviews=rev_num,
                                                tokenization_func=run_hp.tok_func,
                                                max_groups_per_chunk=infer_bsz)
summ_pproc = SeqPostProcessor(tokenizer=lambda x: x.split(),
                                detokenizer=run_hp.detok_func,
                                tcaser=run_hp.true_case_func)

print(f"Saving summaries to: '{infer_out_file_path}'")
gen_summs(infer_data_pipeline.iter(data_path=infer_inp_file_path),
            output_file_path=infer_out_file_path,
            summ_gen_func=partial(idev.summ_generator,
                                summ_post_proc=summ_post_proc))