# **Initial tuning**

Initial tuning of model parameters (hidden size, number of layers and attention heads) to get a basic model to start with and use for future experiments. These parameter values may be changed before the final training.

### **Import data, modules**

In [None]:
# Import all modules and functions
import scripts.functions as functions
from functions import *

In [None]:
# load the train and test data sets
with open("text8_train.txt", "r") as f:
    train_text = f.read()
with open("text8_test.txt", "r") as f:
    test_text = f.read()

Build vocabulary (lowercase + space + a few punctuations)

In [None]:
# Define a character list with length 27
char_set = list("abcdefghijklmnopqrstuvwxyz ")
# Encode each letter to an integer and vice versa
char_to_int = {ch:i for i,ch in enumerate(char_set)}
int_to_char = {i:ch for ch,i in char_to_int.items()}

def encode(s):
    ids = [char_to_int[c] for c in s]
    return np.array(ids, dtype=np.uint8)  # use np.uint8 to save space

In [None]:
# Encode the text into integers
train_text_int = encode(train_text)
test_text_int = encode(test_text)

# Save the data
np.save("train_text_int.npy", train_text_int)
np.save("test_text_int.npy", test_text_int)

# **Create basic Transformer model**

Initialize the basic model architecture and parameters.
In Flax, model parameters are stored as a nested PyTree, similar to nested dictionaries/lists of arrays.

**Disclaimer**: Adjusting the model architecture can affect the number of model params

In [None]:
# Import model
import models.basic_transformer as basic_transformer

In [None]:
# Function to initialize basic Transformer model and its params
def create_train_state(key, vocab_size, d_model, n_layers, n_heads, max_len):
    model = basic_transformer.DecoderOnlyTransformer(vocab_size, d_model, n_layers, n_heads, max_len)

    # Create dummy input for initialization of batch size 1, seq length min(16, max_len)
    dummy = jnp.zeros((1, min(16, max_len)), dtype=jnp.int32)
    # Initialize the parameters and extracts the PyTree of params
    params = model.init({"params": key}, dummy)["params"]
    return model, params

In [None]:
# FIXED vocab size
vocab_size=len(char_set)

# maximum sequence length
max_len=128

### **Optimization step**
Performs a single gradient descent update. Updates the parameters and optimizer state.

In [None]:
# gradient update
def train_step(params, opt_state, x, y, tx):
    """
    Args:
      params: pytree of model parameters.
      opt_state: optax optimizer state corresponding to `params`.
      x: (B_seq, B_tok) int array input tokens.
      y: (B_seq, B_tok) int array target tokens.
      tx: optax.GradientTransformation (already initialized).

    Returns:
      new_params: updated parameters after one gradient step.
      new_opt_state: updated optimizer state.
      metrics: dict of scalar metrics (loss, acc).
    """
    def loss_fn(params):
        logits = model.apply({"params": params}, x)
        loss, metrics = loss_and_metrics(logits, y)
        return loss, metrics

    # compute gradients of loss w.r.t params (loss is scalar, metrics is auxiliary)
    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # optax update: update params and new optimizer state
    updates, new_opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, metrics

# jit: last argument should be static because it is an object
train_step = jax.jit(train_step, static_argnames=("tx",))

In this section, the architectural parameters that will be tuned are:
- Hidden size
- Number of layers
- Number of attention heads

**Objective**: To simply find a combination of model parameters that best improve model performance (the validation loss is the metric for this). Training time is also taken into account.

With GPU memory usage in mind, a reduced parameter grid is used:
- For hidden size, it will be fixed at 256. In future experiments, nearer to the end of the project, there will be an experiment on increasing hidden size.
- The number of attention heads does not affect the number of model parameters. Less restriction is exercised here.
- If there are more layers, it contributes to parameter count. It is mindful to settle on a value to the left of 8.

Batch size and number of iterations are reduced for a faster runtime.

**Disclaimer**: Transformer fine-tuning is very resource-intensive, especially because of the number of model parameters. With this in mind, constraints are exercised for all future experiments to avoid hitting the GPU usage limit on Colab. Furthermore, this experiment is to merely find an initial transformer model to use for future small-scale experiments. Parameter values may or may not be the same during the final training.

