# Script to finetune BART model

This script generates a fine tuned BART model from the fine tuning data

## Basic Setup

In [None]:
# Install dependencies
!pip install pytorch-lightning
!pip install rouge-score nltk



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import sys
sys.path.insert(0, '../')
import transformers
import pandas as pd
import numpy as np
import glob
import math
import random
import re
import argparse
import nltk
from transformers import Trainer, TrainingArguments

## Documentation: Text Summarization with PyTorch Lightning

## **Overview**
This code implements a framework for fine-tuning transformer models (e.g., BART) for text summarization tasks using PyTorch Lightning. Below are the key components and their functionality:

---

## **Imports**
- **`nltk`, `transformers`**: For text preprocessing and working with pre-trained transformer models.
- **PyTorch (`torch`, `torch.nn.functional`, `torch.utils.data`)**: Provides tools for neural network building, training, and dataset handling.
- **`pytorch_lightning`**: A wrapper for PyTorch to streamline training workflows.
- **`ModelCheckpoint`**: Saves the best model during training.

---

## **LitModel (LightningModule)**
This is the core model training and inference logic encapsulated in a PyTorch Lightning module.

1. **Initialization (`__init__`)**:
   - Takes inputs like the transformer model, tokenizer, and learning rate.
   - Includes options to freeze encoder and embedding layers to improve efficiency.

2. **Methods**:
   - `forward`: Defines how data passes through the model.
   - `configure_optimizers`: Configures the optimizer (Adam).
   - `training_step`: Implements the training logic, computing loss using cross-entropy.
   - `validation_step`: Similar to `training_step`, used for validation loss computation.
   - `generate_text`: Generates text summaries using the model's `generate` method.

---

## **SummaryDataModule (LightningDataModule)**
Handles data preparation, splitting, and loading for training.

1. **`prepare_data`**:
   - Splits the data into training, validation, and test sets (60/20/20).

2. **`setup`**:
   - Encodes source and target sentences using the tokenizer.

3. **Data Loaders**:
   - `train_dataloader`: Prepares batched training data.
   - `val_dataloader`: Prepares batched validation data.
   - `test_dataloader`: Prepares batched test data.

---

## **Utility Functions**
1. **`freeze_params`**:
   - Freezes parameters of specific model layers for faster training.

2. **`shift_tokens_right`**:
   - Shifts input token sequences to the right for proper alignment during training.

3. **`encode_sentences`**:
   - Tokenizes and encodes source/target sentences for training and evaluation.

---

## **Key Features**
- **Freezing Layers**: Freezes encoder and embeddings for efficiency.
- **Efficient Data Handling**: Uses PyTorch DataLoader for batched data processing.
- **Text Generation**: Implements summarization using beam search.

---

## **How to Use**
1. Load your dataset into a pandas DataFrame.
2. Initialize the tokenizer and model (e.g., from Hugging Face's Transformers library).
3. Create instances of `LitModel` and `SummaryDataModule`.
4. Use PyTorch Lightning's `Trainer` to train the model.


In [None]:
# Import
import glob
from nltk import tokenize
import nltk
import transformers
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset
import pandas as pd
import numpy as np
import torch.nn.functional as F
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

#Source - https://colab.research.google.com/drive/1Cy27V-7qqYatqMA7fEqG2kgMySZXw9I4?usp=sharing&pli=1
class LitModel(pl.LightningModule):
  # Instantiate the model
  def __init__(self, learning_rate, tokenizer, model):
    super().__init__()
    self.tokenizer = tokenizer
    self.model = model
    self.learning_rate = learning_rate
    # self.freeze_encoder = freeze_encoder
    # self.freeze_embeds_ = freeze_embeds
#     self.hparams = argparse.Namespace()

    self.hparams.freeze_encoder = True
    self.hparams.freeze_embeds = True
    self.hparams.eval_beams = 4
    # self.hparams = hparams

    if self.hparams.freeze_encoder:
      freeze_params(self.model.get_encoder())

    if self.hparams.freeze_embeds:
      self.freeze_embeds()

  def freeze_embeds(self):
    ''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
    freeze_params(self.model.model.shared)
    for d in [self.model.model.encoder, self.model.model.decoder]:
      freeze_params(d.embed_positions)
      freeze_params(d.embed_tokens)

  # Do a forward pass through the model
  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx):
    # Load the data into variables
    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]
    # Shift the decoder tokens right (but NOT the tgt_ids)
    decoder_input_ids = shift_tokens_right(tgt_ids, self.tokenizer.pad_token_id)

    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]
    # Create the loss function
    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    # Calculate the loss on the un-shifted tokens
    loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    # Log the validation loss
    self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    return {'loss':loss}

  def validation_step(self, batch, batch_idx):

    src_ids, src_mask = batch[0], batch[1]
    tgt_ids = batch[2]

    decoder_input_ids = shift_tokens_right(tgt_ids, self.tokenizer.pad_token_id)

    # Run the model and get the logits
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
    lm_logits = outputs[0]

    ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
    val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

    # Log the validation loss
    self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    return {'loss': val_loss}

  # Method that generates text using the BartForConditionalGeneration's generate() method
