In [3]:
import pandas as pd
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


from transformers import BartTokenizer
from transformers import BartForConditionalGeneration
from transformers.modeling_utils import PreTrainedModel
from transformers.generation_utils import top_k_top_p_filtering, BeamSearchScorer
from transformers.pytorch_utils import torch_int_div


import data_utils, dataset, model_utils, topic_metrics

from tqdm import tqdm
tqdm.pandas()

In [4]:
device = torch.device('cuda:0')

# Loading Tokenizers 

In [5]:
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

In [10]:
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn').to(device)

# Custom Decoding Function 

In [187]:
class ToGLDecoder:
    
    def __init__(self,
                model: PreTrainedModel,
                tokenizer,
                top_p: float = 1.0,
                togl_func: str = 'sum',
                togl_func_kwargs: dict = None,
                device = None):
        '''
            Parameters:
                -model: PreTrainedModel
                    A Huggingface pretrained model capable of generating text
                -top_p: float
                    Parameter for top-p sampling decoding method
                -togl_func
                    Function used to combine model predictions and topic model word distribution.
                    Defaults to the sum of the generation and topic model word distributions with weight 1.
                -togl_func_kwargs
                    Keyword arguments to pass to the togl_func beyond word distribution parameters
                -device
                    Torch/Cuda device to use while generating
        '''
        
        self.model = model
        self.tokenizer = tokenizer
        self.vocab_size = self.model.lm_head.out_features
        
        self.top_p = top_p
        if type(togl_func) == str:
            assert togl_func in ('sum'), f'togl_func {togl_func} has not been implemented'
            if togl_func == 'sum':
                self.togl_func = self.togl_sum
        else:
            self.togl_func = togl_func
        
        self.togl_func_kwargs = togl_func_kwargs
        
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    @torch.no_grad()
    def generate(self,
                    inputs: torch.Tensor,
                    togl_tuple: tuple,
                    togl_start: int = 2,
                    togl_weight: int = 0.1,
                    use_cache: bool = True,
                    decoder_start_token_id = None,
                    num_beams: int  = 3,
                    no_repeat_ngram_size = 3,
                    min_length: int = 16,
                    max_length: int = 1024,
                    early_stopping: bool = True,
                    **model_kwargs):
        '''
            Generates a sequence using beam search sampling incorporating ToGL-Decoding.
            
            Code drawn from 
                - https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/generation/utils.py#L998
                    - beam_sample function
                - https://github.com/megagonlabs/cocosum/blob/main/decode.py
                    - generate_function
        '''
        
        batch_size = 1
        
        inputs           = inputs.to(self.device)
        
        togl_probs       = self.togl_convert(togl_tuple, togl_weight)
        togl_probs       = togl_probs.to(self.device)
            
        inputs_t, model_input_name, model_kwargs = self.model._prepare_model_inputs(inputs, self.tokenizer.bos_token_id, model_kwargs)
        
        model_kwargs = self.model._prepare_encoder_decoder_kwargs_for_generation(
            inputs_t, model_kwargs, model_input_name
        )
        
        input_ids = self.model._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=self.tokenizer.bos_token_id,
                bos_token_id=self.tokenizer.bos_token_id,
                model_kwargs=model_kwargs,
                device=self.device,
            )
        
        logits_processor = self.model.model._get_logits_processor(
            repetition_penalty = None,
            no_repeat_ngram_size = no_repeat_ngram_size,
            encoder_no_repeat_ngram_size=None,
            input_ids_seq_length = input_ids.shape[-1],
            encoder_input_ids = inputs_t,
            min_length=min_length,
            max_length=max_length,
            eos_token_id=self.tokenizer.eos_token_id,
            forced_bos_token_id=None,
            forced_eos_token_id=None,
            num_beams=num_beams,
            num_beam_groups=None,
            diversity_penalty=None,
            remove_invalid_values=None,
            bad_words_ids = None,
            prefix_allowed_tokens_fn = None,
            exponential_decay_length_penalty = None,
            logits_processor = [],
            renormalize_logits = None,
        )
        
        
        stopping_criteria = self.model.model._get_stopping_criteria(
            max_length = max_length, max_time = None, stopping_criteria = []
        )
        
        # Setup beam scorer for searching generations
        beam_scorer = BeamSearchScorer(
            batch_size = batch_size,
            num_beams = num_beams,
            device = self.device,
            do_early_stopping = early_stopping,
            num_beam_hyps_to_keep = 1
        )
        
        input_ids, model_kwargs = self.model._expand_inputs_for_generation(
            input_ids, expand_size=num_beams, is_encoder_decoder=True, **model_kwargs
        )
        
        batch_size = len(beam_scorer._beam_hyps)
        
        batch_beam_size, cur_len = input_ids.shape
        
        beam_scores = torch.zeros((batch_size, num_beams), 
                                  dtype = torch.float, 
                                  device = self.device)
        beam_scores[:, 1:] = -1e-9
        beam_scores = beam_scores.view((batch_size * num_beams,))
        beam_indices = (None)
        
        
        
        while True:            
        
            model_in = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
            outputs = self.model(**model_in, 
                                 return_dict=True,
                                 output_attentions = False,
                                 output_hidden_states = False)
            
            # Update togl_logits to zero selected terms in output
            togl_probs[input_ids] = 0. #float('-inf')
                        
            # Modify logits
            raw_logits = outputs.logits
            if cur_len >= togl_start:
                mod_logits = self.togl_func(raw_logits, togl_probs)
            else:
                mod_logits = raw_logits
            mod_logits = mod_logits[:, -1, :]
            
            next_logits = self.model.model.adjust_logits_during_generation(mod_logits, cur_len = cur_len)
            next_scores = F.log_softmax(next_logits, dim = -1)
            
            next_scores_pp = logits_processor(input_ids, next_scores)
            next_scores = next_scores_pp + beam_scores[:, None].expand_as(next_scores)
            
            vocab_size  = next_scores.shape[-1]
            next_scores = next_scores.view(batch_size, num_beams * vocab_size)
            
            next_scores, next_tokens = torch.topk(
                next_scores, 2 * num_beams, dim = 1, largest = True, sorted = True
            )
            
            next_idxs = torch_int_div(next_tokens, vocab_size)
            next_tokens = next_tokens % vocab_size
            
            beam_outputs = beam_scorer.process(
                input_ids,
                next_scores,
                next_tokens,
                next_idxs,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                beam_indices=beam_indices,
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
            
            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim = -1)
            
            model_kwargs = self.model.model._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder = True
            )
            
            cur_len += 1
            
            if beam_scorer.is_done or stopping_criteria(input_ids, None):
                break
            
        seq_outputs = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_idxs,
            pad_token_id = self.tokenizer.pad_token_id,
            eos_token_id = self.tokenizer.eos_token_id,
            max_length = stopping_criteria.max_length,
            beam_indices = beam_indices,
        )
        
        return seq_outputs['sequences']
    
    def togl_sum(self, raw_out, togl_probs):
        norms = raw_out.norm(dim = -1)
        togl_probs = togl_probs.unsqueeze(0).repeat((raw_out.shape[0], 1))
        mod_logits = ((raw_out.squeeze()/norms) + togl_probs) * norms
        mod_logits = mod_logits.unsqueeze(1)
        return mod_logits
    
    def togl_convert(self, togl_tuple, togl_weight):
        full_dist = torch.zeros((self.vocab_size), device = self.device)
        probs = togl_tuple[0]
        idxs  = togl_tuple[1]
        full_dist[idxs] = probs
        if full_dist.sum() > 0:
            full_dist = (full_dist / full_dist.sum()) * togl_weight
        
        print(full_dist.sum())
        
        return full_dist
        

