# Iterative summarisation training notebook

## Goals

* This notebook propose a fine-tuning of GPT-2 for summarisation in order to interpret this process. 
* The reverse training is also performed to create an exemplificator model.

The training is facilitated by Neel Nanda's library [TransformerLens](https://github.com/neelnanda-io/TransformerLens). 
See this project's [GitHub](https://github.com/Xmaster6y/Iterative_summarisation) for more details.

## Notes

* For training use this notebook with a GPU runtime `Runtime>Change runtime type>GPU`.

## Imports

### Pip installs

In [1]:
!pip install git+https://github.com/neelnanda-io/TransformerLens.git
!pip install evaluate
!pip install rouge_score

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-uc88565v
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-uc88565v
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 090f63afcf72e8ecd9527bbb6f598874554def1b
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting einops<0.7.0,>=0.6.0
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Collecting datasets<3.0.0,>=2.7.1
  Downloading datasets-2.8.0-py3-none-any.whl (452 kB)
[2K     [90m━━━━━━━━━━━━━━

### Classic libraries imports

In [2]:
import os
import json
import evaluate

import torch
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader

from transformers import get_scheduler

### External toolboxes

In [3]:
from transformer_lens import HookedTransformer

### Dataset import

Note that the original dataset is really really big, impossible to load it even with linecache. I couldn't even split it with the lunix command. Anyway I wouldn't have time to train with such a tremendous dataset hence the reduction.

Also note that the dataset is ill-formated i.e. not really a json-file.

In [4]:
if not os.path.exists('./Movies_and_TV_5.json.gz') and not os.path.exists('./dataset.json'):
    !wget https://jmcauley.ucsd.edu/data/amazon_v2/categoryFilesSmall/Movies_and_TV_5.json.gz --no-check-certificate

--2023-01-22 14:49:16--  https://jmcauley.ucsd.edu/data/amazon_v2/categoryFilesSmall/Movies_and_TV_5.json.gz
Resolving jmcauley.ucsd.edu (jmcauley.ucsd.edu)... 137.110.160.73
Connecting to jmcauley.ucsd.edu (jmcauley.ucsd.edu)|137.110.160.73|:443... connected.
  Unable to locally verify the issuer's authority.
HTTP request sent, awaiting response... 200 OK
Length: 791322468 (755M) [application/x-gzip]
Saving to: ‘Movies_and_TV_5.json.gz’


2023-01-22 14:49:23 (109 MB/s) - ‘Movies_and_TV_5.json.gz’ saved [791322468/791322468]



In [5]:
if not os.path.exists('./dataset.json'):
    if not os.path.exists('./Movies_and_TV_5.json.gz'):
        raise FileNotFoundError
    else:
        !gzip -d Movies_and_TV_5.json.gz
        !mv Movies_and_TV_5.json dataset.json
n = 2000
!head -n $n dataset.json > mini_dataset.json
out = !wc -l mini_dataset.json
n = int(out[0].split()[0])
n

2000

In [6]:
train_size = n // 2
eval_size = n - train_size
!head -n $train_size mini_dataset.json > train.json
!tail -n $eval_size mini_dataset.json > eval.json

In [7]:
class SummariseDataset(Dataset):
    """Text dataset for summarisation."""
    def __init__(self, dataset_path='./mini_dataset.json'):
        """
        Args:
            dataset_path (string): Path to the dataset of texts.
        """
        with open(dataset_path, 'r') as f:
            lines = f.readlines()
        raw_records = list(map(json.loads, lines))
        self.records = [r for r in raw_records if 'reviewText' in r.keys()]

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        return f"[review]: {self.records[idx]['reviewText']}\n[summary]: {self.records[idx]['summary']}"

  
class ExamplificationDataset(SummariseDataset):
    """Text dataset for examplification."""
    def __getitem__(self, idx):
        return f"[summary]: {self.records[idx]['summary']}\n[review]: {self.records[idx]['reviewText']}"

## Model loading and fine-tuning

The model is trained using the dataloaders defined above on the chosen task.

* To avoid doing the training over the weights are automatically loaded from Drive unless stated otherwise.



In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [9]:
def train(model, optimizer, lr_scheduler, dl, epochs, pb):
    model.train()    
    for epoch in range(epochs):
        for idx, batch in enumerate(dl):
             with torch.set_grad_enabled(True):
                optimizer.zero_grad()
                loss = model(batch, return_type="loss")
                loss.backward()
                optimizer.step() 
                lr_scheduler.step()
                pb.update(1)
                if idx % 50 == 0:
                    print({"loss": float(loss)}, idx+epoch*len(dl))

In [10]:
task = 'exp'
weight_file = f'./{task}_weights.pt'
re_train = False

In [11]:
batch_size = 2
if task == 'sum':
    train_dataset = SummariseDataset(dataset_path='./train.json')
    eval_dataset = SummariseDataset(dataset_path='./eval.json')
else:
    train_dataset = ExamplificationDataset(dataset_path='./train.json')
    eval_dataset = ExamplificationDataset(dataset_path='./eval.json')

train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dl = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)


lr = 3e-4
epochs = 1
optimizer = torch.optim.AdamW(params = model.parameters(), lr=lr)
num_training_steps = epochs * len(train_dl)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
progress_bar = tqdm(range(num_training_steps))

if not os.path.exists(weight_file ):
    if task == 'sum':
        !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1tPU5mHCXcAxZJJHvv9XyT-MzgoeSWUk9" -O $weight_file  && rm -rf /tmp/cookies.txt
    else:
        !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1--qi-Rzhff4OcAtknrSrzfDzgNC9z5mQ" -O $weight_file  && rm -rf /tmp/cookies.txt

if not os.path.exists(weight_file ) or re_train:
    train(model, optimizer, lr_scheduler, train_dl, epochs, progress_bar)
else:
    model.load_state_dict(torch.load(weight_file))
    model.eval()

  0%|          | 0/500 [00:00<?, ?it/s]

--2023-01-22 14:50:07--  https://docs.google.com/uc?export=download&confirm=&id=1--qi-Rzhff4OcAtknrSrzfDzgNC9z5mQ
Resolving docs.google.com (docs.google.com)... 142.250.101.101, 142.250.101.139, 142.250.101.113, ...
Connecting to docs.google.com (docs.google.com)|142.250.101.101|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-00-8k-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4kv17gpv6hrb7aqem398pov9fljc4ihi/1674399000000/13918618242186115589/*/1--qi-Rzhff4OcAtknrSrzfDzgNC9z5mQ?e=download&uuid=40dc85a1-a73d-410a-9fea-3905ef87d111 [following]
--2023-01-22 14:50:07--  https://doc-00-8k-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/4kv17gpv6hrb7aqem398pov9fljc4ihi/1674399000000/13918618242186115589/*/1--qi-Rzhff4OcAtknrSrzfDzgNC9z5mQ?e=download&uuid=40dc85a1-a73d-410a-9fea-3905ef87d111
Resolving doc-00-8k-docs.googleusercontent.com (doc-00-8k-docs.googleusercontent.com)... 74.125.137.13

In [12]:
torch.save(model.state_dict(), weight_file )

## Model evaluation

The model is evaluated using the ROUGE metric.

In [13]:
raise NotImplementedError

NotImplementedError: ignored

In [None]:
def model_eval(model, metric, dl):
    model.eval()
    for batch in dl:
        with torch.no_grad():
            logits = model(batch, return_type="logits")
            predictions = torch.argmax(logits, dim=-1)
            metric.add_batch(predictions=predictions, references=batch)
    metric.compute()

In [None]:
metric = evaluate.load("rouge")
model_eval(model, metric, eval_dl)
metric