def generate_text(self, text, eval_beams, early_stopping = True, max_len = 1024):
    ''' Function to generate text '''
    generated_ids = self.model.generate(
        text["input_ids"],
        attention_mask=text["attention_mask"],
        use_cache=True,
        decoder_start_token_id = self.tokenizer.pad_token_id,
        num_beams= eval_beams,
        max_length = max_len,
        early_stopping = early_stopping
    )
    return [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]

def freeze_params(model):
  ''' Function that takes a model as input (or part of a model) and freezes the layers for faster training
      adapted from finetune.py '''
  for layer in model.parameters():
    layer.requires_grade = False


# Create a dataloading module as per the PyTorch Lightning Docs
class SummaryDataModule(pl.LightningDataModule):
  def __init__(self, tokenizer, df, batch_size):
    super().__init__()
    self.tokenizer = tokenizer
    self.batch_size = batch_size
    self.data = df

  # Loads and splits the data into training, validation and test sets with a 60/20/20 split
  def prepare_data(self):
    self.train, self.validate, self.test = np.split(self.data.sample(frac=1), [int(.6*len(self.data)), int(.8*len(self.data))])

  # encode the sentences using the tokenizer
  def setup(self, stage):
    self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target'])
    self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target'])
    self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target'])

  # Load the training, validation and test sets in Pytorch Dataset objects
  def train_dataloader(self):
    dataset = TensorDataset(self.train['input_ids'], self.train['attention_mask'], self.train['labels'])
    train_data = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size = self.batch_size)
    return train_data

  def val_dataloader(self):
    dataset = TensorDataset(self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels'])
    val_data = DataLoader(dataset, batch_size = self.batch_size)
    return val_data

  def test_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels'])
    test_data = DataLoader(dataset, batch_size = self.batch_size)
    return test_data



def shift_tokens_right(input_ids, pad_token_id):
  """ Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
      This is taken directly from modeling_bart.py
  """
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=1024, min_length = 512, pad_to_max_length=True, return_tensors="pt"):
  ''' Function that tokenizes a sentence
      Args: tokenizer - the BART tokenizer; source and target sentences are the source and target sentences
      Returns: Dictionary with keys: input_ids, attention_mask, target_ids
  '''

  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}

  for sentence in source_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)

  for sentence in target_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=min_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )
    # Shift the target ids to the right
    # shifted_target_ids = shift_tokens_right(encoded_dict['input_ids'], tokenizer.pad_token_id)
    target_ids.append(encoded_dict['input_ids'])

  target_ids = torch.cat(target_ids, dim = 0)


  batch = {
      "input_ids": input_ids,
      "attention_mask": attention_masks,
      "labels": target_ids,
  }

  return batch

## **Documentation: Data Loading Function**

### **Overview**
The `load_data` function loads text data for training and testing from specific directories containing case files and corresponding summaries. It retrieves data from the specified paths for both the training and testing datasets, and structures it into a format suitable for training a model.

### **Function: `load_data(data)`**

This function loads the source (case details) and target (summaries) text files into a list of dictionaries. The function expects a list, `data`, which will be populated with dictionaries where each dictionary contains a "source" and "target" key.

---

### **Steps in the Function**

1. **Loading Training Data**:
    - **Path**: `/content/drive/MyDrive/IN-Abs/train-data/summary`
      - Loads all `.txt` files containing summaries into the `summary` list.
    - **Path**: `/content/drive/MyDrive/IN-Abs/train-data/judgement`
      - Loads all `.txt` files containing case details into the `source` list.
    - Pairs the case details (`source`) with their corresponding summaries (`target`) and appends them to the `data` list in dictionary format.

2. **Loading Testing Data**:
    - **Path**: `/content/drive/MyDrive/IN-Abs/test-data/summary`
      - Loads all `.txt` files containing test summaries into the `summary` list.
    - **Path**: `/content/drive/MyDrive/IN-Abs/test-data/judgement`
      - Loads all `.txt` files containing test case details into the `source` list.
    - Pairs the test case details (`source`) with their corresponding summaries (`target`) and appends them to the `data` list in dictionary format.

3. **Returning the Data**:
    - The function returns the populated `data` list, which contains the source-summary pairs for both training and testing.

---

### **Parameters**
- `data`: A list to store the source-summary pairs. This list is populated by the function.

### **Returns**
- `data`: A list of dictionaries where each dictionary has:
  - `"source"`: The case details (text from `.txt` files in the "judgement" directories).
  - `"target"`: The corresponding summary (text from `.txt` files in the "summary" directories).

---

