In [1]:
# Install these packages with these specific versions else the notebook breaks
!pip install transformers==4.5.1
!pip install pytorch_lightning==1.2.10
!pip install sentencepiece

Collecting transformers==4.5.1
  Using cached transformers-4.5.1-py3-none-any.whl (2.1 MB)
Collecting tokenizers<0.11,>=0.10.1
  Using cached tokenizers-0.10.2-cp38-cp38-macosx_10_11_x86_64.whl (2.3 MB)
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.7.0
    Uninstalling tokenizers-0.7.0:
      Successfully uninstalled tokenizers-0.7.0
  Attempting uninstall: transformers
    Found existing installation: transformers 2.9.0
    Uninstalling transformers-2.9.0:
      Successfully uninstalled transformers-2.9.0
Successfully installed tokenizers-0.10.2 transformers-4.5.1
Collecting pytorch_lightning==1.2.10
  Downloading pytorch_lightning-1.2.10-py3-none-any.whl (841 kB)
[K     |████████████████████████████████| 841 kB 6.2 MB/s eta 0:00:01
Collecting tensorboard!=2.5.0,>=2.2.0
  Downloading tensorboard-2.4.1-py3-none-any.whl (10.6 MB)
[K     |████████████████████████████████| 10.6 MB 5.9 MB/s eta 0:0

Collecting async-timeout<4.0,>=3.0
  Downloading async_timeout-3.0.1-py3-none-any.whl (8.2 kB)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.6.3-cp38-cp38-macosx_10_14_x86_64.whl (124 kB)
[K     |████████████████████████████████| 124 kB 6.9 MB/s eta 0:00:01
Collecting multidict<7.0,>=4.5
  Downloading multidict-5.1.0-cp38-cp38-macosx_10_14_x86_64.whl (49 kB)
[K     |████████████████████████████████| 49 kB 6.7 MB/s  eta 0:00:01
Building wheels for collected packages: PyYAML
  Building wheel for PyYAML (setup.py) ... [?25ldone
[?25h  Created wheel for PyYAML: filename=PyYAML-5.3.1-cp38-cp38-macosx_10_15_x86_64.whl size=44624 sha256=13c6a678e5f473ab49698754fa50d6747359811095d065d66e6fa31e29e062c5
  Stored in directory: /Users/akshatgoel/Library/Caches/pip/wheels/13/90/db/290ab3a34f2ef0b5a0f89235dc2d40fea83e77de84ed2dc05c
Successfully built PyYAML
Installing collected packages: multidict, yarl, async-timeout, fsspec, aiohttp, torchmetrics, tensorboard, PyYAML, pytorch-lightning
  Atte

In [2]:
# Import packages
import argparse
import glob
import pickle
import os
import json
import time
import logging
import random
import re
from tqdm import tqdm
from itertools import chain
from string import punctuation
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu

import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize


import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import sentencepiece


from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/akshatgoel/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
class T5FineTuner(pl.LightningModule):
  def __init__(self, hparams):
    super(T5FineTuner, self).__init__()
    
    if type(hparams) is dict: 
      hparams = argparse.Namespace(**hparams)
    
    self.hparams = hparams
    self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path)
    self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path)
  
  def is_logger(self):
    return self.trainer.global_rank <= 0
  
  def forward(
      self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None
  ):
    return self.model(
        input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        labels=labels,
    )

  def _step(self, batch):
    labels = batch["target_ids"]
    labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

    outputs = self(
        input_ids=batch["source_ids"],
        attention_mask=batch["source_mask"],
        labels=labels,
        decoder_attention_mask=batch['target_mask']
    )

    loss = outputs[0]
    
    return loss

  def training_step(self, batch, batch_idx):
    loss = self._step(batch)
    self.log('training_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    tensorboard_logs = {"train_loss": loss}
    return {"loss": loss, "log": tensorboard_logs}
  
  def training_epoch_end(self, outputs):
    avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
    self.log('avg_training_loss', avg_train_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
    tensorboard_logs = {"avg_train_loss": avg_train_loss}
    return {"avg_train_loss": avg_train_loss, "log": tensorboard_logs, 'progress_bar': tensorboard_logs}

  def validation_step(self, batch, batch_idx):
    loss = self._step(batch)
    self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
    tensorboard_logs = {"val_loss": loss}
    return {"val_loss": loss, "log": tensorboard_logs}
  
  def validation_epoch_end(self, outputs):
    avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    print(avg_loss)
    tensorboard_logs = {"val_loss": avg_loss}
    self.log('avg_val_loss', avg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

  def configure_optimizers(self):
    "Prepare optimizer and schedule (linear warmup and decay)"

    model = self.model
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": self.hparams.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
    self.opt = optimizer
    return [optimizer]
  
  def optimizer_step(self,
                     epoch=None, 
                     batch_idx=None, 
                     optimizer=None, 
                     optimizer_idx=None, 
                     optimizer_closure=None, 
                     on_tpu=None, 
                     using_native_amp=None, 
                     using_lbfgs=None
                     ):

    optimizer.step(closure=optimizer_closure)
    optimizer.zero_grad()
    self.lr_scheduler.step()

  def train_dataloader(self):
    train_dataset = get_dataset(tokenizer=self.tokenizer, type_path="train", args=self.hparams)
    dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, num_workers=4)
    t_total = (
        (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
        // self.hparams.gradient_accumulation_steps
        * float(self.hparams.num_train_epochs)
    )
    scheduler = get_linear_schedule_with_warmup(
        self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
    )
    self.lr_scheduler = scheduler
    return dataloader

  def val_dataloader(self):
    val_dataset = get_dataset(tokenizer=self.tokenizer, type_path="val", args=self.hparams)
    return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4)

In [10]:
def write_weights(checkpoint_dir='/Users/akshatgoel/Desktop/checkpoints/', 
                  checkpoint_name = 'exp_9_rc_noise.ckpt', 
                  state_name = 'exp_9_rc_noise.pt'):
    '''
    ------------
    Input: 
    Output: 
    ------------
    '''
    # Set parameters
    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
    filepath = os.path.join(checkpoint_dir, state_name)

    model = T5FineTuner.load_from_checkpoint(checkpoint_path)
    torch.save(model.state_dict(), filepath)
    
    

In [12]:
checkpoint_dir='/Users/akshatgoel/Desktop/checkpoints/'

files = [f for f in os.listdir(checkpoint_dir) if f.endswith('ckpt')]

for i, checkpoint_name in enumerate(files):
    state_name = str(i) + '.pth'
    write_weights(checkpoint_dir, checkpoint_name, state_name)

In [13]:
files

['exp_11_rc_ef_low_reg.ckpt',
 'exp_9_rc_noise.ckpt',
 'exp_10_rc_noise_ef.ckpt',
 'exp_12_rc_low_reg.ckpt']