In [1]:
import pandas as pd
import torch
from tqdm.notebook import tqdm
from datasets import load_dataset
from transformers import (AutoModel, AutoConfig, AutoTokenizer,AutoModelForSequenceClassification,
                          pipeline, Trainer, TrainingArguments,EarlyStoppingCallback)
from utils.text_processing import (get_summarizer, perform_summarizer, text_filter, 
                                   get_sentiment_model, get_topic_model,compute_metrics)

In [None]:
def preprocess_input_data(batch):
    # take a batch
    texts = batch['summary']
    
    # encode them
    text_encoded = tokenizer(texts, padding='max_length',truncation=True, max_length=256)
    
    return text_encoded

## Load Trained Model Weights

In [None]:
device = 'cuda'
#######
model_ckpt = None
#######
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

# Initiate model
summary_model_name = 'facebook/bart-large-cnn' # pretrained model hosted on HuggingFace
summary_model = get_summarizer(summary_model_name)

## Load Text Data & Run Concated Summary with CLS

In [None]:
data = pd.read_csv('../data/BERT/articles_2015_2019_train_fold-1.csv')

cls = {} 
summaries = {}

for group in tqdm(data.groupby(by = 'timestamp')):
    date_idx = group[0]
    concat_news = ''
    for text in group[1].text:
        concat_news = concat_news + text 
    
    if len(concat_news.split()) >= 250:
        if len(concat_news.split()) >= 5000:
            concat_news = ' '.join(concat_news.split()[:2000])
        summary = perform_summarizer(concat_news, summary_model, ratio = 0.5, return_embeddings = False)
        summary_len = len(summary.split())        
        
        while summary_len >= 250:
            summary = perform_summarizer(summary, summary_model, ratio = 0.8, return_embeddings = False)
            summary_len = len(summary.split())    
        
        
        
        tokens = tokenizer(summary, padding='max_length',truncation=True, max_length=512)
        input_tokens = torch.tensor([tokens['input_ids']])
        pooler_output = model(input_tokens).last_hidden_state.mean(dim=1) # [1, 768] 
        pooler_output = pooler_output.flatten().tolist()
        
        cls[date_idx] = pooler_output
        summaries[date_idx] = summary
    else:
        tokens = tokenizer(concat_news, 
                           padding='max_length',
                           truncation=True, 
                           max_length=512, 
                           add_special_tokens=True)
        
        input_tokens = torch.tensor([tokens['input_ids']])
        pooler_output = model(input_tokens).last_hidden_state.mean(dim=1) # [1, 768] 
        pooler_output = pooler_output.flatten().tolist()
        cls[date_idx] = pooler_output
        summaries[date_idx] = summary

### Save Data

In [None]:
cls_path = '../data/BERT/articles_2015_2019_train_fold-1_CLS.txt'
sum_path = '../data/BERT/articles_2015_2019_train_fold-1_concated_summaries.txt'

with open(cls_path,'w') as f:
    f.write(str(cls))
    
with open(sum_path,'w',encoding='utf-8') as f:
    f.write(str(summaries))