In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -U datasets

In [None]:
%%capture
!pip install transformers==4.19.2
!pip install rouge_score

In [None]:
from datasets import load_metric
import pandas as pd
df = pd.read_csv("/content/drive/MyDrive/Code Cycle/articlesSet.csv")
df.head()

In [None]:
print(df.shape)
df = df.dropna()
print(df.shape)

In [None]:
df['length'] = df.paragraph.map(lambda x: len(x.split(" ")))

In [None]:
numOfWords = df.length
from matplotlib import pyplot as plt

# Creating plot
fig = plt.figure(figsize=(5, 3))

plt.hist(numOfWords.to_numpy(), bins=[0, 500, 1000, 1500,
                                    2000, 2500, 3000, 3500, 4000, 5000, 6000, 7000, 8000, 9000])

plt.title("Word count distribution")

# show plot
plt.show()

In [None]:
tempDf = df[df.length < 800]
tempDf = tempDf[tempDf.length >=100]
tempDf.shape

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384")

In [None]:
max_input_length = 1024
max_target_length = 1024
batch_size = 4

def process_data_to_model_inputs(batch):
    # tokenize the inputs and labels
    inputs = tokenizer(
        batch["summary"],
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
    )

    outputs = tokenizer(
        batch["content"],
        padding="max_length",
        truncation=True,
        max_length=max_target_length,
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask

    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]

    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = outputs.input_ids

    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]

    return batch

In [None]:
import numpy as np

train, validate, test = np.split(tempDf.sample(frac=1, random_state=42), [int(.4*len(df)), int(.5*len(df))])


In [None]:
print(train.shape, validate.shape)

In [None]:
train = train[0:250]
validate = validate[25:50]

print(train.shape, validate.shape)

In [None]:
from datasets import Dataset

train_dataset = Dataset.from_pandas(train)
val_dataset = Dataset.from_pandas(validate)

In [None]:
train_dataset = train_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["content", "summary", "length","__index_level_0__"],
)

In [None]:
val_dataset = val_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["content", "summary", "length", "__index_level_0__"],
)

In [None]:
train_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"]
)

val_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"]
)

In [None]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments

led = AutoModelForSeq2SeqLM.from_pretrained("allenai/led-base-16384", gradient_checkpointing=True, use_cache=False)

# Set generate hyperparameters
led.config.num_beams = 2
led.config.max_length = 1024
led.config.min_length = 512
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3

rouge = load_metric("rouge")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )
    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

# Enable fp16 apex training
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    output_dir="./",
    logging_steps=5,
    eval_steps=10,
    save_steps=10,
    save_total_limit=2,
    gradient_accumulation_steps=4,
    num_train_epochs=2
)

In [None]:
trainer = Seq2SeqTrainer(
  model= led,
  tokenizer=tokenizer,
  args=training_args,
  compute_metrics =compute_metrics,
  train_dataset=train_dataset,
  eval_dataset=val_dataset,
)

In [None]:
trainer.train()

In [None]:
sample = tempDf.sample(frac=0.005, random_state=12)
sample.shape

In [None]:
sample = sample[['content', 'summary']]
sample['content']

In [None]:
sample['summary'][505]

In [None]:
from datasets import Dataset

pubmed_test = Dataset.from_pandas(sample)

import torch

from datasets import load_dataset, load_metric
from transformers import LEDTokenizer, LEDForConditionalGeneration

# Load tokenizer
tokenizer = LEDTokenizer.from_pretrained("/content/checkpoint-100")

model = LEDForConditionalGeneration.from_pretrained("/content/checkpoint-100").to("cuda").half()

def generate_answer(batch):
    inputs_dict = tokenizer(batch["summary"], padding="max_length", max_length=1924, return_tensors="pt", truncation=True)

    input_ids = inputs_dict.input_ids.to("cuda")
    attention_mask = inputs_dict.attention_mask.to("cuda")
    global_attention_mask = torch.zeros_like(attention_mask)

    # Put global attention on token
    global_attention_mask[:, 0] = 1

    predicted_abstract_ids = model.generate(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
    batch["predicted_content"] = tokenizer.batch_decode(predicted_abstract_ids, skip_special_tokens=True)

    return batch

result = pubmed_test.map(generate_answer, batched=True, batch_size=2)

In [None]:
result['content'][1]

In [None]:
result['predicted_content'][1]