In [1]:
import pandas as pd
import numpy as np
import json

import torch
from torch.utils.data import DataLoader

from transformers import BartTokenizer

from tam_lib import data_utils, modeling, evaluation

In [2]:
POLISUM_SRC = '../../../../data/polisum_clean.csv'
CLEAN_OUTPUT = '../../../../data/polisum_tam'
TAM_SAVE_PATH = '../../../../models/tam_model'
TOGL_DIST_OUT = '../../../../results/togl_decoding/'

TEXT_COL        = 'sm_text'
CLEAN_TEXT_COL  = 'text_clean'
SENT_SPLIT_TOK  = '|||'
SENT_RSPLIT_TOK = '\|\|\|'

BATCH_SIZE = 64
NUM_WORKERS = 10

DEVICE = torch.device('cuda:0')

# Loading Data 

In [3]:
data_train, data_val, vectorizer = data_utils.load_data(CLEAN_OUTPUT)

Reading data files
Reading vectorizer


In [4]:
polisum = pd.read_csv(POLISUM_SRC)

In [5]:
train_ds = data_utils.DocDataset(pd.concat((data_train, data_val), axis = 0), 
                              text_col = CLEAN_TEXT_COL, 
                              vectorizer = vectorizer)

In [6]:
VOCAB_SIZE = len(vectorizer.vocabulary_)
VECT_VOCAB = vectorizer.vocabulary_

In [9]:
dl = DataLoader(train_ds, batch_size = 32, num_workers = NUM_WORKERS, shuffle = True)

# Loading Model 

In [10]:
tam = modeling.TAM.from_pretrained(TAM_SAVE_PATH, device = DEVICE)

# Predicting Aspects

In [11]:
asps = {'asp1': None, 'asp2': None, 'title_date': []}

i = 0
for batch in dl:
    bow = batch['bow'].to(DEVICE).squeeze()
    asp1, asp2 = tam.pred_aspect_dists(bow)
    
    if asps['asp1'] is not None:
        asps['asp1'] = torch.concat((asps['asp1'], asp1.cpu().detach()), axis = 0)
        asps['asp2'] = torch.concat((asps['asp2'], asp2.cpu().detach()), axis = 0)
    else:
        asps['asp1'] = asp1.cpu().detach()
        asps['asp2'] = asp2.cpu().detach()
    asps['title_date'].append(batch['title_date'])
    print(f'Finished batch {i}', end = '\r')
    i+=1

Finished batch 1895

In [12]:
asps['title_date'] = [title for title_l in asps['title_date'] for title in title_l]

In [15]:
title_ix = {}
for i, title in enumerate(asps['title_date']):
    if title in title_ix.keys():
        title_ix[title].append(i)
    else:
        title_ix[title] = [i]
title_ix = {key: np.array(val) for key,val in title_ix.items()}

In [71]:
asps_1 = {title_date: (asps['asp1'][idxs].sum(dim = 0).softmax(dim = -1).sort(descending = True).values[:10], 
                       asps['asp1'][idxs].sum(dim = 0).softmax(dim = -1).sort(descending = True).indices[:10]) for title_date, idxs in title_ix.items()}
print('Finished asps 1')
asps_2 = {title_date: (asps['asp2'][idxs].sum(dim = 0).softmax(dim = -1).sort(descending = True).values[:10], 
                       asps['asp2'][idxs].sum(dim = 0).softmax(dim = -1).sort(descending = True).indices[:10]) for title_date, idxs in title_ix.items()}
print('Finished asps 2')

Finished asps 1
Finished asps 2


In [72]:
asps_1 = {td: (probs[0].numpy().tolist(), probs[1].numpy().tolist()) for td, probs in asps_1.items()}
asps_2 = {td: (probs[0].numpy().tolist(), probs[1].numpy().tolist()) for td, probs in asps_2.items()}

In [74]:
with open(TOGL_DIST_OUT + 'asp1_top10.json', 'w') as f:
    json.dump(asps_1, f)
with open(TOGL_DIST_OUT + 'asp2_top10.json', 'w') as f:
    json.dump(asps_2, f)

# Vocabulary Mapping 

In [76]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum')

In [77]:
bart_vocab = tokenizer.get_vocab()

In [89]:
vocab_map = {}

In [90]:
for i, (term, idx) in enumerate(VECT_VOCAB.items()):
    for tok in tokenizer.tokenize(term):
        tok2 = tok[0].upper() + tok[1:]
        tok3 = 'Ġ' + tok

        vocab_map[idx] = []
        for t in (tok, tok2, tok3):
            if t in bart_vocab.keys():
                vocab_map[idx].append(bart_vocab[t])
        if len(vocab_map[idx]) == 0:
            del vocab_map[idx]

        print(f'Finished {i} of {VOCAB_SIZE}', end = '\r')

Finished 9999 of 10000

In [97]:
vocab_map = {int(key): val for key, val in vocab_map.items()}

In [98]:
with open(TOGL_DIST_OUT + 'vocab_map.json', 'w') as f:
    json.dump(vocab_map, f)

# Finalize Togl Distributions 

In [105]:
def expand_dist(dist, vocab_map):
    probs = dist[0]
    ids = dist[1]
    
    new_probs, new_ids = [], []
    
    for prob, idx in zip(probs, ids):
        if idx in vocab_map.keys():
            new_ids += vocab_map[idx]
            new_probs += [prob] * len(vocab_map[idx])
    
    return (new_probs, new_ids)

In [107]:
asps_1_mapped = {key: expand_dist(val, vocab_map) for key, val in asps_1.items()}
print('Finished asps_1')
asps_2_mapped = {key: expand_dist(val, vocab_map) for key, val in asps_2.items()}
print('Finished asps_2')

Finished asps_1
Finished asps_2


In [110]:
with open(TOGL_DIST_OUT + 'asp1_top10_mapped.json', 'w') as f:
    json.dump(asps_1_mapped, f)
with open(TOGL_DIST_OUT + 'asp2_top10_mapped.json', 'w') as f:
    json.dump(asps_2_mapped, f)