In [1]:
import pandas
import re, json
import csv

import torch
import torch.nn as nn
from datasets import load_metric,Dataset,DatasetDict, load_dataset
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, BartForConditionalGeneration
from transformers import AutoTokenizer, Trainer

import evaluate

import numpy as np
import nltk
import os
import random
from sklearn.model_selection import train_test_split
from typing import List, Optional, Tuple, Union, Dict, Any
from jointbart import myBartForConditionalGeneration

In [2]:
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
_numpy_rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(False)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def convert_to_iob(d):
    for i in range(len(d)):
        for j in range(len(d[i])):
            if d[i][j] != 'O':
                d[i][j] = 'B-' + d[i][j]
    
    return d

In [5]:
model_checkpoint = "facebook/bart-large"
metric = evaluate.load("rouge")

In [6]:
max_input_length = 256
max_target_length = 128

In [7]:
model = myBartForConditionalGeneration.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)

Some weights of myBartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
dataset = load_dataset('pvisnrt/capstone_hal')

In [9]:
dataset

DatasetDict({
    train: Dataset({
        features: ['source', 'summary_target', 'tags'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['source', 'summary_target', 'tags'],
        num_rows: 10
    })
    test: Dataset({
        features: ['source', 'summary_target', 'tags'],
        num_rows: 10
    })
})

In [10]:
def tokenize_and_align_labels(examples):
    inputs = [doc for doc in examples['source']]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, is_split_into_words=True, return_tensors='pt', padding=True)

    with tokenizer.as_target_tokenizer():
        tokenized_inputs = tokenizer(examples["summary_target"], truncation=True, is_split_into_words=True, return_tensors='pt', padding=True)

    labels = []
    for i, label in enumerate(examples["tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)# Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    model_inputs['labels'] = tokenized_inputs['input_ids']

    model_inputs["decoder_tags"] = labels
    
    return model_inputs

In [11]:
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

In [12]:
tokenized_datasets['train'] = tokenized_datasets['train'].remove_columns(['source','summary_target', 'tags'])
tokenized_datasets['validation'] = tokenized_datasets['validation'].remove_columns(['source','summary_target', 'tags'])
tokenized_datasets['test'] = tokenized_datasets['test'].remove_columns(['source','summary_target', 'tags'])

tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'decoder_tags'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'decoder_tags'],
        num_rows: 10
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'decoder_tags'],
        num_rows: 10
    })
})

In [13]:
len(tokenized_datasets['validation']['decoder_tags'][0])

44

In [14]:
class MySeq2SeqTrainer(Seq2SeqTrainer):
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "decoder_tags" in inputs
        inputs = self._prepare_inputs(inputs)
        
        # print("prediction_step inputs: {}".format(inputs.keys()))

        # XXX: adapt synced_gpus for fairscale as well
        gen_kwargs = self._gen_kwargs.copy()
        if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
            gen_kwargs["max_length"] = self.model.config.max_length
        gen_kwargs["num_beams"] = (
            gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
        )
        # default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
        default_synced_gpus = False
        gen_kwargs["synced_gpus"] = (
            gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
        )

        if "attention_mask" in inputs:
            gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
        if "global_attention_mask" in inputs:
            gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)

        # prepare generation inputs
        # some encoder-decoder models can have varying encoder's and thus
        # varying model input names
        if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
            generation_inputs = inputs[self.model.encoder.main_input_name]
        else:
            generation_inputs = inputs[self.model.main_input_name]

        tags = inputs["decoder_tags"]
        gen_kwargs.update({"decoder_tags": tags})
        # print(f"Gen kwargs: {gen_kwargs}")
        # print(f"Gen inputs:{generation_inputs}")
        generated_tokens = self.model.generate(
            generation_inputs,
            **gen_kwargs,
        )
        # in case the batch is shorter than max length, the output should be padded
        if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
        elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
            gen_kwargs["max_new_tokens"] + 1
        ):
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)

        with torch.no_grad():
            if has_labels:
                with self.compute_loss_context_manager():
                    outputs = model(**inputs) # linear_logits as output
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["decoder_tags"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return (loss, None, None)

        if has_labels:
            labels = inputs["decoder_tags"]
            if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
                labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
            elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
                gen_kwargs["max_new_tokens"] + 1
            ):
                labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
        else:
            labels = None
        # print(labels)

        return (loss, generated_tokens, labels)

