In [33]:
import pandas as pd
import numpy as np
import json

import torch
from torch.utils.data import DataLoader

from transformers import BartTokenizer

from tam_lib import data_utils, modeling, evaluation

In [29]:
POLISUM_SRC = '../../../../data/polisum_clean.csv'
CLEAN_OUTPUT = '../../../../data/polisum_tam'
TAM_SAVE_PATH = '../../../../models/tam_model'
TOGL_DIST_OUT = '../../../../results/togl_decoding/'

TEXT_COL        = 'sm_text'
CLEAN_TEXT_COL  = 'text_clean'
SENT_SPLIT_TOK  = '|||'
SENT_RSPLIT_TOK = '\|\|\|'

BATCH_SIZE = 64
NUM_WORKERS = 10

DEVICE = torch.device('cuda:0')

# Loading Data 

In [3]:
data_train, data_val, vectorizer = data_utils.load_data(CLEAN_OUTPUT)

Reading data files
Reading vectorizer


In [4]:
polisum = pd.read_csv(POLISUM_SRC)

In [5]:
train_ds = data_utils.DocDataset(pd.concat((data_train, data_val), axis = 0), 
                              text_col = CLEAN_TEXT_COL, 
                              vectorizer = vectorizer)

In [6]:
VOCAB_SIZE = len(vectorizer.vocabulary_)
VECT_VOCAB = vectorizer.vocabulary_

In [7]:
dl = DataLoader(train_ds, batch_size = 16, num_workers = NUM_WORKERS, shuffle = True)

# Loading Model 

In [8]:
tam = modeling.TAM.from_pretrained(TAM_SAVE_PATH, device = DEVICE)

# Predicting Aspects

In [20]:
asps = {'asp1': None, 'asp2': None, 'title': []}

i = 0
for batch in dl:
    bow = batch['bow'].to(DEVICE).squeeze()
    asp1, asp2 = tam.pred_aspect_dists(bow)
    
    # asps['asp1'].append(asp1.cpu().detach())
    # asps['asp2'].append(asp2.cpu().detach())
    if asps['asp1'] is not None:
        asps['asp1'] = torch.concat((asps['asp1'], asp1.cpu().detach()), axis = 0)
        asps['asp2'] = torch.concat((asps['asp2'], asp2.cpu().detach()), axis = 0)
    else:
        asps['asp1'] = asp1.cpu().detach()
        asps['asp2'] = asp2.cpu().detach()
    asps['title'].append(batch['title'])
    print(f'Finished batch {i}', end = '\r')
    i+=1

Finished batch 3791

In [26]:
asps['title'] = [title for title_l in asps['title'] for title in title_l]

In [34]:
torch.save(asps['asp1'], TOGL_DIST_OUT + '/asp1.pt')
torch.save(asps['asp2'], TOGL_DIST_OUT + '/asp1.pt')
with open(TOGL_DIST_OUT + '/asp1.pt', 'w') as f:
    json.dump(asps['title'], f)

In [None]:
c_asps1 = {key: None for key in set(asps['title'])}
c_asps2 = {key: None for key in set(asps['title'])}

for i, title in enumerate(asps['title']):
    if c_asps1[title] is not None:
        c_asps1[title] += asps['asp1']
    else:
        c_asps1[title] = asps['asp1']
        
    if c_asps2[title] is not None:
        c_asps2[title] += asps['asp2']
    else:
        c_asps2[title] = asps['asp2']
    
    print(f'Finished {i} of {len(asps["title"])}', end = '\r')

Finished 13441 of 60671

# Vocabulary Mapping 

In [22]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum')

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.51k [00:00<?, ?B/s]

In [23]:
bart_vocab = tokenizer.get_vocab()

In [43]:
vocab_map = {}

In [45]:
for i, (term, idx) in enumerate(VECT_VOCAB.items()):
    term2 = term[0].upper() + term[1:]
    term3 = 'Ġ' + term
    
    vocab_map[idx] = []
    for t in (term, term2, term3):
        if term in bart_vocab.keys():
            vocab_map[idx].append(bart_vocab[term])
    if len(vocab_map[idx]) == 0:
        del vocab_map[idx]
    
    print(f'Finished {i} of {VOCAB_SIZE}', end = '\r')

Finished 9999 of 10000

In [50]:
torch.tensor([1, 2, 5, 3, 7, 4]).sort().values

(tensor([1, 2, 3, 4, 5, 7]), tensor([0, 1, 3, 5, 2, 4]))