In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
os.chdir("/content/drive/MyDrive/MS_DS/NLP/Final project/nlp-text-summarisation")

In [None]:
!git submodule init
!git submodule update

In [None]:
%pip install transformers
%pip install torch
%pip install rouge

### Load model

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch

tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
model = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-xsum")

if torch.cuda.is_available():
  model.to('cuda')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


### Run model and summarize text

In [None]:
def summarize(text):
    preprocess_text = text.strip().replace("\n","")
    tokenized_text = tokenizer.encode(preprocess_text, return_tensors="pt", max_length=512).to(device)
    summary_ids = model.generate(tokenized_text,
                                 min_length=30,
                                 max_length=100,
                                 early_stopping=True)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [None]:
import json
from generators import get_cnn_dm_both_generator

output = []

test_data_path = './dataset/chunked/test_000.bin'
for article, abstract in get_cnn_dm_both_generator(test_data_path):
    article_len = len(article)
    if article_len > 5000:
        continue

    bart_abstract = summarize(article)
    output.append({
        'article': article,
        'abstract': abstract,
        'pegasus_abstract': bart_abstract
    })

with open('pegasus_output_000.json', 'w') as fout:
    json.dump(output, fout, indent=2)

### Examine output

In [None]:
import pandas as pd

df = pd.read_json('pegasus_output_000.json')
df.head()

### ROUGE evaluation

In [None]:
from rouge import Rouge

rouge = Rouge()

In [None]:
pred_str = df['pegasus_abstract']
label_str = df['abstract']

rouge_output = rouge.get_scores(pred_str, label_str)

print(rouge_output)