In [None]:
param_grid = {
    "d_model": [256],
    "n_heads": [8, 16],
    "n_layers": [4, 8]
}

In [None]:
# For model tuning, separate train_text_int into training and validation sets
print(len(train_text_int))

# 1/9 of the training set will be used for validation
val_text_int = train_text_int[-10_000_000:]
training_text_int = train_text_int[:80_000_000]
print(len(val_text_int))
print(len(training_text_int))

90000000
10000000
80000000


### **Early Stopping**

In [None]:
from itertools import product

min_improvement = 1e-3   # Considering computation of validation loss periodically
new_params = None
validation_losses_of_all = []
times_of_all = []
models = []

for d_model, n_heads, n_layers in product(param_grid["d_model"], param_grid["n_heads"], param_grid["n_layers"]):
    new_params = [d_model, n_heads, n_layers]
    models.append(new_params)
    # Initialize the model architecture and params
    model, params = create_train_state(key, vocab_size, d_model, n_layers, n_heads, max_len)
    print(f"Transformer model initialized with a hidden size of {d_model}, {n_layers} layers and {n_heads} heads")

    learning_rate = 0.001
    tx = optax.adam(learning_rate=learning_rate)
    # Initialize optimizer state for current params, learning rate is 0.001 for gradient descent
    opt_state = tx.init(params)

    niter = 20_000                    # reduced no of iterations for each model
    B_seq, B_tok = 64, 32             # sample a smaller no of sequences per batch, same no of tokens per seq
    loss_test_history = []

    time_start = time.time()
    patience_counter = 0
    fully_trained = False
    # Keep track of the best_val_loss for each model
    best_val_loss = 1000
    for it in range(niter):
        batch = get_batch(training_text_int, B_seq, B_tok)
        input, target = batch[0], batch[1]
        params_new, opt_state_new, metrics = train_step(params, opt_state, input, target, tx)

        # update model weights and optimizer state
        params = params_new
        opt_state = opt_state_new

        if it == niter - 1:
            fully_trained = True


        # Implement early stopping; compute validation loss periodically
        if (it%1000 == 0) or fully_trained:
            new_time = time.time() - time_start
            B_val, T_val = 1024, 32
            val_batch = get_batch(val_text_int, B_val, T_val)
            val_input, val_target = val_batch[0], val_batch[1]
            val_logits = model.apply({"params": params}, val_input)
            val_loss, val_metrics = loss_and_metrics(val_logits, val_target)

            if val_loss + min_improvement < best_val_loss:
                best_val_loss = val_loss
                patience_counter += 1

                if patience_counter >= 5:
                    print(f"Early stopping implemented at iteration {it}")
                    # If early stopping, break and keep tabs on latest training time and validation loss
                    validation_losses_of_all.append(val_loss)
                    times_of_all.append(new_time)
                    break

                else:
                    continue

            # If the training goes all the way without early stopping, keep tabs on total training time and latest validation loss
            elif fully_trained:
                validation_losses_of_all.append(val_loss)
                times_of_all.append(new_time)

# number of permutations of model parameters, there should only be 4
print(len(param_grid["d_model"]) * len(param_grid["n_heads"]) * len(param_grid["n_layers"]))

Transformer model initialized with a hidden size of 256, 4 layers and 8 heads
Early stopping implemented at iteration 4000
Transformer model initialized with a hidden size of 256, 8 layers and 8 heads
Early stopping implemented at iteration 4000
Transformer model initialized with a hidden size of 256, 4 layers and 16 heads
Early stopping implemented at iteration 4000
Transformer model initialized with a hidden size of 256, 8 layers and 16 heads
Early stopping implemented at iteration 4000
4


### **Full model training**

In [None]:
from itertools import product

new_params = None
full_validation_losses_of_all = []
full_times_of_all = []
models = []

