# Introduction

In this notebook you will be implementing a Jax version of GPT from [this](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) paper. Please read it in order to better understand the model. In particular, pay attention to the applications of a pre-trained model to fine-tuning and few-shot learning.

Afterwards, the notebook will walk you through several experiments using your pre-trained model.

In [1]:
# basic explanation of the model

In [2]:
# jax explanation

# Setup

In [3]:
!pip install flax
!pip install optax
!pip install tensorflow

Collecting flax
  Using cached flax-0.6.2-py3-none-any.whl (189 kB)
Collecting jax>=0.3.16
  Using cached jax-0.3.25-py3-none-any.whl
Collecting rich>=11.1
  Using cached rich-12.6.0-py3-none-any.whl (237 kB)
Collecting tensorstore
  Downloading tensorstore-0.1.28-cp39-cp39-macosx_10_14_x86_64.whl (9.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.2/9.2 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hCollecting optax
  Using cached optax-0.1.4-py3-none-any.whl (154 kB)
Collecting commonmark<0.10.0,>=0.9.0
  Using cached commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
Collecting chex>=0.1.5
  Using cached chex-0.1.5-py3-none-any.whl (85 kB)
Collecting jaxlib>=0.1.37
  Downloading jaxlib-0.3.25-cp39-cp39-macosx_10_14_x86_64.whl (66.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.2/66.2 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting dm-tree>=0.1.5
  Downloading dm_tree-0.1.7-cp39-cp39-macosx



In [4]:
!pip3 install tensorflow



In [5]:
# imports

import jax
import jax.numpy as jnp
from jax import random

import flax
from flax import linen as nn
# from flax.training import train_state, checkpoints

import optax

RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

# Helper Functions

These are functions you may find helpful in your implementation.

In [None]:
class TransformerGELU(nn.Module):
    """
    Applies GELU function layer-wise
    """
    def setup(self, approximate=False):
        super().__init__()
        self.approximate = approximate

    def __call__(self, x):
        return nn.gelu(x, self.approximate)



# Implementation

In this section you will implement x parts of the Flax/JAX GPT model. Specifically: (list what we end up deciding)



You will also be coding task-specific input transformations for fine-tuning.


## (1) Implementing Attention and Multi-Headed Attention

(Description of how GPT attention might differ from non-gpt attention)

## (2) Embedding Layer

(GPT does not have positional embeddings)

## (3) Decoder Block

In [None]:
class TransformerDecoderBlock(nn.Module):
    """A decoding block from the paper Attention Is All You Need (https://arxiv.org/pdf/1706.03762.pdf).
    :param inputs: Tensor of decoder_inputs
\                    decoder_inputs -> a Tensor with shape [batch_size, decoding_sequence_length, channels]
    :return: output: Tensor with same shape as decoder_inputs
    """
    input_size : int
    n_heads : int
    filter_size : int
    hidden_size : int
    dropout : float
    def setup(self):
        self.norm_1 = nn.LayerNorm(self.input_size)
        self.attention = MultiHeadAttention(self.n_heads, self.input_size)
        self.norm_2 = nn.LayerNorm(self.input_size)
        self.feed_forward = TransformerFeedForward(self.input_size, self.filter_size, self.hidden_size, self.dropout)

    def __call__(self, inputs, self_attention_mask=None):
        norm_inputs = self.norm_1(inputs)
        attention = self.attention(norm_inputs, mask=self_attention_mask)
        res_attention = attention + inputs
        output = res_attention + self.feed_forward(self.norm_2(res_attention))
        return output

## (4) Putting it all together: Transformer Decoder and GPT

We have implemented the TransformerFeedForward class for you. 



In [None]:
# transformer decoder
class TransformerDecoder(nn.Module):
    """
        Stack of TransformerDecoderBlocks. Performs initial embedding to d_model dimensions, then repeated self-attention
        followed by attention on source sequence. Defaults to 6 layers of self-attention.
    """
    # embed_size,
    # vocab_size,
    # # output_layer,
    # n_layers = 6,
    # n_heads = 8,
    # d_model = 512,
    # d_filter = 2048,
    # dropout = 0.1
    embed_size : int
    vocab_size : int
    n_layers : int
    n_heads : int
    d_model : int
    d_filter : int
    dropout : float
    def setup(self):

        self.token_embedding = nn.Embed(self.vocab_size, self.embed_size)
        self.pos_embedding = nn.Embed(self.d_model, self.embed_size)
        # self.pos_embedding = nn.Embed(d_model, self.embed_size)

        self.output_layer = nn.Dense(self.vocab_size, use_bias=False)

        decoding_stack = [0]*self.n_layers
        for i in range(self.n_layers):
            decoder = TransformerDecoderBlock(self.embed_size, self.n_heads, self.d_filter, self.d_model, self.dropout)
            setattr(self,f"decoder{i}",decoder)
            decoding_stack[i] = decoder
        # self.output_layer = output_layer
        self.decoding_stack = decoding_stack
        self.attention_mask = jnp.reshape(jnp.tril(jnp.ones((self.d_model, self.d_model))), (1,1,self.d_model,self.d_model))
        self.norm = nn.LayerNorm(self.embed_size)
        self.drop = nn.Dropout(self.dropout, deterministic=False)

    # Self attention mask is a upper triangular mask to prevent attending to future targets + a padding mask
    # attention mask is just the padding mask
    def __call__(self, input, fine_tune = False, train=True):
        """
            Args:
                inputs: a tuple of (encoder_output, target_embedding)
                    encoder_output: a float32 Tensor with shape [batch_size, sequence_length, d_model]
                    target_input: either a int32 or float32 Tensor with shape [batch_size, target_length, ndims]
                    cache: Used for fast decoding, a dictionary of tf.TensorArray. None during training.
                mask_future: a boolean for whether to mask future states in target self attention
            Returns:
                a tuple of (embedding_output, output)
                    output: a Tensor with shape [batch_size, sequence_length, d_model]
        """
        seq_len = len(input)

        pos = jnp.expand_dims(jnp.arange(0, stop=seq_len),0)

        tok_embed = self.token_embedding(input) # (batch_size, sequence_length, d_model)
        pos_embed = self.pos_embedding(pos) # (1, sequence_length, d_model)

        decoder_output = self.drop(tok_embed + pos_embed)

        self_attention_mask = (self.attention_mask[:,:,:seq_len,:seq_len] == 0)

        for decoder in self.decoding_stack:
            decoder_output = decoder(decoder_output, self_attention_mask = self_attention_mask)

        decoder_output = self.norm(decoder_output)

        embedding_output = self.token_embedding.attend(decoder_output)
        output = None
        if fine_tune:
            output = self.output_layer(decoder_output)

        return embedding_output, output

: 

In [None]:
# gpt block

In [None]:
# pretrain OR import pretrained weights

## (5) Task-specific Head

In [None]:
# import a test task


# Experiments

In this section you will (train) and evaluate models with different pre-training strategies. (Note: if neccessary, we could reduce the number of parameters for this part)

These models are:
(1) No unsupervised pretraining, only fine-tuning
(2) Pretraining on same dataset as fine-tune task
(3) Pretraining on dataset which combines data from several tasks
(4) Pretraining on an unrelated dataset. This pretrained model is provided.

Before starting, consider how you expect these models to perform (1) on their related fine-tuning task, and (2) how well these models will generalize to other tasks.

In [None]:
def build_pretrain_batch(dataset, seq_length, batch_size):
    indices = list(np.random.randint(0, len(dataset), size=batch_size))
    
    batch_input = [dataset[i:i+] for i in indices]
    
    return batch_input

In [None]:
CHECKPOINT_PATH = "checkpoints/"
import os

# import default gpt model
class TrainerModule:

    def __init__(self, model_name, exmp_batch, max_iters, lr=1e-3, warmup=100, seed=42, **model_kwargs):
        """
        Inputs:
            model_name - Name of the model. Used for saving and checkpointing
            exmp_batch - Example batch to the model for initialization
            max_iters - Number of maximum iterations the model is trained for. This is needed for the CosineWarmup scheduler
            lr - Learning rate in the optimizer
            warmup - Number of warmup steps. Usually between 50 and 500
            seed - Seed to use for model init
        """
        super().__init__()
        self.model_name = model_name
        self.max_iters = max_iters
        self.lr = lr
        self.warmup = warmup
        self.seed = seed
        # Create empty model. Note: no parameters yet
        self.model = TransformerDecoder(**model_kwargs)
        # Prepare logging
        self.log_dir = os.path.join(CHECKPOINT_PATH, self.model_name)
        # Create jitted training and eval functions
        self.create_functions()
        # Initialize model
        self.init_model(exmp_batch)

    def batch_to_input(self, exmp_batch):
        # Map batch to input data to the model
        # To be implemented in a task specific sub-class
        raise NotImplementedError

    def get_loss_function(self):
        # Return a function that calculates the loss for a batch
        # To be implemented in a task specific sub-class
        raise NotImplementedError

    def create_functions(self):
        # Create jitted train and eval functions
        calculate_loss = self.get_loss_function()

        # Training function
        def train_step(state, rng, batch):
            loss_fn = lambda params: calculate_loss(params, rng, batch, train=True)
            ret, grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
            loss, acc, rng = ret[0], *ret[1]
            state = state.apply_gradients(grads=grads)
            return state, rng, loss, acc
        self.train_step = jax.jit(train_step)

        # Evaluation function
        def eval_step(state, rng, batch):
            _, (acc, rng) = calculate_loss(state.params, rng, batch, train=False)
            return acc, rng
        self.eval_step = jax.jit(eval_step)

    def init_model(self, exmp_batch):
        # Initialize model
        self.rng = jax.random.PRNGKey(self.seed)
        self.rng, init_rng, dropout_init_rng = jax.random.split(self.rng, 3)
        exmp_input = self.batch_to_input(exmp_batch)
        params = self.model.init({'params': init_rng, 'dropout': dropout_init_rng}, exmp_input, train=True)['params']
        # Initialize learning rate schedule and optimizer
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=self.lr,
            warmup_steps=self.warmup,
            decay_steps=self.max_iters,
            end_value=0.0
        )
        optimizer = optax.chain(
            optax.clip_by_global_norm(1.0),  # Clip gradients at norm 1
            optax.adam(lr_schedule)
        )
        # Initialize training state
        self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=optimizer)

    def train_model(self, train_loader, val_loader, num_epochs=500):
        # Train model for defined number of epochs
        best_acc = 0.0
        for epoch_idx in tqdm(range(1, num_epochs+1)):
            self.train_epoch(train_loader, epoch=epoch_idx)
            if epoch_idx % 5 == 0:
                eval_acc = self.eval_model(val_loader)
                self.logger.add_scalar('val/accuracy', eval_acc, global_step=epoch_idx)
                if eval_acc >= best_acc:
                    best_acc = eval_acc
                    self.save_model(step=epoch_idx)
                self.logger.flush()

    def train_epoch(self, train_loader, epoch):
        # Train model for one epoch, and log avg loss and accuracy
        accs, losses = [], []
        for batch in tqdm(train_loader, desc='Training', leave=False):
            self.state, self.rng, loss, accuracy = self.train_step(self.state, self.rng, batch)
            losses.append(loss)
            accs.append(accuracy)
        avg_loss = np.stack(jax.device_get(losses)).mean()
        avg_acc = np.stack(jax.device_get(accs)).mean()
        self.logger.add_scalar('train/loss', avg_loss, global_step=epoch)
        self.logger.add_scalar('train/accuracy', avg_acc, global_step=epoch)

    def eval_model(self, data_loader):
        # Test model on all data points of a data loader and return avg accuracy
        correct_class, count = 0, 0
        for batch in data_loader:
            acc, self.rng = self.eval_step(self.state, self.rng, batch)
            correct_class += acc * batch[0].shape[0]
            count += batch[0].shape[0]
        eval_acc = (correct_class / count).item()
        return eval_acc

    def save_model(self, step=0):
        # Save current model at certain training iteration
        checkpoints.save_checkpoint(ckpt_dir=self.log_dir, target=self.state.params, step=step)

    def load_model(self, pretrained=False):
        # Load model. We use different checkpoint for the pretrained model
        if not pretrained:
            params = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=self.state.params)
        else:
            params = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt'), target=self.state.params)
        self.state = train_state.TrainState.create(apply_fn=self.model.apply, params=params, tx=self.state.tx)

    def checkpoint_exists(self):
        # Check whether a pretrained model exist for this Transformer
        return os.path.isfile(os.path.join(CHECKPOINT_PATH, f'{self.model_name}.ckpt'))

