In [None]:
from google.colab import drive #1397
drive.mount('/content/drive')

Mounted at /content/drive


# Imports

In [None]:
# crash colab to get more RAM
# !kill -9 -1

In [None]:
! pip install datasets transformers rouge-score nltk sentencepiece



In [None]:
VERSION = "1.8.1"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [None]:
import transformers
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

print(transformers.__version__)

4.9.2


# Data Preprocessing

In [None]:
from datasets import load_dataset, load_metric

data = load_dataset("multi_news")
rouge = load_metric("rouge")

Downloading:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/918 [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset multi_news/default (download: 245.06 MiB, generated: 667.74 MiB, post-processed: Unknown size, total: 912.80 MiB) to /root/.cache/huggingface/datasets/multi_news/default/1.0.0/2e145a8e21361ba4ee46fef70640ab946a3e8d425002f104d2cda99a9efca376...


Downloading: 0.00B [00:00, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset multi_news downloaded and prepared to /root/.cache/huggingface/datasets/multi_news/default/1.0.0/2e145a8e21361ba4ee46fef70640ab946a3e8d425002f104d2cda99a9efca376. Subsequent calls will reuse this data.


Downloading:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [None]:
for i in data:
  data[i] = data[i].select(range(200))

data

DatasetDict({
    train: Dataset({
        features: ['document', 'summary'],
        num_rows: 200
    })
    validation: Dataset({
        features: ['document', 'summary'],
        num_rows: 200
    })
    test: Dataset({
        features: ['document', 'summary'],
        num_rows: 200
    })
})

In [None]:
# import pandas as pd

# length_distribution = pd.Series([len(i) for i in data['train']['document']])#.value_counts(normalize = True)

# sum(length_distribution <= 2048 * 8) / len(length_distribution)

In [None]:
from transformers import AutoTokenizer

model_name = "t5-small" # allenai/led-large-16384-arxiv, microsoft/prophetnet-large-uncased-cnndm
prefix = "summarize: "

tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Tokenization

In [None]:
max_input_length = 2048
max_target_length = 512

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding='max_length')

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True, padding='max_length')

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
processed_data = data.map(preprocess_function, batched=True, remove_columns=["document", "summary"])
processed_data.set_format(type='torch')

processed_data

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 200
    })
    validation: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 200
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 200
    })
})

# Fine Tuning

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, gradient_checkpointing=True, use_cache=False)

model.train()
WRAPPED_MODEL = xmp.MpModelWrapper(model)

Downloading:   0%|          | 0.00/242M [00:00<?, ?B/s]

In [None]:
import nltk
import numpy as np

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

# TPU

In [None]:
def train_loop(model, batch_size=2):
    """
    This contains everything that must be done to train our models
    """
    print("Training... ", end="")

    training_args = Seq2SeqTrainingArguments(
        "Summarization",
        evaluation_strategy = "epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay=0.01,
        save_total_limit=3,
        num_train_epochs=1,
        tpu_num_cores=8,
        prediction_loss_only=True
    )

    trainer = Seq2SeqTrainer(
        model,
        training_args,
        train_dataset=processed_data["train"],
        eval_dataset=processed_data["validation"],
        compute_metrics=compute_metrics
    )

    trainer.place_model_on_device = False
    trainer.train()

    model.save_pretrained('/content/drive/MyDrive/Summarization/T5')

In [None]:
def _mp_fn(index):

    device = xm.xla_device()

    model = WRAPPED_MODEL.to(device)

    train_loop(model)

xmp.spawn(_mp_fn, start_method="fork")

***** Running training *****
  Num examples = 44972
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 2811


Training... 

Epoch,Training Loss,Validation Loss


Training... 

Epoch,Training Loss,Validation Loss


Training... 

Epoch,Training Loss,Validation Loss


Training... 

Epoch,Training Loss,Validation Loss


Training... 

Epoch,Training Loss,Validation Loss


Training... 

Epoch,Training Loss,Validation Loss


Training... 

Epoch,Training Loss,Validation Loss


Training... 

Epoch,Training Loss,Validation Loss


Saving model checkpoint to Summarization/checkpoint-500
Configuration saved in Summarization/checkpoint-500/config.json
Model weights saved in Summarization/checkpoint-500/pytorch_model.bin
Saving model checkpoint to Summarization/checkpoint-1000
Configuration saved in Summarization/checkpoint-1000/config.json
Model weights saved in Summarization/checkpoint-1000/pytorch_model.bin


KeyboardInterrupt: ignored

# Evaluation

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("/content/drive/MyDrive/Summarization/T5") #20528

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [None]:
import gc

eval_dataset = processed_data["validation"]
del processed_data
gc.collect()

batch_size = 1

