In [1]:
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from pyro.optim import Adam, ClippedAdam

import torch
from torch import nn
from torch.nn import MultiheadAttention, TransformerEncoderLayer, TransformerEncoder, Linear, Dropout, LayerNorm
import torch.nn.functional as F
from torch.autograd import handle_torch_function, has_torch_function
from torch.utils.data import DataLoader

from treelib import Node, Tree
# import igraph
# from igraph import Graph, EdgeSeq

import pandas as pd
import numpy as np
from numpy.linalg import norm

from typing import Optional, List, Union, Tuple

import matplotlib.pyplot as plt
# import plotly.graph_objects as go

import math

import copy as copy

from graphviz import Source, Digraph
from IPython.core.display import SVG

import data_read, data_utils, dataset, model_utils, topic_metrics

In [2]:
DATA_DIR = '/proj/nlp/users/ndeas/acos/data/'
BASE_DIR = '/home/ndeas/abstract_cos/Abstract-Contrastive-Opinion-Summarization/data/'

# Data Processing 

In [3]:
UP_SOURCE = BASE_DIR + 'polisumm_eval_short.csv'

In [4]:
data = pd.read_csv(UP_SOURCE)

In [28]:
data['split_texts'] = data['all_texts'].str.split('|')
data['num_src'] = data['split_texts'].apply(len)
data['repeat_title'] = data.apply(lambda row: [row['title'] for _ in range(row['num_src'])], axis = 1)

In [29]:
split_texts = [src for slist in data['split_texts'].values for src in slist]
titles = [t for tlist in data['repeat_title'].values for t in tlist]
topic_df = pd.DataFrame({'text': split_texts, 'title': titles})

In [33]:
topic_df.to_csv(DATA_DIR + 'polisumm_tam_src.csv', index = None)

In [3]:
UP_SOURCE = DATA_DIR + 'polisumm_tam_src.csv'

## Pre Processing and Cleaning

In [4]:
OUTPUT = DATA_DIR + 'polisumm_tam_val.csv'
data_utils.read_and_process_data(UP_SOURCE, OUTPUT,
                                 test_frac = 0.2, val_frac = 0.1,
                                 max_vocab = 10000,
                                 ext_verbs = False)

Reading data from /proj/nlp/users/ndeas/acos/data/polisumm_tam_src.csv
Cleaning text data 


100%|███████████████████████████████████████████████████████████████████| 125873/125873 [00:59<00:00, 2115.95it/s]
100%|█████████████████████████████████████████████████████████████████████| 35964/35964 [00:17<00:00, 2012.03it/s]
100%|█████████████████████████████████████████████████████████████████████| 17983/17983 [00:08<00:00, 2111.34it/s]


Fitting vectorizer
Vectorizer trained with vocab size 10000
Filtering documents with less than 2 tokens
Saving data to /proj/nlp/users/ndeas/acos/data/polisumm_tam_val.csv/['train', 'test', 'val'].csv
/proj/nlp/users/ndeas/acos/data/polisumm_tam_val.csv does not exist, creating directory.
Saved all data, vectorizer, and author mapping to /proj/nlp/users/ndeas/acos/data/polisumm_tam_val.csv


In [4]:
OUTPUT = DATA_DIR + 'polisumm_tam_prod.csv'
data_utils.read_and_process_data(UP_SOURCE, OUTPUT,
                                 test_frac = 0, val_frac = 0,
                                 max_vocab = 10000,
                                 ext_verbs = False)

Reading data from /proj/nlp/users/ndeas/acos/data/polisumm_tam_src.csv
Cleaning text data 


100%|███████████████████████████████████████████████████████████████████| 179820/179820 [01:18<00:00, 2289.81it/s]


Fitting vectorizer
Vectorizer trained with vocab size 10000
Filtering documents with less than 2 tokens
Saving data to /proj/nlp/users/ndeas/acos/data/polisumm_tam_prod.csv/['train'].csv
/proj/nlp/users/ndeas/acos/data/polisumm_tam_prod.csv does not exist, creating directory.
Saved all data, vectorizer, and author mapping to /proj/nlp/users/ndeas/acos/data/polisumm_tam_prod.csv


# Model Definition 

## Layers 

