In [5]:
from transformers import (BigBirdPegasusForConditionalGeneration, AutoTokenizer, 
DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer )
from datasets import load_dataset, load_metric
import numpy as np

In [6]:
data_files = 'ccdv/cnn_dailymail'
model_name = 'google/bigbird-pegasus-large-pubmed'

In [7]:
train_dataset = load_dataset(data_files, '3.0.0', split="train")
val_dataset = load_dataset(data_files, '3.0.0', split="validation")

Reusing dataset cnn_dailymail (/Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f)
Reusing dataset cnn_dailymail (/Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f)


For the sake of this trial, I only select a few samples in the database,
please comment the cell below to train the model on full dataset

In [8]:
train_dataset = train_dataset.select(range(100))
val_dataset = val_dataset.select(range(20))

In [9]:
def preprossess_text(example):
    return {'article':example['article'].replace('\'', ''), 
            'highlights':example['highlights'].replace('\'', ''), 
            'length': len(example['article'].split())}

In [10]:
def lowercase(example):
    return{'article':example['article'].lower(),
          'highlights':example['highlights'].lower()}

In [11]:
train_dataset = train_dataset.map(preprossess_text)
train_dataset = train_dataset.filter(lambda x: x['length'] > 1000)
train_dataset = train_dataset.map(lowercase)
train_dataset = train_dataset.remove_columns("id")

Loading cached processed dataset at /Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f/cache-736a077279fb0b7c.arrow
Loading cached processed dataset at /Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f/cache-ef00eadf1766c0ac.arrow
Loading cached processed dataset at /Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f/cache-17eca76de43ae8c7.arrow


In [12]:
len(train_dataset)

23

In [13]:
val_dataset = val_dataset.map(preprossess_text)
val_dataset = val_dataset.filter(lambda x: x['length'] > 1000)
val_dataset = val_dataset.map(lowercase)
val_dataset = val_dataset.remove_columns("id")

Loading cached processed dataset at /Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f/cache-7e644a1c6ca1171e.arrow
Loading cached processed dataset at /Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f/cache-110cc2bebd3a0506.arrow
Loading cached processed dataset at /Users/johnchou/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f/cache-d73c9e11c6ad04de.arrow


In [14]:
len(val_dataset)

2

In [15]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [17]:
tokenizer('Let’s walk through this code to understand what’s happening.')

{'input_ids': [1593, 123, 116, 1102, 224, 136, 929, 112, 630, 180, 123, 116, 3114, 107, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [18]:
max_input_length = 4096
max_target_length = 512
batch_size = 3

In [19]:
def tokenize(examples):
    model_inputs = tokenizer(
        examples["article"], max_length=max_input_length, truncation=True, return_tensors='pt'
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["highlights"], max_length=max_target_length, truncation=True, return_tensors='pt'
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [20]:
train_dataset_tokenized = train_dataset.map(tokenize, batched=True)

  0%|          | 0/1 [00:00<?, ?ba/s]

In [21]:
val_dataset_tokenized = val_dataset.map(tokenize, batched=True)

  0%|          | 0/1 [00:00<?, ?ba/s]

In [22]:
val_dataset_tokenized

Dataset({
    features: ['article', 'highlights', 'length', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 2
})

In [23]:
train_dataset_tokenized.remove_columns(train_dataset.column_names)
val_dataset_tokenized.remove_columns(val_dataset.column_names)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2
})

In [24]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model_name)

In [25]:
model = BigBirdPegasusForConditionalGeneration.from_pretrained(model_name)

In [26]:
rouge = load_metric("rouge")

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

In [29]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    
    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [30]:
num_train_epochs = 8
# Show the training loss with every epoch
logging_steps = len(train_dataset_tokenized) // batch_size

In [32]:
args = Seq2SeqTrainingArguments(
    output_dir=f"./model/{model_name}-finetuned-on-cnn_news",
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
)

In [33]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset_tokenized,
    eval_dataset=val_dataset_tokenized,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BigBirdPegasusForConditionalGeneration.forward` and have been ignored: length, highlights, article. If length, highlights, article are not expected by `BigBirdPegasusForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 23
  Num Epochs = 8
  Instantaneous batch size per device = 3
  Total train batch size (w. parallel, distributed & accumulation) = 3
  Gradient Accumulation steps = 1
  Total optimization steps = 64
Input ids are automatically padded from 2044 to 2048 to be a multiple of `config.block_size`: 64


In [None]:
trainer.evaluate()