In [15]:
training_args = Seq2SeqTrainingArguments(
    output_dir="checkpoints/",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    save_total_limit=4,
    num_train_epochs=15,
    predict_with_generate=True,
    do_train=True,
    do_eval=True,
    fp16=True,
    logging_steps=1,
    save_strategy="epoch",
    metric_for_best_model="eval_rouge1",
    greater_is_better=True,
    load_best_model_at_end=True,
    seed=42,
    generation_max_length=max_target_length,
)

In [16]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [17]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # print("In compute metrics")
    # print(predictions[0])
    # print(labels[0])
    
    preds = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    print(result.items())
    #result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [18]:
trainer = MySeq2SeqTrainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [19]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,9.5795,6.072535,0.2667,0.0,0.2667,0.2667,9.0
2,6.8896,4.142837,0.0331,0.0095,0.0331,0.0331,78.3
3,4.6912,3.5249,0.0533,0.0154,0.0535,0.0535,32.4
4,4.8622,3.159065,0.0571,0.0167,0.0571,0.0571,28.5
5,4.3967,2.873141,0.0606,0.0167,0.0606,0.0606,25.8
6,3.5974,2.680887,0.0708,0.0182,0.0708,0.0708,25.9
7,3.0621,2.443796,0.0717,0.0222,0.0717,0.0717,23.0
8,2.8287,2.193723,0.0628,0.0182,0.0628,0.0628,22.5
9,2.5288,2.03026,0.0619,0.0167,0.0619,0.0619,22.5
10,2.2865,1.881976,0.0674,0.0182,0.0674,0.0674,23.0


dict_items([('rouge1', 0.26666666666666666), ('rouge2', 0.0), ('rougeL', 0.26666666666666666), ('rougeLsum', 0.26666666666666666)])
dict_items([('rouge1', 0.03308361204013377), ('rouge2', 0.009523809523809523), ('rougeL', 0.03308361204013377), ('rougeLsum', 0.03308361204013378)])
dict_items([('rouge1', 0.053333333333333344), ('rouge2', 0.015384615384615385), ('rougeL', 0.05345238095238096), ('rougeLsum', 0.05345238095238096)])
dict_items([('rouge1', 0.05714285714285714), ('rouge2', 0.01666666666666667), ('rougeL', 0.05714285714285715), ('rougeLsum', 0.05714285714285715)])
dict_items([('rouge1', 0.06062271062271063), ('rouge2', 0.01666666666666667), ('rougeL', 0.06062271062271063), ('rougeLsum', 0.06062271062271063)])
dict_items([('rouge1', 0.07076923076923076), ('rouge2', 0.01818181818181818), ('rougeL', 0.07076923076923076), ('rougeLsum', 0.07076923076923076)])
dict_items([('rouge1', 0.07174825174825175), ('rouge2', 0.02222222222222222), ('rougeL', 0.07174825174825175), ('rougeLsum', 

TrainOutput(global_step=300, training_loss=3.891774559020996, metrics={'train_runtime': 372.0294, 'train_samples_per_second': 3.226, 'train_steps_per_second': 0.806, 'total_flos': 650142716313600.0, 'train_loss': 3.891774559020996, 'epoch': 15.0})

In [20]:
!nvidia-smi

Mon Oct 23 15:21:00 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:A1:00.0 Off |                    0 |
| N/A   57C    P0             285W / 300W |  76340MiB / 81920MiB |     99%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    