In [5]:
class Encoder(nn.Module):
    """
        Base class for the document encoder used within the guide
    """
    
    def __init__(self, vocab_size, num_topics, hidden, dropout, encode_len: str = None, init_exp_len_root: float = None):
        
        super().__init__()
        
        self.vocab_size  = vocab_size
        self.num_topics  = num_topics
        self.hidden_size = hidden
        self.dropout     = dropout
        
        self.drop = nn.Dropout(dropout)  # to avoid component collapse
        self.fc1 = nn.Linear(vocab_size, hidden)
        
        self.fc2 = nn.Linear(hidden, hidden)
            
        self.fcmu = nn.Linear(hidden, num_topics)
            
        self.fclv = nn.Linear(hidden, num_topics)
            
        # NB: here we set `affine=False` to reduce the number of learning parameters
        # See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
        # for the effect of this flag in BatchNorm1d
        self.bnmu = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse
        self.bnlv = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse

    def forward(self, inputs):
        h = F.softplus(self.fc1(inputs))
        
        if self.encode_len == 'hidden_deep':
            len_tens = inputs.sum(-1)[:, None]
            h = torch.concat((h, len_tens), axis = -1)  
        h = F.softplus(self.fc2(h))
        
        h = self.drop(h)
        
        # μ and Σ are the outputs
        logtheta_loc = self.bnmu(self.fcmu(h))
            
        logtheta_logvar = self.bnlv(self.fclv(h))
        logtheta_scale = 1.0e-10 + (0.5 * logtheta_logvar).exp()  # Defenisvely enforces positivity
            
        return logtheta_loc, logtheta_scale
    

In [6]:
class AspEncoder(nn.Module):
    """
        Base class for the document encoder used within the guide
    """
    
    def __init__(self, vocab_size, num_aspects, hidden, dropout):
        
        super().__init__()
        
        self.vocab_size  = vocab_size
        self.num_topics  = num_topics
        self.hidden_size = hidden
        self.dropout     = dropout
        
        self.drop = nn.Dropout(dropout)  # to avoid component collapse
        self.fc1 = nn.Linear(vocab_size, hidden)
        
        self.fc2 = nn.Linear(hidden, hidden)
            
        self.fcmu = nn.Linear(hidden, num_aspects)
            
        self.fclv = nn.Linear(hidden, num_aspects)
            
        self.bnmu = nn.BatchNorm1d(num_aspects, affine=False)  # to avoid component collapse
        self.bnlv = nn.BatchNorm1d(num_aspects, affine=False)  # to avoid component collapse


    def forward(self, inputs):
        h = F.softplus(self.fc1(inputs))  
        h = F.softplus(self.fc2(h))
        
        h = self.drop(h)
        
        # μ and Σ are the outputs
        logtheta_loc = self.bnmu(self.fcmu(h))
            
        logtheta_logvar = self.bnlv(self.fclv(h))
        logtheta_scale = 1.0e-10 + (0.5 * logtheta_logvar).exp()  # Defenisvely enforces positivity
            
        return logtheta_loc, logtheta_scale
    

In [7]:
class Decoder(nn.Module):
    """
        Base class for the document decoder used in the model
    """
    
    def __init__(self, vocab_size, num_topics, dropout):
        
        super().__init__()
        
        self.vocab_size  = vocab_size
        self.num_topics  = num_topics
        self.dropout     = dropout
        
        self.beta = nn.Linear(num_topics, vocab_size, bias=False)
        self.bn = nn.BatchNorm1d(vocab_size, affine=False)
        self.drop = nn.Dropout(dropout)

    def forward(self, inputs):
        inputs = self.drop(inputs)
        # the output is σ(βθ)
        return self.bn(self.beta(inputs))

## Model 