for d_model, n_heads, n_layers in product(param_grid["d_model"], param_grid["n_heads"], param_grid["n_layers"]):
    new_params = [d_model, n_heads, n_layers]
    models.append(new_params)
    # Initialize the model architecture and params
    model, params = create_train_state(key, vocab_size, d_model, n_layers, n_heads, max_len)
    print(f"Transformer model initialized with a hidden size of {d_model}, {n_layers} layers and {n_heads} heads")

    learning_rate = 0.001
    tx = optax.adam(learning_rate=learning_rate)
    # Initialize optimizer state for current params
    opt_state = tx.init(params)

    niter = 20_000                    # reduced no of iterations for each model
    B_seq, B_tok = 64, 32             # sample a smaller no of sequences per batch, same no of tokens per seq

    time_start = time.time()
    # Keep track of the best_val_loss for each model
    best_val_loss = 1000
    for it in range(niter):
        batch = get_batch(training_text_int, B_seq, B_tok)
        input, target = batch[0], batch[1]
        params_new, opt_state_new, metrics = train_step(params, opt_state, input, target, tx)

        # update model weights and optimizer state
        params = params_new
        opt_state = opt_state_new

        # If the last iteration
        if it == niter - 1:
            new_time = time.time() - time_start
            B_val, T_val = 1024, 32
            val_batch = get_batch(val_text_int, B_val, T_val)
            val_input, val_target = val_batch[0], val_batch[1]
            val_logits = model.apply({"params": params}, val_input)
            val_loss, val_metrics = loss_and_metrics(val_logits, val_target)
            full_validation_losses_of_all.append(val_loss)
            full_times_of_all.append(new_time)


# number of permutations of model parameters, there should only be 4
print(len(param_grid["d_model"]) * len(param_grid["n_heads"]) * len(param_grid["n_layers"]))

Transformer model initialized with a hidden size of 256, 4 layers and 8 heads
Transformer model initialized with a hidden size of 256, 8 layers and 8 heads
Transformer model initialized with a hidden size of 256, 4 layers and 16 heads
Transformer model initialized with a hidden size of 256, 8 layers and 16 heads
4


In [None]:
# Evaluation with early stopping at the 4_000th iteration
for i in range(4):
    print(f"This model with parameters {models[i]} has a training time of {times_of_all[i]} and a validation loss of {validation_losses_of_all[i]}")

This model with parameters [256, 8, 4] has a training time of 70.73276329040527 and a validation loss of 1.4581079483032227
This model with parameters [256, 8, 8] has a training time of 137.31607007980347 and a validation loss of 1.4512369632720947
This model with parameters [256, 16, 4] has a training time of 75.04076051712036 and a validation loss of 1.4715590476989746
This model with parameters [256, 16, 8] has a training time of 146.54027605056763 and a validation loss of 1.467713713645935


In [None]:
# Evaluation for full training
for i in range(4):
    print(f"This model with parameters {models[i]} has a total training time of {full_times_of_all[i]} and a final validation loss of {full_validation_losses_of_all[i]}")

This model with parameters [256, 8, 4] has a total training time of 331.1321680545807 and a final validation loss of 1.3307313919067383
This model with parameters [256, 8, 8] has a total training time of 652.1351611614227 and a final validation loss of 1.3208262920379639
This model with parameters [256, 16, 4] has a total training time of 351.33376145362854 and a final validation loss of 1.3433935642242432
This model with parameters [256, 16, 8] has a total training time of 696.6427783966064 and a final validation loss of 1.3371410369873047


### **Evaluation**

**Decision**: Model with a **hidden size of 256, 4 layers and 8 heads**.

1. First result using early stopping (to get an initial sense of model performance)
- For the models that have 8 heads compared to 16 heads, model performance is slightly better.
- Between [256, 8, 4] and [256, 8, 8], as the number of layers doubled, the training time nearly doubled while model performance very slightly improves.

2. Second result (full model training)
- These results reassert the decision because by comparing the two combination of parameters that have 8 heads, model performance improves by a very small margin despite training time doubling. There are many other ways to optimize the model other than increase the number of layers.