args = Seq2SeqTrainingArguments(
    "Summarization",
    per_device_eval_batch_size=batch_size,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

trainer.evaluate()

***** Running Evaluation *****
  Num examples = 200
  Batch size = 1


# Evaluation (Generation)

In [None]:
from datasets import load_dataset, load_metric

eval_dataset = load_dataset("multi_news", split='validation[:1%]').select(range(20))
rouge = load_metric("rouge")

eval_dataset

Using custom data configuration default
Reusing dataset multi_news (/root/.cache/huggingface/datasets/multi_news/default/1.0.0/2e145a8e21361ba4ee46fef70640ab946a3e8d425002f104d2cda99a9efca376)


Dataset({
    features: ['document', 'summary'],
    num_rows: 20
})

In [None]:
from transformers import AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained("/content/drive/MyDrive/Summarization/T5")

tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [None]:
max_input_length = 2048
max_target_length = 512
prefix = "summarize: "

def generate_answer(batch):
  ARTICLE = [prefix + batch["document"]]
  inputs_dict = tokenizer(ARTICLE, padding="max_length", max_length=max_input_length, return_tensors="pt", truncation=True)
  inputs_dict = {k: inputs_dict[k] for k in inputs_dict}
  predicted_abstract_ids = model.generate(**inputs_dict, max_length=max_target_length, num_beams=3, length_penalty=0.8, no_repeat_ngram_size=2)
  batch["predicted"] = tokenizer.decode(predicted_abstract_ids[0], skip_special_tokens=True)
  print(batch["predicted"])
  return batch

result = eval_dataset.map(generate_answer)

  0%|          | 0/20 [00:00<?, ?ex/s]

the most donated author to Oxfam's 700 high street shops has sold more than 80 million copies. 'the Lost Symbol' is a cult crime writer responsible to heavy-weight hardbacks, sonny mehta says - and is the second most sold author of the year he has been given away to charity shops. the charity says it is raising money for its first national book festival, Bookfest, in July.
a lack of communication by the agency about the delays has left service members facing severe financial hardships. this week, the VA announced that the information technology fixes will not be completed until the end of 2019 — more than one year past the original Aug. 1 deadline for finishing this work.
a video purportedly from AQAP said it planned and financed the attack. the satirical magazine's editor-in-chief drew threats from militant websites and criticism from the al qaeda savage islamists in the region, he said in an interview with three french fighters praising the attacks on the french newspaper, the media 

In [None]:
rouge_output = rouge.compute(
    predictions=result["predicted"], references=result["summary"][:max_target_length], rouge_types=["rouge2"]
)["rouge2"].mid

print({
    "rouge2_precision": round(rouge_output.precision, 4),
    "rouge2_recall": round(rouge_output.recall, 4),
    "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
})

{'rouge2_precision': 0.1658, 'rouge2_recall': 0.0444, 'rouge2_fmeasure': 0.0683}


# Text Generation

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("/content/drive/MyDrive/Summarization/T5")

In [None]:
ARTICLE = data['test']['document'][0]

ARTICLE

'Whether a sign of a good read; or a comment on the \'pulp\' nature of some genres of fiction, the Oxfam second-hand book charts have remained in The Da Vinci Code author\'s favour for the past four years. \n  \n Dan Brown has topped Oxfam\'s \'most donated\' list again, his fourth consecutive year. Having sold more than 80 million copies of The Da Vinci Code and had all four of his novels on the New York Times bestseller list in the same week, it\'s hardly surprising that Brown\'s hefty tomes are being donated to charity by readers keen to make some room on their shelves. \n  \n Another cult crime writer responsible to heavy-weight hardbacks, Stieg Larsson, is Oxfam\'s \'most sold\' author for the second time in a row. Both the \'most donated\' and \'most sold\' lists are dominated by crime fiction, trilogies and fantasy, with JK Rowling the only female author listed in either of the Top Fives. \n  \n Click here or on "View Gallery" to see both charts in pictures ||||| A woman reads a

In [None]:
inputs = tokenizer.encode(prefix + ARTICLE, return_tensors="pt", max_length=512, truncation=True)

outputs = model.generate(
    inputs, 
    max_length=150, 
    min_length=40,
    num_beams=5, 
    no_repeat_ngram_size=2, 
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


Output:
----------------------------------------------------------------------------------------------------
the second-hand book charts have remained in Oxfam's favour for the past four years. the cult crime writer has sold more than 80 million copies of The da Vinci Code a year - his fourth consecutive year in the series'most sold' the list is dominated by crime fiction, trilogies and fantasy, with JK Rowling the only female author listed in either of the Top Fives. "we are seeing historic, record-breaking sales across all types of our accounts in North America for 


Beam Search

In [None]:
outputs = model.generate(
    inputs, 
    max_length=150, 
    min_length=40,
    num_beams=5, 
    no_repeat_ngram_size=2,
    repetition_penalty=2.0, 
    num_return_sequences=3, 
    early_stopping=True
) 

print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

Sampling

In [None]:
outputs = model.generate(
    inputs,
    do_sample=True, 
    max_length=150, 
    top_k=50, 
    top_p=0.95, 
    num_return_sequences=3
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(outputs):
  print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

# References

*   https://colab.research.google.com/drive/1dVEfoxGvMAKd0GLnrUJSHZycGtyKt9mr#scrollTo=y1lbgZUyBc8l
*   https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/summarization.ipynb#scrollTo=545PP3o8IrJV