In [None]:
class PreTrainer(TrainerModule):
    def batch_to_input(self, exmp_batch):
        return exmp_batch['input']

    def get_loss_function(self):
        def calculate_loss(params, rng, batch, train):
            rng, dropout_apply_rng = random.split(rng)
            logits = self.model.apply({'params': params}, batch['input'],
                                      add_positional_encoding=True,  # No positional encoding since this is a permutation equivariant task
                                      train=train,
                                      rngs={'dropout': dropout_apply_rng})
            logits = logits.squeeze(axis=-1)
            labels = batch['input'][1:]
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
            acc = (logits.argmax(axis=-1) == labels).astype(jnp.float32).mean()
            return loss, (acc, rng)
        return calculate_loss

In [None]:
model_args = {'embed_size':512, 'vocab_size':10000, 'n_layers':6, 'n_heads':16, 'd_model':512, 'd_filter':2048, 'dropout':.1}
trainer = PreTrainer(model_name='PreTrain',
                             exmp_batch={'input':jnp.array([1,0,1,2])},
                             max_iters=1000, **model_args)

## Experiment 1: The value of pretraining

In this section we will fine-tune a randomly initialized GPT model on (task 1). We will also fine-tune the pre-trained model on the same task. 

Compare the results. (Which model has better performance? Which converges faster?)

In [None]:
# initialize a blank GPT model

# fine-tune on task 1

# fine-tune pretrained model on task 1

# graph results

Q: 

## Experiment 2: Pretraining on related datasets

In this section we will remove the labels from the (task 1) dataset, and use it to pretrain our GPT implementation. We will then fine-tune the model on (task 1) and (task 2), and evaluate the respective models. 



*   List item
*   List item



In [None]:
# construct dataset using a subset of (task 1) labels.

# pretrain a blank GPT model on this dataset OR import the weights directly

# fine-tune on (task 1) 

# fine-tune on (task 2)

# evaluate task 1 on held-out task 1 data

# evaluate task 1 on task 2 data

# fine-tune for both tasks using model 4 as the pretrained model

# graph results

Q: How did the model perform on (task 1)?  

Now we will see how a model pretrained on multiple tasks performs. 

In [None]:
# pre-train using combined dataset of task 1 and 2 (model 3.1)

# pre-train using combined dataste of task 1,2,3 (model 3.2)

# evaluate on task 1 and task 2. 


Q: How did model 3.1 perform on task 1? How about model 3.2? Explain the difference in performance.

Q: 