In [None]:
!pip -q install accelerate
!apt-get install git-lfs
!git lfs install
!pip -q install evaluate
!pip -q install rouge_score

In [None]:
import os
import nltk
import numpy as np
import pandas as pd
from tqdm import tqdm
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
os.environ['TOKENIZERS_PARALLELISM']='false'

import torch
from torch.utils.data import DataLoader

import evaluate
from datasets import Dataset, load_dataset, DatasetDict
from transformers import BartTokenizer, BartForConditionalGeneration

import warnings
warnings.filterwarnings("ignore")

In [None]:
df = pd.read_csv("/kaggle/input/extract-scien-summ/extracted_summary.csv")
df.head()

In [None]:
df['text_len'] = df.apply(lambda x: len(x['text'].split()), axis = 1)
df.text_len.describe()

In [None]:
df.info()

In [None]:
ds = load_dataset("csv", data_files="/kaggle/input/extract-scien-summ/extracted_summary.csv")
ds = ds['train'].train_test_split(test_size=0.01)
ds

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
sns.set_palette("muted")
sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})

def plot_hist(data, title, xlabel, ylabel):
    plt.figure(figsize=(16, 5))
    plt.hist(data, bins=100, color='g')
    plt.gca().set(title=title, xlabel=xlabel, ylabel=ylabel)
    plt.show()

plot_hist([len(x.split()) for x in ds['train']['text']], 'Text Length Distribution', 'Length', 'Number of Samples')
plot_hist([len(x.split()) for x in ds['train']['summary']], 'Summary Length Distribution', 'Length', 'Number of Samples')
plot_hist([len(x.split()) for x in ds['train']['extracted_summary']], 'Extracted Summary Length Distribution', 'Length', 'Number of Samples')

In [None]:
model_checkpoint = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_checkpoint)

max_input_length = 1024
max_target_length = 150

def tokenize(example):
    model_inputs = tokenizer(
        example['extracted_summary'],
        truncation=True,
        max_length = max_input_length
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example['summary'],
            truncation=True,
            max_length = max_target_length
        )
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_ds = ds.map(tokenize, batched=True, remove_columns=ds['train'].column_names)
tokenized_ds

In [None]:
from transformers import DataCollatorForSeq2Seq, BartForConditionalGeneration 

model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
train_dataloader = DataLoader(
    tokenized_ds["train"], shuffle=True, collate_fn=data_collator, batch_size=8
)
eval_dataloader = DataLoader(
    tokenized_ds["test"], collate_fn=data_collator, batch_size=8
)

for batch in train_dataloader:
    break
print({k:v.shape for k, v in batch.items()})

In [None]:
from transformers import get_scheduler

num_epochs = 10
num_training_steps = num_epochs * len(train_dataloader)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

In [None]:
import wandb

wandb.init(project="summarization", config={"learning_rate": 5e-5, "epochs": num_epochs, "batch_size": 8})
wandb.watch(model, log="all")

In [None]:
from huggingface_hub import notebook_login, Repository, get_full_repo_name

notebook_login()

In [None]:
from accelerate import Accelerator
from huggingface_hub import notebook_login, Repository, get_full_repo_name

notebook_login()
repo_name = get_full_repo_name(
    "bart-large-cnn-finetuned-scientific_summarize",
    organization = "SmartPy"
)
output_dir = "./output"
repo = Repository(output_dir, clone_from=repo_name)

accelerate = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerate.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
from datasets import load_metric
rouge_score = load_metric("rouge")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    
    # metrics expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = [["\n".join(nltk.sent_tokenize(label))] for label in labels]
    return preds, labels

In [None]:
from tqdm import tqdm
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_epochs):
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerate.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        wandb.log({"train_loss": loss.item()})
    model.eval()
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)
            loss = outputs.loss
            wandb.log({"eval_loss": loss.item()})
            generated_tokens = accelerate.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
            )

            generated_tokens = accelerate.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            labels = batch["labels"]

            # If we did not pad to max length, we need to pad the labels too
            labels = accelerate.pad_across_processes(
                batch["labels"], dim=1, pad_index=tokenizer.pad_token_id
            )

            generated_tokens = accelerate.gather(generated_tokens).cpu().numpy()
            labels = accelerate.gather(labels).cpu().numpy()

            # Replace -100 in the labels as we can't decode them
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            decoded_preds = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            decoded_preds, decoded_labels = postprocess_text(
                decoded_preds, decoded_labels
            )

            metrics.add_batch(predictions=decoded_preds, references=decoded_labels)

    # Compute metrics
    result = metrics.compute()
    # Extract the median ROUGE scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    result = {k: round(v, 4) for k, v in result.items()}
    print(f"Epoch {epoch}:", result)

    # Save and upload
    accelerate.wait_for_everyone()
    unwrapped_model = accelerate.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerate.save)
    if accelerate.is_main_process:
        tokenizer.save_pretrained(output_dir)
        repo.push_to_hub(
            commit_message=f"Training in progress epoch {epoch}", blocking=False
        )

In [None]:
accelerate.remove_slurm_checkpoint()

In [None]:
import gc
gc.collect()

In [None]:

from transformers import pipeline

hub_model_id = "SmartPy/bart-large-cnn-finetuned-scientific_summarize"
summarizer = pipeline("summarization", model=hub_model_id)

def summarize_research_papers(idx):
    review = ds["test"][idx]["extracted_summary"]
    title = ds["test"][idx]["summary"]
    summary = summarizer(ds["test"][idx]["review_body"])[0]["summary_text"]
    print(f"'>>> Paper: {review}'")
    print(f"\n'>>> Summary: {title}'")
    print(f"\n'>>> Generated Summary: {summary}'")

summarize_research_paper(3)

In [None]:
# import shutil 
# shutil.rmtree("/kaggle/working/bart-large-cnn-finetuned-scientific_summarize/checkpoint-1500")