In [1]:
!nvidia-smi

Mon Feb 21 14:41:02 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
!pip install -q transformers datasets torchinfo rouge_score git+https://github.com/google-research/bleurt.git

In [2]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [14]:
RANDOM_SEED = 42

# Dataset loading

In [None]:
# load already tokenized dataset
from datasets import load_dataset, load_from_disk

DATA_PATH = "/gdrive/MyDrive/final-project/post-refactor/data/"
TOKENIZED_DATASET_PATH = os.path.join(DATA_PATH, "tokenized_bigbird_dataset")

dataset = load_from_disk(TOKENIZED_DATASET_PATH)

# split dataset into test and train
dataset = dataset.train_test_split(test_size=0.10, seed=RANDOM_SEED)

# Model loading

In [None]:
from transformers import BigBirdPegasusForConditionalGeneration
from torchinfo import summary

FINETUNE_MODEL_PATH = os.path.join(DATA_PATH, "BigBirdModelFineTune/", "final/")

model = BigBirdPegasusForConditionalGeneration.from_pretrained(
    FINETUNE_MODEL_PATH,
    block_size=16,
    num_random_blocks=3,
    attention_type="block_sparse",
    use_cache=False) # required for fp16
model.gradient_checkpointing_enable()
summary(model, dtypes=["torch.IntTensor"])

# Generation

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/bigbird-pegasus-large-bigpatent")
valid_sample = dataset["valid"]["input_ids"][0]

# print summary
tokenizer.decode(valid_sample, skip_special_tokens=True)

In [None]:
# print claim
tokenizer.decode(dataset["valid"]["decoder_input_ids"][0], skip_special_tokens=True)

In [None]:
import torch
valid_sample = torch.tensor(valid_sample)
inputs = tokenizer([tokenizer.decode(valid_sample, skip_special_tokens=True)], 
                   max_length=2048, 
                   return_tensors="pt", 
                   truncation=True)

In [None]:
# set return_num_sequences > 1
beam_outputs = model.generate(
    inputs["input_ids"], 
    max_length=50, 
    num_beams=5, 
    #no_repeat_ngram_size=2,
    repetition_penalty=0.5,
    #num_return_sequences=5, 
    early_stopping=True
)

# now we have 3 output sequences
print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))

In [None]:
from datasets import load_metric
import numpy as np
import nltk

ROUGE = load_metric('rouge')
SACREBLEU = load_metric('sacrebleu')
BLEURT = load_metric('bleurt', 'bleurt-large-512')
SARI = load_metric('sari')

def compute_metrics(input, reference, predicted): 
  d_input = tokenizer.batch_decode(input, skip_special_tokens=True)
  d_pred = tokenizer.batch_decode(predicted, skip_special_tokens=True)
  # Replace -100 in the labels to actual padding
  reference = torch.where(reference != -100, reference, tokenizer.pad_token_id)
  d_label = tokenizer.batch_decode(reference, skip_special_tokens=True)

  rouge_scores = ROUGE.compute(references=reference, predictions=predicted)
  rouge_scores = { k: v.mid.fmeasure * 100 for k, v in rouge_scores.items() }

  sacrebleu_score = SACREBLEU.compute(predictions=d_pred, references=[d_label])
  sacrebleu_score = sacrebleu_score["score"]

  bleurt_score = BLEURT.compute(predictions=d_pred, references=d_label)
  bleurt_score = bleurt_score["scores"][0] * 100

  sari_score = SARI.compute(predictions=d_pred, sources=d_input, references=[d_label])
  sari_score = sari_score["sari"]
  
  return {
      "sacrebleu": sacrebleu_score, 
      **rouge_scores, 
      "bleurt": bleurt_score,
      "sari": sari_score
  }