In [None]:
class TAM(nn.Module):
    def __init__(self, num_topics:int, num_aspects: int
                 hidden_size: int = None,
                 vocab_size: int = None,
                 dropout: float = None,
                 encoder: Encoder = None, decoder: Decoder = None, 
                 encode_len: str = None, len_pow_val: float = None,
                 delta_prior: pyro.distributions.torch_distribution.TorchDistributionMixin = None,
                 delta_scale: int = 3,
                 device: torch.device = None):
        """
            ProdLDA implementation including model and guide.
            
            Parameters:
                num_topics: int
                    Number of latent topics assumed
                hidden_size: int (optional)
                    Size of the hidden layer in the Encoder
                vocab_size: int (optional)
                    Number of tokens in the vocabulary
                dropout: float (optional)
                    Dropout rate for the encoder
                encoder: models.layers.Encoder (optional)
                    Encoder layer(s) transforming document BOW's into a topic distribution
                decoder: models.layers.Decoder (optional)
                    Decoder layer(s) transforming topic distribution into a document BOW.
                    Essentially describes the Beta matrix num_topics x vocab_size
                delta_prior: pyro.distributions (optional)
                    Prior distribution on the latent topics
                device: torch.device (optional)
                    GPU/CPU device to use in training/inference
        """
        
        super().__init__()   
        
        self.num_topics = num_topics
        self.num_aspects = num_aspects
        
        if not ((hidden_size and vocab_size and dropout) or encoder):
            raise Exception('Either hidden_size, vocab_size, and dropout or an encoder must be specified.')
        if not ((vocab_size and dropout) or decoder):
            raise Exception('Either vocab_size or a decoder must be specified')
                
        self.hidden_size = hidden_size if hidden_size else encoder.hidden_size
        self.vocab_size  = vocab_size if vocab_size else encoder.vocab_size
        self.dropout     = dropout if dropout else encoder.dropout
        
        self.encode_len = encode_len
        
        self.encoder  = encoder if encoder else Encoder(vocab_size, num_topics, hidden_size, dropout, encode_len = encode_len, len_pow_val = len_pow_val)
        self.decoder = decoder if decoder else Decoder(vocab_size, num_topics, dropout)
        
        self.device = device if device else torch.device("cpu")
        self.to(self.device)
        
        self.delta_prior = delta_prior if delta_prior else dist.Normal(0, delta_scale * torch.ones(num_topics, device=self.device))

    def model(self,
              bows: torch.Tensor, 
              num_docs: int, 
              h2: torch.Tensor = None):
        """
            The ProdLDA model to generate a corpus of documents
            
            Parameters:
                bows: torch.Tensor
                    The corpus of source document BOWS for loss calculations
                num_docs: int
                    The total number of documents in the source corpus
                h2: torch.Tensor (optional, default = None)
                    If half (either ordered or random) of the document texts are passed for
                        bows, then the remaining half should be passed as h2
        """
        
        pyro.module("decoder", self.decoder)
                
        with pyro.plate("documents", num_docs, subsample = bows):
            delta = pyro.sample("delta", self.delta_prior.to_event(1))

            # Softmax to calculate theta, the distribution over topics: (Docs, Topics)
            theta = F.softmax(delta, -1)

            # Decode the topic distribution to generate distribution over words: (Docs, Vocab Size)
            logits = self.decoder(theta)

            # Maximum document length for multinomial distribution sampling of reconstruction
            total_count = int((bows if h2 is None else h2).sum(-1).max()) 

            # Sample document reconstruction from multinomial characterized by decoded topic distribution
            pyro.sample(
                'words',
                dist.Multinomial(total_count, logits = logits),
                obs=bows if h2 is None else h2
            )

    def guide(self, 
              bows: torch.Tensor, 
              num_docs: int, 
              h2: torch.Tensor = None):
        """
            The ProdLDA guide for learning latent variables
            
            Parameters:
                bows: torch.Tensor
                    The corpus of source document BOWS for loss calculations
                num_docs: int
                    The total number of documents in the source corpus
                h2: torch.Tensor (optional, default = None)
                    If half (either ordered or random) of the document texts are passed for
                        bows, then the remaining half should be passed as h2
        """
        
        pyro.module('encoder', self.encoder)

        # document plate
        with pyro.plate("documents", num_docs, subsample = bows):

            delta_loc, delta_sigma = self.encoder(bows.float())

            pyro.sample(f"delta", dist.Normal(delta_loc, delta_sigma).to_event(1))

    def beta(self):
        # beta matrix elements are the weights of the FC layer on the decoder
        return self.decoder.beta.weight.detach().T
    
    def get_doc_scale(self, bows):
        _, delta_sigma = self.encoder(bows.float())
        
        return delta_sigma
    
    def reconstruct_doc(self, bow, num_particles = 50):
        self.eval()

        num_docs = bow.shape[0]

        with torch.no_grad():
            delta_loc, delta_scale = self.encoder(bow)  

        delta_samples = dist.Normal(delta_loc, delta_scale).sample((num_particles,))

        theta = F.softmax(delta_samples, dim=-1)

        # decode for reconstruction
        with torch.no_grad():
            theta = theta.view(num_docs*num_particles, -1)
            word_logits = self.decoder(theta)
            word_logits = word_logits.view(num_particles, num_docs, -1)

        word_probs = torch.softmax(word_logits, axis=-1).mean(axis=0)

        return word_probs
    
    def calc_perplexity(self, test_half_loader, num_particles = 50, output_indiv = False):
        """
        Calculate perplexity
        """

        # Accumulated perplexity
        total_ce = 0
        total_num_words = 0
        
        if output_indiv:
            lengths = []
            perps   = []

        for i, batch in enumerate(test_half_loader):   
            bow = batch['bow'].to(self.device)

            bow_recon = self.reconstruct_doc(bow, num_particles = num_particles)
            ces       = (-bow*torch.log(bow_recon))
            total_ce += ces.sum().cpu().item()
            total_num_words += bow.sum()

            if output_indiv:
                new_lens = list(bow.sum(axis = -1).cpu().detach())
                new_perps = list(torch.exp(ces).sum(-1).cpu().detach())
                
                lengths += new_lens
                perps   += new_perps
            
        ce = total_ce / total_num_words
        perp = torch.exp(ce)

        if output_indiv:
            return perp, ce, lengths, perps
        else:
            return perp, ce
    
    def doc_completion_perplexity(self, test_half_loader, num_particles = 50, output_indiv = False):
        """
        Calculate document completion perplexity
        """

        # Accumulated perplexity
        total_ce = 0
        total_num_words = 0
        
        if output_indiv:
            lengths = []
            perps   = []

        for i, batch in enumerate(test_half_loader):   
            h1  = batch['bow_h1'].to(self.device)
            h2  = batch['bow_h2'].to(self.device)
            ids = batch['author_id']
            
            h2_recon = self.reconstruct_doc(h1, num_particles = num_particles)
            ces      = (-h2*torch.log(h2_recon))
            total_ce += ces.sum().cpu().item()
            total_num_words += h2.sum()
            
            if output_indiv:
                new_lens = list((h1+h2).sum(axis = -1).cpu().detach())
                new_perps = list(torch.exp(ces).sum(-1).cpu().detach())
                
                lengths += new_lens
                perps   += new_perps

        ce = total_ce / total_num_words
        perp = torch.exp(ce)
        
        if output_indiv:
            return perp, ce, lengths, perps
        else:
            return perp, ce
    
    def save(self, save_path:str):
        
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        torch.save(self.encoder.state_dict(), f'{save_path}/encoder.pt')
        torch.save(self.decoder.state_dict(), f'{save_path}/decoder.pt')
        
        with open(f'{save_path}/delta_prior.pkl', 'wb') as f:
            pkl.dump(self.delta_prior, f)
            
            
        config_dict = {
            'model_type': 'prodlda',
            'num_topics': self.num_topics,
            'hidden_size': self.hidden_size,
            'vocab_size': self.vocab_size,
            'dropout': self.dropout,
            'encode_len': self.encode_len,
            'len_pow_val': self.encoder.len_pow.data.item(),
        }
        
        with open(f'{save_path}/config.json', 'w') as f:
            json.dump(config_dict, f)
    
    @classmethod
    def from_pretrained(cls, model_path, device = None):
        
        with open(f'{model_path}/config.json', 'r') as f:
            config_dict = json.load(f)
        
        if config_dict['model_type'] != 'prodlda':
            raise Exception(f'Pretrained model of type {config_dict["model_type"]} is not of type ProdLDA')
        
        num_topics = config_dict['num_topics']
        encode_len = config_dict['encode_len'] if 'encode_len' in config_dict.keys() else None
        
        encoder = Encoder(config_dict['vocab_size'],
                         config_dict['num_topics'],
                         config_dict['hidden_size'],
                         config_dict['dropout'],
                         encode_len = encode_len,
                         len_pow_val = len_pow_val)
        
        decoder = Decoder(config_dict['vocab_size'],
                         config_dict['num_topics'],
                         config_dict['dropout'])
        
        encoder.load_state_dict(torch.load(f'{model_path}/encoder.pt'))
        decoder.load_state_dict(torch.load(f'{model_path}/decoder.pt'))
        
        with open(f'{model_path}/delta_prior.pkl', 'rb') as f:
            delta_prior = pkl.load(f)
                        
        model = cls(num_topics, 
                    encoder = encoder, decoder = decoder, 
                    delta_prior = delta_prior,
                    encode_len = encode_len, len_pow_val = len_pow_val,
                    device = device)
        
        return model