# Iterative summarization training notebook

## Goals

* This notebook propose a fine-tuning of GPT-2 for summarization in order to interpret this process. 
* The reverse training is also performed to create an exemplificator model.

The training is facilitated by Neel Nanda's library [TransformerLens](https://github.com/neelnanda-io/TransformerLens). 
See this project's [GitHub](https://github.com/Xmaster6y/Iterative_summarisation) for more details.

## Notes

* For training use this notebook with a GPU runtime `Runtime>Change runtime type>GPU`.

## Imports

### Pip installs

In [None]:
!pip install git+https://github.com/neelnanda-io/TransformerLens.git
!pip install evaluate
!pip install rouge_score

### Classic libraries imports

In [None]:
import os
import json
import evaluate

import torch
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader

from transformers import get_scheduler

### External toolboxes

In [None]:
from transformer_lens import HookedTransformer

### Dataset import

Note that the original dataset is really really big, impossible to load it even with linecache. I couldn't even split it with the lunix command. Anyway I wouldn't have time to train with such a tremendous dataset hence the reduction.

Also note that the dataset is ill-formated i.e. not really a json-file.

In [None]:
if not os.path.exists('./Movies_and_TV_5.json.gz') and not os.path.exists('./dataset.json'):
  !wget https://jmcauley.ucsd.edu/data/amazon_v2/categoryFilesSmall/Movies_and_TV_5.json.gz --no-check-certificate

In [None]:
if not os.path.exists('./dataset.json'):
  if not os.path.exists('./Movies_and_TV_5.json.gz'):
    raise FileNotFoundError
  else:
    !gzip -d Movies_and_TV_5.json.gz
    !mv Movies_and_TV_5.json dataset.json
n = 2000
!head -n $n dataset.json > mini_dataset.json
out = !wc -l mini_dataset.json
n = int(out[0].split()[0])
n

In [None]:
train_size = n // 2
eval_size = n - train_size
!head -n $train_size mini_dataset.json > train.json
!tail -n $eval_size mini_dataset.json > eval.json

In [None]:
class SummarizationDataset(Dataset):
    """Text dataset for summarisation."""
    def __init__(self, dataset_path='./mini_dataset.json'):
        """
        Args:
            dataset_path (string): Path to the dataset of texts.
        """
        with open(dataset_path, 'r') as f:
            lines = f.readlines()
        raw_records = list(map(json.loads, lines))
        self.records = [r for r in raw_records if 'reviewText' in r.keys()]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        return f"[review]: {self.records[idx]['reviewText']}\n[summary]: {self.records[idx]['summary']}"

  
class ExamplificationDataset(SummarizationDataset):
    """Text dataset for examplification."""
    def __getitem__(self, idx):
        return f"[summary]: {self.records[idx]['summary']}\n[review]: {self.records[idx]['reviewText']}"

## Model loading and fine-tuning

The model is trained using the dataloaders defined above on the chosen task.

* To avoid doing the training over the weights are automatically loaded from Drive unless stated otherwise.



In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

In [None]:
def train(model, optimizer, lr_scheduler, dl, epochs, pb):
    model.train()    
    for epoch in range(epochs):
        for idx, batch in enumerate(dl):
             with torch.set_grad_enabled(True):
                optimizer.zero_grad()
                loss = model(batch, return_type="loss")
                loss.backward()
                optimizer.step() 
                lr_scheduler.step()
                pb.update(1)
                if idx % 50 == 0:
                    print({"loss": float(loss)}, idx+epoch*len(dl))

In [None]:
task = 'exp'
weight_file = f'./{task}_weights.pt'
re_train = False

In [None]:
batch_size = 2
if task == 'sum':
    train_dataset = SummarizationDataset(dataset_path='./train.json')
    eval_dataset = SummarizationDataset(dataset_path='./eval.json')
else:
    train_dataset = ExamplificationDataset(dataset_path='./train.json')
    eval_dataset = ExamplificationDataset(dataset_path='./eval.json')

train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dl = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)


lr = 3e-4
epochs = 1
optimizer = torch.optim.AdamW(params = model.parameters(), lr=lr)
num_training_steps = epochs * len(train_dl)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
progress_bar = tqdm(range(num_training_steps))

if not os.path.exists(weight_file):
    if task == 'sum':
        !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1tPU5mHCXcAxZJJHvv9XyT-MzgoeSWUk9" -O $weight_file  && rm -rf /tmp/cookies.txt
    else:
        !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1--qi-Rzhff4OcAtknrSrzfDzgNC9z5mQ" -O $weight_file  && rm -rf /tmp/cookies.txt

if not os.path.exists(weight_file) or re_train:
    train(model, optimizer, lr_scheduler, train_dl, epochs, progress_bar)
else:
    model.load_state_dict(torch.load(weight_file))
    model.eval()

In [None]:
torch.save(model.state_dict(), weight_file )

## Model evaluation

The model is evaluated using the ROUGE metric. This only gives insight on the meaningfulness of the training.

In [None]:
def split_func(batch, sep="summary"):
    samples = [s.split(f'\n[{sep}]: ') for s in batch]
    return [s[0]+f'\n[{sep}]: ' for s in samples]

In [None]:
def model_eval(model, metric, dl, split_func, max_iter=100, max_new_tokens=15):
    model.eval()
    for i, batch in enumerate(dl):
        with torch.no_grad():
            to_pred = split_func(batch)
            predictions = [model.generate(prompt, max_new_tokens=max_new_tokens) for prompt in to_pred]
            metric.add_batch(predictions=predictions, references=batch)
        if i >= max_iter:
          break

In [None]:
metric = evaluate.load("rouge")
if task == "sum":
    sf = lambda b: split_func(b, sep="summary")
else:
    sf = lambda b: split_func(b, sep="review")
max_new_tokens = 15
max_iter = 100
model_eval(model, metric, eval_dl, sf, max_iter, max_new_tokens)

In [None]:
metric.compute()