### **Example Usage**
```python
data = []
data = load_data(data)


In [None]:
import glob

def load_data(data):
  #Path to the summary
  path = "/content/drive/MyDrive/IN-Abs/train-data/summary"

  #Get all the txt files with summaries
  all_files = glob.glob(path + "/*.txt")

  summary = []

  #Get the summaries
  for filename in all_files:
      with open(filename, 'r') as f:
          a = f.read()
          summary.append(a)

  #Path to the case files
  path = "/content/drive/MyDrive/IN-Abs/train-data/judgement"

  #Get all the txt files with case details
  all_files = glob.glob(path + "/*.txt")
  print(all_files)
  source = []

  for filename in all_files:
      with open(filename, 'r') as f:
          a = f.read()
          source.append(a)

  #Prepare the dataset


  for src,summ in zip(source,summary):
      data.append({"source" : src , "target" : summ})

  #Path to the summary
  path = "/content/drive/MyDrive/IN-Abs/test-data/summary"

  #Get all the txt files with summaries
  all_files = glob.glob(path + "/*.txt")

  summary = []

  #Get the summaries
  for filename in all_files:
      with open(filename, 'r') as f:
          a = f.read()
          summary.append(a)

  #Path to the case files
  path = "/content/drive/MyDrive/IN-Abs/test-data/judgement"

  #Get all the txt files with case details
  all_files = glob.glob(path + "/*.txt")
  print(all_files)
  source = []

  for filename in all_files:
      with open(filename, 'r') as f:
          a = f.read()
          source.append(a)
          #Prepare the dataset

  for src,summ in zip(source,summary):
      data.append({"source" : src , "target" : summ})

  return data

In [None]:
#Load the data
data=list()
data=load_data(data)

['/content/drive/MyDrive/IN-Abs/train-data/judgement/676.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6738.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6743.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6789.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/630.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6271.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6626.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6616.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6327.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6286.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6634.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6714.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6528.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/666.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/6593.txt', '/content/drive/MyDrive/IN-Abs/train-data/judgement/672.t

In [None]:
# Covert to appropriate format
data_pd=pd.DataFrame(data)

## Load the BART Model, Tokenizer and Dataset

In [None]:
# Loading Model and tokenizer
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, BartConfig

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)

bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# Add special tokens if required

new_tokens = ['<F>', '<RLC>', '<A>', '<S>', '<P>', '<R>', '<RPC>']

special_tokens_dict = {'additional_special_tokens': new_tokens}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
bart_model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


BartScaledWordEmbedding(50272, 1024, padding_idx=1)

In [None]:
summary_data = SummaryDataModule(tokenizer, data_pd, batch_size = 1)
model = LitModel(learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model)

## Train the Model
**Explanation:**

1. **Install PyTorch Lightning:**
   - `!pip install pytorch-lightning --upgrade`: This line ensures that the latest version of PyTorch Lightning is installed.

2. **Import Necessary Modules:**
   - `import pytorch_lightning as pl`: Imports the main PyTorch Lightning library.
   - `from pytorch_lightning.callbacks import TQDMProgressBar`: Imports the `TQDMProgressBar` callback for displaying a progress bar during training.

3. **Configure the Trainer:**
   - `trainer = pl.Trainer(...)`: Creates a `Trainer` object, which is the central component of PyTorch Lightning for managing the training process.
   - **`accelerator="gpu"`:** Specifies that the training should be performed on a GPU.
   - **`devices=1"`:** Indicates that one GPU should be used.
   - **`max_epochs=3"`:** Sets the maximum number of training epochs.
   - **`min_epochs=2"`:** Sets the minimum number of training epochs.
   - **`callbacks=[TQDMProgressBar(refresh_rate=5)]`:** Adds a progress bar with a refresh rate of 5 to the training process.
   - **`precision=16"`:** Enables mixed precision training with 16-bit floating-point numbers.




In [None]:
!pip install pytorch-lightning --upgrade  # Ensure PyTorch Lightning is up-to-date

import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar


# Use TQDMProgressBar to control the refresh rate
trainer = pl.Trainer(accelerator="gpu",  # Use "gpu" for GPU training
                     devices=1,           # Specify the number of GPUs to use
                     max_epochs=3,
                     min_epochs=2,
                     callbacks=[TQDMProgressBar(refresh_rate=5)],  # Set refresh rate here
                     precision=16)



/usr/local/lib/python3.10/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, summary_data)

  return bound(*args, **kwds)
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                         | Params | Mode
--------------------------------------------------------------
0 | model | BartForConditionalGeneration | 406 M  | eval
--------------------------------------------------------------
406 M     Trainable params
0         Non-trainable params
406 M     Total params
1,625.194 Total estimated model params size (MB)
0         Modules in train mode
348       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


In [None]:
#Save the model
trainer.save_checkpoint("/content/drive/MyDrive/IN-Abs/model/output.ckpt")

##Results :
Ran 3 epochs succesfully. Training loop terminated as 'max_epochs=3' reached.