# Introduction

In this notebook you will be implementing a Jax version of GPT from [this](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) paper. Please read it in order to better understand the model. In particular, pay attention to the applications of a pre-trained model to fine-tuning and few-shot learning.

Afterwards, the notebook will walk you through several experiments using your pre-trained model.

In [1]:
# basic explanation of the model

In [2]:
# jax explanation

# Setup

In [3]:
!pip install flax
!pip install optax
!pip install tensorflow

Collecting flax
  Using cached flax-0.6.2-py3-none-any.whl (189 kB)
Collecting jax>=0.3.16
  Using cached jax-0.3.25-py3-none-any.whl
Collecting rich>=11.1
  Using cached rich-12.6.0-py3-none-any.whl (237 kB)
Collecting tensorstore
  Downloading tensorstore-0.1.28-cp39-cp39-macosx_10_14_x86_64.whl (9.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.2/9.2 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hCollecting optax
  Using cached optax-0.1.4-py3-none-any.whl (154 kB)
Collecting commonmark<0.10.0,>=0.9.0
  Using cached commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
Collecting chex>=0.1.5
  Using cached chex-0.1.5-py3-none-any.whl (85 kB)
Collecting jaxlib>=0.1.37
  Downloading jaxlib-0.3.25-cp39-cp39-macosx_10_14_x86_64.whl (66.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.2/66.2 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting dm-tree>=0.1.5
  Downloading dm_tree-0.1.7-cp39-cp39-macosx



In [4]:
!pip3 install tensorflow



In [5]:
# imports

import jax
import jax.numpy as jnp
from jax import random

import flax
from flax import linen as nn
# from flax.training import train_state, checkpoints

import optax

RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

# Helper Functions

These are functions you may find helpful in your implementation.

In [None]:
class TransformerGELU(nn.Module):
    """
    Applies GELU function layer-wise
    """
    def setup(self, approximate=False):
        super().__init__()
        self.approximate = approximate

    def __call__(self, x):
        return nn.gelu(x, self.approximate)



# Implementation

In this section you will implement x parts of the Flax/JAX GPT model. Specifically: (list what we end up deciding)



You will also be coding task-specific input transformations for fine-tuning.


## (1) Implementing Attention and Multi-Headed Attention

(Description of how GPT attention might differ from non-gpt attention)

In [None]:
# copy paste implementation here

## (2) Embedding Layer

(GPT does not have positional embeddings)

## (3) Decoder Block

## (4) Putting it all together: Transformer Decoder Block and GPT

We have implemented the TransformerFeedForward class for you. 



In [None]:
# transformer decoder block

In [None]:
# gpt block

In [None]:
# pretrain OR import pretrained weights

## (5) Task-specific Head

In [None]:
# import a test task


# Experiments

In this section you will (train) and evaluate models with different pre-training strategies. (Note: if neccessary, we could reduce the number of parameters for this part)

These models are:
(1) No unsupervised pretraining, only fine-tuning
(2) Pretraining on same dataset as fine-tune task
(3) Pretraining on dataset which combines data from several tasks
(4) Pretraining on an unrelated dataset. This pretrained model is provided.

Before starting, consider how you expect these models to perform (1) on their related fine-tuning task, and (2) how well these models will generalize to other tasks.

In [None]:
def build_pretrain_batch(dataset, seq_length, batch_size):
    indices = list(np.random.randint(0, len(dataset), size=batch_size))
    
    batch_input = [dataset[i:i+] for i in indices]
    
    return batch_input

In [None]:
# import default gpt model
from transformer import TransformerDecoder

class TransformerPreTrainer(nn.Module):
    def setup(self, vocab_size, d_model, input_length, output_length, n_layers, d_filter, dropout=0, learning_rate=1e-3):
        self.model = TransformerDecoder(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, d_filter=d_filter)

        # Summarization loss
        criterion = optax.softmax_cross_entropy_with_integer_labels()
        self.loss_fn = lambda pred, input: criterion(pred[:-1], input[1:])
        self.learning_rate = learning_rate
        self.optimizer = optax.Adam(self.model.parameters(), lr=learning_rate)
    
    def forward(self,batch,optimize=True):
        pred_logits = self.model(**batch)
        
        loss = self.loss_fn(pred_logits,batch['input'])
        # accuracy = (th.eq(pred_logits.argmax(dim=2,keepdim=False),target).float()*mask).sum()/mask.sum()
        
        if optimize:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
                
        return loss #, accuracy

## Experiment 1: The value of pretraining

In this section we will fine-tune a randomly initialized GPT model on (task 1). We will also fine-tune the pre-trained model on the same task. 

Compare the results. (Which model has better performance? Which converges faster?)

In [None]:
# initialize a blank GPT model

# fine-tune on task 1

# fine-tune pretrained model on task 1

# graph results

Q: 

## Experiment 2: Pretraining on related datasets

In this section we will remove the labels from the (task 1) dataset, and use it to pretrain our GPT implementation. We will then fine-tune the model on (task 1) and (task 2), and evaluate the respective models. 



*   List item
*   List item



In [None]:
# construct dataset using a subset of (task 1) labels.

# pretrain a blank GPT model on this dataset OR import the weights directly

# fine-tune on (task 1) 

# fine-tune on (task 2)

# evaluate task 1 on held-out task 1 data

# evaluate task 1 on task 2 data

# fine-tune for both tasks using model 4 as the pretrained model

# graph results

Q: How did the model perform on (task 1)?  

Now we will see how a model pretrained on multiple tasks performs. 

In [None]:
# pre-train using combined dataset of task 1 and 2 (model 3.1)

# pre-train using combined dataste of task 1,2,3 (model 3.2)

# evaluate on task 1 and task 2. 


Q: How did model 3.1 perform on task 1? How about model 3.2? Explain the difference in performance.

Q: 