In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration, RobertaTokenizer

def get_device_map() -> str:
    return 'cuda' if torch.cuda.is_available() else 'cpu'

device = get_device_map()

In [2]:
import pandas as pd

comment_df = pd.read_csv('sampled_data.csv')
comment_df

Unnamed: 0,comment_text,satd,comment_length
0,"// if we are the dest and is a call action, cr...",0,112
1,// TR#18 1.2,0,12
2,//Ignore manifest entries. They're bound to c...,0,167
3,"//NOTE: unlike all other Loaders, this one is ...",0,87
4,// no error as default,0,22
...,...,...,...
617,// the path to the plugin.xml descriptor file ...,0,100
618,// Test to see if correct suffix was used to c...,0,65
619,// TODO: figure out why bind variables aren't ...,1,53
620,// i'th argument,0,16


In [3]:
class Summarizer:
    def __init__(self, model_checkpoint):
        self.model_checkpoint = model_checkpoint
        self.set_tokenizer(self.model_checkpoint)
        self.set_model(self.model_checkpoint)
        self.model_name = model_checkpoint.split("/")[-1]
        
    def set_tokenizer(self, model_checkpoint):
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    
    def set_model(self, model_checkpoint):
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, device_map="cuda:0")
    
    def generate_summary(self, df, text_column='comment_text', prompt=None):
        
        summaries = []
        if prompt == None:
            prompt = "Produce a summary of the following text:"
        else:
            prompt = prompt
        
        for comment in df[text_column]:    
            input_text = f"{prompt} {comment}"
            inputs = self.tokenizer(input_text, return_tensors="pt")
            attention_mask = inputs["attention_mask"].to("cuda")
            input_ids = inputs['input_ids'].to("cuda")

            outputs = self.model.generate(input_ids, 
                                          attention_mask=attention_mask, 
                                          do_sample=True,
                                          num_return_sequences=1, 
                                          )
            summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            summaries.append(summary)

        df['summary'] = summaries
        
        return df
    
    def export_summaries(self, df):
        df.to_excel(f"summarization_results/{self.model_name}_summaries.xlsx", index=False)


In [4]:
class FlanSummarizer(Summarizer):    
    def __init__(self, model_checkpoint):
        super().__init__(model_checkpoint)
        self.model_checkpoint = model_checkpoint
        
class BartLargeSummarizer(Summarizer):   
    def __init__(self, model_checkpoint):
        super().__init__(model_checkpoint)
        self.model_checkpoint = model_checkpoint    
        
    def set_model(self, model_checkpoint):
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
        self.model.load_state_dict(self.model.state_dict(), assign=True)
        device = "cuda:0"
        self.model.to(device)

In [None]:
torch.cuda.empty_cache() 

# summarizer = BartLargeSummarizer("facebook/bart-large-cnn")
# summarizer = FlanSummarizer("jordiclive/flan-t5-3b-summarizer")
summarizer = Summarizer("Falconsai/text_summarization")

summaries = summarizer.generate_summary(comment_df)

summarizer.export_summaries(summaries)

del summarizer
torch.cuda.empty_cache()