In [15]:
inputs = ['The top US hostage affairs official on Sunday reflected on conducting the prisoner swap that led to Brittney Griner’s release, saying the WNBA star immediately thanked the crew returning her to the United States.\n“When she finally got on to the US plane, I said, ‘Brittney, you must have been through a lot over the last 10 months. Here’s your seat. Please feel free to decompress. We’ll give you your space,’” Special Presidential Envoy for Hostage Affairs Roger Carstens told CNN’s Dana Bash on “State of the Union.”\n“And she said, ‘Oh no. I’ve been in prison for 10 months now listening to Russian, I want to talk. But first of all, who are these guys?’ And she moved right past me and went to every member on that crew, looked them in the eyes, shook their hands and asked about them and got their names, making a personal connection with them. It was really amazing,” Carstens recalled. “And then later on, on an 18 hour flight, she probably spent 12 hours just talking and we talked about everything under the sun.', 'goodbye my fellow compatriate']

In [16]:
tok_inputs = bart_tokenizer(inputs, padding = True, return_tensors = 'pt')

In [17]:
tok_inputs = tok_inputs.to(device)

In [49]:
output = model.generate(tok_inputs['input_ids'], min_length = 10, max_length = 20, 
                        num_beams = 3, no_repeat_ngram_size = 3)

In [50]:
output

tensor([[    2,     0, 14323,   382, 16301,  5185,   781,  7680,    15,     5,
         16796, 12313,    14,   669,     7, 16278,  2596,  2974,  5101,     2],
        [    2,     0,  8396, 33542,   127,  2598, 20840,  1069,   877,     4,
             2,     1,     1,     1,     1,     1,     1,     1,     1,     1]],
       device='cuda:0')

In [51]:
bart_tokenizer.decode(output[0], skip_special_tokens = True)

