In [1]:
import pandas
import re, json
import csv

import torch
import torch.nn as nn
from datasets import load_metric,Dataset,DatasetDict, load_dataset, Sequence, Value
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, BartForConditionalGeneration
from transformers import AutoTokenizer, Trainer

import evaluate

import numpy as np
import nltk
import os
import random
from sklearn.model_selection import train_test_split
from typing import List, Optional, Tuple, Union, Dict, Any

In [2]:
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
_numpy_rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(False)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
max_input_length = 256
max_target_length = 128

In [5]:
model_checkpoint = "facebook/bart-large"
metric = evaluate.load("rouge")
model = BartForConditionalGeneration.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [6]:
dataset = load_dataset('samsum')

In [7]:
def tokenize_and_align_labels(examples):
    inputs = [doc for doc in examples['dialogue']]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors='pt', padding=True)

    with tokenizer.as_target_tokenizer():
        tokenized_inputs = tokenizer(examples["summary"], truncation=True, return_tensors='pt', padding=True)
        
    model_inputs['labels'] = tokenized_inputs['input_ids']
    return model_inputs

In [8]:
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

In [9]:
tokenized_datasets['train'] = tokenized_datasets['train'].remove_columns(['id','dialogue', 'summary'])
tokenized_datasets['validation'] = tokenized_datasets['validation'].remove_columns(['id','dialogue', 'summary'])
tokenized_datasets['test'] = tokenized_datasets['test'].remove_columns(['id','dialogue', 'summary'])

In [10]:
training_args = Seq2SeqTrainingArguments(
    output_dir="checkpoints/",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    save_total_limit=4,
    num_train_epochs=10,
    predict_with_generate=True,
    do_train=True,
    do_eval=True,
    fp16=True,
    logging_steps=1,
    save_strategy="epoch",
    greater_is_better=True,
    load_best_model_at_end=True,
    seed=42,
    generation_max_length=max_target_length,
)

In [11]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [12]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    print(f"Generated summary: {decoded_preds[0]}")

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    print(f"Gold summary: {decoded_labels[0]}")

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    # print(result)
    # result = {"rouge": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [13]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [14]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,0.7456,0.56735,0.4979,0.2533,0.4079,0.408,26.0782
2,0.4781,0.501947,0.5256,0.2841,0.4364,0.4363,28.368
3,0.4099,0.497809,0.5227,0.2807,0.4333,0.4333,28.5024
4,0.5148,0.491243,0.5231,0.2847,0.4363,0.4363,27.3472
5,0.2558,0.496187,0.5317,0.2918,0.4395,0.4395,28.1161
6,0.1866,0.495583,0.5305,0.2913,0.4411,0.4406,26.5966
7,0.4257,0.499362,0.5323,0.2914,0.4408,0.4405,27.6345
8,0.2905,0.500727,0.5295,0.2871,0.436,0.4357,29.3521
9,0.3899,0.503964,0.5272,0.2874,0.4345,0.4342,29.0978
10,0.3661,0.504823,0.5315,0.2919,0.4395,0.4394,28.8655


Generated summary: A wants to get a puppy for her son. She took him to the animal shelter last Monday. He showed her one that he really liked. He wanted to take it home right away.
Gold summary: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
Generated summary: A wants to get a puppy for her son. A took him to the animal shelter last Monday and he really liked it. A will get him one of those little dogs.
Gold summary: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
Generated summary: A wants to get a puppy for her son. B will go with her to the animal shelter tomorrow afternoon. A took her son to the shelter last Monday and he liked the puppy. A will get him one of those little dogs. 
Gold summary: A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelte

TrainOutput(global_step=2310, training_loss=0.7229478450951639, metrics={'train_runtime': 1943.9178, 'train_samples_per_second': 75.785, 'train_steps_per_second': 1.188, 'total_flos': 7.981446249578496e+16, 'train_loss': 0.7229478450951639, 'epoch': 10.0})

In [15]:
trainer.evaluate(tokenized_datasets['test'])

Generated summary: Hannah doesn't have Betty's number. She doesn't know him well, but he called her last time they were at the park together. She will text him.
Gold summary: Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.


{'eval_loss': 0.5453633666038513,
 'eval_rouge1': 0.4821,
 'eval_rouge2': 0.239,
 'eval_rougeL': 0.3989,
 'eval_rougeLsum': 0.3983,
 'eval_gen_len': 25.6606,
 'eval_runtime': 69.0753,
 'eval_samples_per_second': 11.857,
 'eval_steps_per_second': 0.188,
 'epoch': 10.0}

In [1]:
!nvidia-smi

Wed Nov 15 16:16:49 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:A1:00.0 Off |                    0 |
| N/A   33C    P0              63W / 300W |  20016MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    