In [None]:
!pip install transformers
!pip install rouge-score
!pip install tqdm

In [None]:
from transformers import AutoTokenizer, BartForConditionalGeneration
from rouge_score import rouge_scorer
import pandas as pd
import numpy as np
import random
import torch

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

In [None]:
class CNN_Mail(torch.utils.data.Dataset):
    def __init__(self, file_name):
        self.data = pd.read_csv(file_name)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.data.iloc[idx].article, self.data.iloc[idx].highlights

In [None]:
test_data = CNN_Mail(file_name = '/content/test.csv')

In [None]:
for x, y in test_data:
  # print(x)
  print(y)
  break

In [None]:
from tqdm import tqdm
results = {'rouge1': np.zeros(3), 'rouge2': np.zeros(3), 'rougeL': np.zeros(3)}

for data in tqdm(test_data):
  article, ref = data

  inputs = tokenizer(article, return_tensors="pt", max_length=1024).to(device)
  summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=100)
  
  res = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL', 'rouge2'], use_stemmer=True)
  scores = scorer.score(res, ref)
  for key in scores.keys():
    results[key] += np.array(scores[key])



In [None]:
# average over all test data
for key in results.keys(): 
  results[key] /= len(test_data)
results

In [None]:
# final results on test data
'''
{'rouge1': array([0.45841702, 0.44183878, 0.44161249]),
 'rouge2': array([0.23055623, 0.21361195, 0.21270519]),
 'rougeL': array([0.33750412, 0.31121467, 0.35731044])}
  '''


In [None]:
# Print Average F1 metrics 
print("average R1: ", results['rouge1'][3])
print("average R2: ", results['rouge2'][3])
print("average RL: ", results['rougeL'][3])

average R1: 0.44161249
average R2: 0.21270519
average RL: 0.35731044