'Top US hostage affairs official reflected on the prisoner swap that led to Brittney Griner'

In [169]:
togl_dist = (
    [0.0005 * i for i in range(15)],
    [i + 5 for i in range(15)]
)

togl_dist = (torch.tensor(togl_dist[0]).to(device), torch.tensor(togl_dist[1]).to(device))

In [191]:
decoder = ToGLDecoder(model, bart_tokenizer,
                     device = device
                     )

In [196]:
output = decoder.generate(tok_inputs['input_ids'][0].unsqueeze(0), 
                          togl_dist, togl_start = 5,
                          togl_weight = 0.1,
                          min_length = 2, max_length = 64,
                          no_repeat_ngram_size = 3,
                          num_beams = 5)

tensor(0.1000, device='cuda:0')


Ġ

In [197]:
data = pd.read_csv('../../../../data/polisum_clean.csv')

Unnamed: 0,article_url,date,title,left_sum,right_sum,linked_arts,left_op,right_op,linked_arts_clean,reddit_text,twitter_text,sm_text,num_reddit,num_twitter,num_sm,h1_text,h2_text,sm_text_primera
0,https:/theflipside.io/archives/checks-in-the-mail,2020-03-20,Checks In The Mail,The left supports cash payments and argues tha...,The right is generally supportive of helping c...,['https://www.politico.com/news/magazine/2020/...,One of the strangest spectacles in the economi...,"Fiscal conservatives, like the Tea Partiers of...",['https://www.politico.com/news/magazine/2020/...,Is the U.S.|||Headed Toward a Short British-St...,"With the bailout, President Trump and Congres...",Adam Brandon: $1 trillion coronavirus economic...,15,54,69,Republicans would have screamed bloody murder ...,Headed Toward a Short British-Style Election?|...,Adam Brandon: $1 trillion coronavirus economic...
1,https:/theflipside.io/archives/eos-regarding-t...,2021-01-26,EOs Regarding Transgender Rights,The left is supportive of both policies.,The right is critical of both policies.,['https://www.spectator.co.uk/article/biden-s-...,[The Trump administration] tried to argue in c...,The existing policy—that of former President D...,['https://www.spectator.co.uk/article/biden-s-...,Missouri state lawmaker charged with selling f...,"By me for @SpecCoffeeHouse: ""boys who identify...","“ This year, state lawmakers also want to rest...",15,76,91,Trump resigns from Screen Actors Guild in rant...,Trans people can serve openly in the US milita...,"“ This year, state lawmakers also want to rest..."
2,https:/theflipside.io/archives/facebook-and-br...,2019-04-02,Facebook and Breaking Up Big Tech,"The left is skeptical of Zuckerberg’s motives,...",The right is disturbed by the free speech impl...,['https://medium.com/@teamwarren/heres-how-we-...,All of this might sound reasonable on its face...,Facebook's nominal raison d'etre is to serve a...,['https://medium.com/@teamwarren/heres-how-we-...,Fox news calls out Trump lie about Wikileaks||...,Remember when people criticized @ewarren for s...,Corporate self-governance has failed in part b...,12,79,91,UN experts warn Assange arrest exposes him to ...,Do you see the fall of Fakebook?|||#facebook #...,Corporate self-governance has failed in part b...
3,https:/theflipside.io/archives/general-electio...,2020-10-28,General Election Update,The left is optimistic about Biden’s chances.,The right is cautiously optimistic about Trump...,['https://theflipside.us15.list-manage.com/tra...,"Right now, Joe Biden is vastly outspending Don...","At this point, you should be saying something ...",['https://projects.fivethirtyeight.com/trump-b...,"Text: Joe Biden, on verge of victory says, ‘We...","It's morning, and like we expected we still do...",“The mistake the Clinton campaign made in Mich...,20,59,79,While Nate makes the case in here no Biden isn...,<URL>|||<URL>|||Pennsylvanians -- history has ...,“The mistake the Clinton campaign made in Mich...
4,https:/theflipside.io/archives/impeachment-hea...,2019-11-22,Impeachment Hearings Continue,"The left supports impeachment, arguing that th...","The right opposes impeachment, arguing that th...",['https://theflipside.us15.list-manage.com/tra...,‘Everyone was in the loop. It was no secret.’ ...,Polls show the vast majority of Americans agre...,['https://nypost.com/2019/11/21/fiona-hill-and...,Trade talks should also be climate talks|||Tru...,Fiona Hill (and Dems) ignore the serious evide...,<UNAME> <UNAME> <UNAME> <UNAME> <UNAME> Can so...,18,68,86,Banker who signed off Trump loans found dead a...,"<URL>|||""And in 2018 House testimony, Nellie O...",<UNAME> <UNAME> <UNAME> <UNAME> <UNAME> Can so...
