In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from tqdm import tqdm

import matplotlib.pyplot as plt

from models import GPT
from tokenizers import Tokenizer

from torch.optim import AdamW

import pandas as pd 

from utils import BatchLoader, estimate_loss, train_loop, load, save

# hyperparameters
batch_size = 16  # number of independent sequences that'll be processed in parallel
block_size = 128  # maximum context length for the preds
max_iters = 1000
eval_interval = 200
learning_rate = 3e-4
device = "mps" if torch.backends.mps.is_available() else "cpu"
eval_iters = 200
n_embd = 256
n_head = 4
n_blocks = 4
dropout = 0.2
# --------------

torch.manual_seed(1337)

# data preparation
text = open("dataset/tinyshakespeare.txt", "r").read()
# set up the vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
tokenizer = Tokenizer(chars)

data = torch.tensor(tokenizer.encode(text), dtype=torch.long)

n = int(0.9 * len(data))  # first 90% will be the training set
n1 = int(0.95 * len(data))  # 90-95% will be the validation set and the last 5% will be the calibration set for the paper

train_data = data[:n]
val_data = data[n:n1]
calibrate_data = data[n1:]

train_loader = BatchLoader(train_data, block_size, batch_size, device, name="train")
val_loader = BatchLoader(val_data, block_size, batch_size, device, name="val")
calibration_loader = BatchLoader(calibrate_data, block_size, batch_size, device, name="calibrate")


In [2]:
model = GPT(vocab_size, block_size, n_embd, n_head, n_blocks, device, dropout)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

training_losses = train_loop(model, optimizer, vocab_size, train_loader, [train_loader, val_loader], max_iters, eval_interval, eval_iters)

print("training is done!")

plt.title("training losses")
plt.plot(training_losses)
plt.savefig("training_losses.png")

idx = torch.zeros((1, 1), dtype=torch.long, device=device)
print(tokenizer.decode(model.generate(idx, max_new_tokens=500)[0].tolist()))

model_params = {
    "vocab_size": vocab_size,
    "block_size": block_size,
    "n_embd": n_embd,
    "n_head": n_head,
    "n_blocks": n_blocks,
    "dropout": dropout,
    "device": device
}

save(model, tokenizer, model_params, "model")

In [3]:
def attn_head_importance_hook(module, ins, outs) -> None: # TODO: does the importance calculation returns the correct values for each head? 
    """ calculates the multi-head-attention layer's importance per head """
    # outs.shape = (B, T, E) where B: batch_size, T: num tokens, E: embedding size
    # the importance is calculated as summing the L2 norm of the attn outputs on B and T dimensions
    outs_flat = outs.view(-1, outs.shape[-1]) # (b,t,e) -> (b*t, e)
    importance = torch.linalg.vector_norm(outs_flat.detach().cpu(), ord=2, dim=-1).sum()

    module.calculated_importance = importance
    
    # print(outs_flat.shape)
    # print("module:", module.__class__.__name__, end=" ")
    # print("importance:", importance)
    # print(f"{module.__class__.__name__} importance: {importance.shape}")

def neuron_importance_hook(module, ins, outs) -> None:
    """ calculates the neuron importance for the given layer """ 
    
    # the ffwd linear weights should be in the shape of (out, in)
    # the paper sums up the values of (X * W_i^T) meaning (B, T, in) x (in, 1)= (B,T,1) -> (1, ) (summed up)
    
    # thus, in order to vectorize this operation, we'll need to hook this function to the first linear layer itself rather than the whole ffwd block. 

    # for each neuron in the ffwd layer, we can simply sum up the output columns

    # as they're the activations of individual neurons
    # calculate the importances
    # importance = outs.detach().sum()
    importance = outs.detach().cpu().sum(dim=(0,1))
    # print(f"{module.__class__.__name__} importance.shape: {importance.shape}")

    module.calculated_importance = importance

def embedding_importance_hook(module, ins, outs) -> None:
    # the first block's first processing layer will be the 
    # layer norm
    # so we'll just sum up the layer norm outputs after getting them
    # calculate the importances

    importance = outs.detach().sum(dim=(0,1))
    # print("importance.shape:", importance.shape)
    # print("n_embd: ", outs.size(-1))
    # print("module:", module.__class__.__name__)
    # print("outs.shape:", outs.shape) # probably (B, T, E)
 
    module.calculated_importance = importance

    # print(f"{module.__class__.__name__} importance.shape: {importance.shape}")

def block_importance_hook(module, ins, outs) -> None:
   
    in_vectors = ins[0].detach()  # (B, T, E)
    out_vectors = outs.detach()   # (B, T, E)
    
    # Calculate cosine similarity for each sample and time step
    dot_product = torch.sum(in_vectors * out_vectors, dim=-1)  # (B, T)
    in_norm = torch.norm(in_vectors, p=2, dim=-1)  # (B, T)
    out_norm = torch.norm(out_vectors, p=2, dim=-1)  # (B, T)
    
    cosine_sim = dot_product / (in_norm * out_norm + 1e-8)  # (B, T)
    
    # Calculate BI by taking the expectation (mean) and subtracting from 1
    block_importance = 1 - torch.mean(cosine_sim)
    
    # print("Block Importance:", block_importance.item())
    # print("module:", module.__class__.__name__)
    # print("outs.shape:", outs.shape)  # (B, T, E)
 
    module.calculated_importance = block_importance

    # print(f"{module.__class__.__name__} importance.shape: {block_importance.shape}")

In [4]:
# set up the initial hooks for all the corresponding layers
from models import Block, GPT

def delete_importance_attr(layer: nn.Module):
    if hasattr(layer, "calculated_importance"):
        del layer.calculated_importance

def remove_all_forward_hooks(model: GPT):
    if not isinstance(model, GPT):
        raise NotImplementedError("Only GPT models are supported for now")
    
    for module in model.modules():
        if isinstance(module, Block):
            for head in module.sa.heads:
                head._forward_hooks.clear()

                head.key._forward_hooks.clear()
                head.value._forward_hooks.clear()
                head.query._forward_hooks.clear()

                delete_importance_attr(head)
                
                delete_importance_attr(head.key)
                delete_importance_attr(head.query)
                delete_importance_attr(head.value)

            module.ffwd.net[0]._forward_hooks.clear()
            module.ln1._forward_hooks.clear()
            module.sa._forward_hooks.clear()
            module.sa.proj._forward_hooks.clear()
            delete_importance_attr(module.ffwd.net[0])
            delete_importance_attr(module.ln1)
            delete_importance_attr(module.sa)
            delete_importance_attr(module.sa.proj)

def register_all_forward_hooks(model: GPT):
    if not isinstance(model, GPT):
        raise NotImplementedError("Only GPT models are supported for now")

    num_blocks = 0
    for module in model.modules():
        if isinstance(module, Block):
            num_blocks += 1
            for head in module.sa.heads:
                head.register_forward_hook(attn_head_importance_hook)

                head.key.register_forward_hook(neuron_importance_hook)
                head.value.register_forward_hook(neuron_importance_hook) 
                head.query.register_forward_hook(neuron_importance_hook)

            module.ffwd.net[0].register_forward_hook(neuron_importance_hook) # register the forward hook to the linear layer inside of the ffwd block
            module.sa.proj.register_forward_hook(neuron_importance_hook)
            module.ln1.register_forward_hook(embedding_importance_hook)
            module.register_forward_hook(block_importance_hook)

In [5]:
def reinit_models():
    loaded_model, tokenizer = load(GPT, "model")
    loaded_model.to(device);

    remove_all_forward_hooks(loaded_model)
    register_all_forward_hooks(loaded_model)

    return loaded_model, tokenizer

In [6]:
def get_model():
    model, tokenizer = reinit_models()
    s = 0
    for k in model.parameters():
        if k.requires_grad:
            s += k.numel()
    print("# trainable parameters:", s)

    sample_batch = calibrate_data[:batch_size*block_size]
    sample_batch = sample_batch.view(batch_size, block_size)
    sample_batch = sample_batch.to(device)

    model(sample_batch);

    return model, s

In [7]:
# neuron and head pruning? 

# start with neuron pruning

def prune_neurons(model, ratio=0.2) -> None:
    # goal: trim the MLP layer weights
    # 1 - argsort the importances of the `ffwd` layers defined in the model
    # 2 - remove the weights with respect to the given ratio
    for module in model.modules():
        if isinstance(module, Block):
            importances = module.ffwd.net[0].calculated_importance 
            num_neurons = int((1-ratio) * importances.size(0))
            idx = importances.argsort(descending=True)[:num_neurons]
            # reinitialize the weights along with the layer
            dense1 = module.ffwd.net[0]
            dense2 = module.ffwd.net[2]

            module.ffwd.net[0] = nn.Linear(dense1.in_features, num_neurons).to(model.device) # weights.shape = (num_neurons, dense1.in_features)
            module.ffwd.net[2] = nn.Linear(num_neurons, dense2.out_features).to(model.device) # weights.shape = (dense2.out_features = emb)

            # now we need to set the weights to the new layers.

            dense1.weight.data = dense1.weight.data[idx, :]
            dense1.bias.data = dense1.bias.data[idx]

            dense2.weight.data = dense2.weight.data[idx, :]
            dense2.bias.data = dense2.bias.data[idx]

            module.ffwd.net[0].calculated_importance = importances[idx]
            module.ffwd.net[2].calculated_importance = importances[idx]

    return model


def prune_heads(model, ratio=0.2) -> None:
    # goal: trim the attention heads' layer weights using the same approach as the `prune_neurons`
    for module in model.modules():
        if isinstance(module, Block):
            # now the multi-head attention
            for head in module.sa.heads:
                # key,value,query weight shape: (head_size, n_embd) # n_embd
                k,v,q = head.key, head.value, head.query

                key_importances = head.key.calculated_importance
                value_importances = head.value.calculated_importance
                query_importances = head.query.calculated_importance

                num_neurons = int((1-ratio) * key_importances.size(0))


                k_idx = key_importances.argsort(descending=True)[:num_neurons]
                v_idx = value_importances.argsort(descending=True)[:num_neurons]
                q_idx = query_importances.argsort(descending=True)[:num_neurons]

                head.key = nn.Linear(k.in_features, num_neurons, bias=False).to(model.device) 
                head.value = nn.Linear(v.in_features, num_neurons, bias=False).to(model.device) 
                head.query = nn.Linear(q.in_features, num_neurons, bias=False).to(model.device) 

                head.key.weight.data = k.weight.data[k_idx, :] # (head_size, num_dense_embd)
                head.value.weight.data = v.weight.data[v_idx, :] # (head_size, num_dense_embd)
                head.query.weight.data = q.weight.data[q_idx, :] # (head_size, num_dense_embd)

                head.key.calculated_importance = key_importances[k_idx]
                head.value.calculated_importance = value_importances[v_idx]
                head.query.calculated_importance = query_importances[q_idx]

                # TODO: only the weights in the embedding layers are prunned (1st strategy)
                # TODO: need to follow the correct implementation from the paper (pruning every linear layer?)

            proj = module.sa.proj
            proj_importances = module.sa.proj.calculated_importance
            num_neurons = int((1-ratio) * key_importances.size(0)) * module.sa.num_heads
            idx = proj_importances.argsort(descending=True)[:num_neurons]

            module.sa.proj = nn.Linear(num_neurons, proj.out_features).to(model.device)

            module.sa.proj.weight.data = proj.weight.data[:, idx]
            module.sa.proj.bias.data = proj.bias.data

            module.sa.proj.calculated_importance = proj_importances[idx]

def prune_embeddings(model, ratio=0.2) -> None:
    # goal: trim the embedding dimension of the weight matrices in MLP, MHA, and LayerNorm layers.
    # TODO: check how embedding importance is calculated!
    
    for module in model.modules():
        if isinstance(module, Block):
            # start with pruning the MLP layers
            importances = module.ln1.calculated_importance

            dense1 = module.ffwd.net[0] # weights.shape = (emb, 4 * emb)
            dense2 = module.ffwd.net[2] # weights.shape = (4 * emb, emb)

            num_dense_embd = int((1-ratio) * dense1.in_features)
            idx = importances.argsort(descending=True)[:num_dense_embd]

            module.ffwd.net[0] = nn.Linear(num_dense_embd, dense1.out_features).to(model.device) # weights.shape = (num_dense_embd, dense1.in_features)
            module.ffwd.net[2] = nn.Linear(dense2.in_features, num_dense_embd).to(model.device) # weights.shape = (dense2.out_features = emb)

            module.ffwd.net[0].weight.data = dense1.weight.data[:, idx]
            module.ffwd.net[0].bias.data = dense1.bias.data
            module.ffwd.net[2].weight.data = dense2.weight.data[idx, :]
            module.ffwd.net[2].bias.data = dense2.bias.data[idx]        


            # now the multi-head attention
            for head in module.sa.heads:
                # key,value,query weight shape: (head_size, n_embd) # n_embd
                k,v,q = head.key, head.value, head.query
                  
                head.key = nn.Linear(num_dense_embd, k.out_features, bias=False).to(model.device) 
                head.value = nn.Linear(num_dense_embd, v.out_features, bias=False).to(model.device) 
                head.query = nn.Linear(num_dense_embd, q.out_features, bias=False).to(model.device) 

                head.key.weight.data = k.weight.data[:, idx] # (head_size, num_dense_embd)
                head.value.weight.data = v.weight.data[:, idx] # (head_size, num_dense_embd)
                head.query.weight.data = q.weight.data[:, idx] # (head_size, num_dense_embd)

                head.key.calculated_importance = k.calculated_importance
                head.value.calculated_importance = v.calculated_importance 
                head.query.calculated_importance = q.calculated_importance

            ln1 = module.ln1
            ln2 = module.ln2

            module.ln1 = nn.LayerNorm(num_dense_embd).to(model.device) 
            module.ln1.weight.data = ln1.weight.data[idx]
            module.ln1.bias.data = ln1.bias.data[idx]

            module.ln2 = nn.LayerNorm(num_dense_embd).to(model.device) 
            module.ln2.weight.data = ln2.weight.data[idx]
            module.ln2.bias.data = ln2.bias.data[idx]

            proj = module.sa.proj
            module.sa.proj = nn.Linear(proj.in_features, num_dense_embd).to(model.device) 
            module.sa.proj.weight.data = proj.weight.data[idx, :] # (num_dense_embd, n_embd)
            module.sa.proj.bias.data = proj.bias.data[idx]
            
            module.sa.proj.calculated_importance = proj.calculated_importance

    
    temb_table = model.token_embedding_table
    pemb_table = model.position_embedding_table

    model.token_embedding_table = nn.Embedding(vocab_size, num_dense_embd).to(device)
    model.position_embedding_table = nn.Embedding(model.block_size, num_dense_embd).to(device)

    model.token_embedding_table.weight.data = temb_table.weight.data[:, idx]
    model.position_embedding_table.weight.data = pemb_table.weight.data[:, idx]

    lnf = model.ln_f
    ln_head = model.ln_head

    model.ln_f = nn.LayerNorm(num_dense_embd).to(device)
    model.ln_head = nn.Linear(num_dense_embd, ln_head.out_features).to(device) 

    model.ln_f.weight.data = lnf.weight.data[idx]
    model.ln_f.bias.data = lnf.bias.data[idx]
    model.ln_head.weight.data = ln_head.weight.data[:, idx] # weight.shape = (vocab_size, embd)
    model.ln_head.bias.data = ln_head.bias.data

In [8]:
def get_num_params(model):
    t = 0
    for k in model.parameters():
        if k.requires_grad:
            t += k.numel()
    
    return t

In [9]:
calibrate_data.shape[0] / train_data.shape[0]

0.05555588760915432

In [10]:
strategies = {
    "width_head": prune_heads,
    "width_neuron": prune_neurons,
    "width_embedding": prune_embeddings
}

def experiment(pruning_strategies: list[list[str]] = [[("width_head", 0.1), ("width_neuron", 0.1), ("width_embedding", 0.1)]], learning_rate: float=2e-3):
    results = []

    model, num_params = get_model()
    base_loss = estimate_loss(model, val_loader)['val'].item()

    for run in range(len(pruning_strategies)):
        print("-"*50)
        strategy = pruning_strategies[run]

        pruning_funcs = [strategies[s] for s, ratio in strategy]
        pruning_func_names = [s for s, ratio in strategy]
        ratios = [ratio for s, ratio in strategy]

        print(f"RUN {run+1} | RATIO: {ratios} | STRATEGIES: {pruning_func_names}")
        model, num_params = get_model()
        print(f"{'Number of trainable parameters before pruning:':60}", num_params)
        # prune
        for f, r in zip(pruning_funcs, ratios):
            f(model, r)
        #
        pruned_num_params = get_num_params(model)
        param_diff_ratio = ((num_params-pruned_num_params)/num_params)
        print(f"{'Number of training parameters after pruning:':60} {pruned_num_params}")
        print(f"{'Ratio of the pruned weights to the base model:':60} {param_diff_ratio*100:.2f}%")
        pruned_eval = estimate_loss(model, val_loader)['val'].item()
        print(f"{'Pruned evaluation loss (before calibration):':60} {pruned_eval:.4f}")
        #
        print("Starting the calibration")
        optimizer = AdamW(model.parameters(), lr=learning_rate)
        losses = train_loop(model, optimizer, vocab_size, calibration_loader, [calibration_loader, val_loader], max_iters = 200, eval_interval=50, eval_iters=50)
        #
        calibrated_eval = estimate_loss(model, val_loader)['val'].item()
        print(f"{'Pruned evaluation loss (after calibration):':60} {calibrated_eval:.4f}")

        
        result = {
            "run": run+1,
            "base_num_params": num_params,
            "pruned_num_params": pruned_num_params,
            "pruning_ratio": ratios,
            "param_diff_ratio": param_diff_ratio,
            "before_calibration_loss": pruned_eval,
            "after_calibration_loss": calibrated_eval,
            "base_loss": base_loss,
            "learning_rate": learning_rate,
            "pruning_strategies": pruning_func_names,
            "training_losses": losses
        }


        results.append(result)
        run_df = pd.DataFrame(results)
        run_df.to_csv(f"run_results.csv", index=False)
    

    return results

In [21]:
import itertools
import numpy as np

def get_config_combinations(start: float=0.1, end: float=0.5, step: float=0.15):
    # Define the range and step
    # Create the list of values for widths
    values = np.arange(start, end + step, step)

    # Initialize the experiment config list
    experiment_config = []

    for s in ['width_head', 'width_neuron', 'width_embedding']:
        config = [
            [(s, round(v, 2))]
            for v in values
        ]
        experiment_config.extend(config)

    # Setup 1: Vary width_head and width_neuron
    config1 = [
        [("width_head", round(wh, 2)), ("width_neuron", round(wn, 2))]
        for wh, wn in itertools.product(values, values)
    ]
    experiment_config.extend(config1)

    # Setup 2: Vary width_head and width_embedding
    config2 = [
        [("width_head", round(wh, 2)), ("width_embedding", round(we, 2))]
        for wh, we in itertools.product(values, values)
    ]
    experiment_config.extend(config2)

    # Setup 3: Vary width_neuron and width_embedding
    config3 = [
        [("width_neuron", round(wn, 2)), ("width_embedding", round(we, 2))]
        for wn, we in itertools.product(values, values)
    ]
    experiment_config.extend(config3)

    # Setup 4: Vary all three - width_head, width_neuron, and width_embedding
    config4 = [
        [("width_head", round(wh, 2)), ("width_neuron", round(wn, 2)), ("width_embedding", round(we, 2))]
        for wh, wn, we in itertools.product(values, values, values)
    ]
    experiment_config.extend(config4)

    # Show an example of what the experiment_config looks like
    print(f"Total configurations: {len(experiment_config)}")


Total configurations: 74087


[[('width_head', np.float64(0.1))],
 [('width_head', np.float64(0.11))],
 [('width_head', np.float64(0.12))],
 [('width_head', np.float64(0.13))],
 [('width_head', np.float64(0.14))],
 [('width_head', np.float64(0.15))],
 [('width_head', np.float64(0.16))],
 [('width_head', np.float64(0.17))],
 [('width_head', np.float64(0.18))],
 [('width_head', np.float64(0.19))],
 [('width_head', np.float64(0.2))],
 [('width_head', np.float64(0.21))],
 [('width_head', np.float64(0.22))],
 [('width_head', np.float64(0.23))],
 [('width_head', np.float64(0.24))],
 [('width_head', np.float64(0.25))],
 [('width_head', np.float64(0.26))],
 [('width_head', np.float64(0.27))],
 [('width_head', np.float64(0.28))],
 [('width_head', np.float64(0.29))],
 [('width_head', np.float64(0.3))],
 [('width_head', np.float64(0.31))],
 [('width_head', np.float64(0.32))],
 [('width_head', np.float64(0.33))],
 [('width_head', np.float64(0.34))],
 [('width_head', np.float64(0.35))],
 [('width_head', np.float64(0.36))],
 [('

In [11]:
experiment_config = get_config_combinations()

exp_results = experiment(pruning_strategies=experiment_config)

  model.load_state_dict(torch.load(save_dir / "model.pth"))


# trainable parameters: 3222593
--------------------------------------------------
RUN 1 | RATIO: [0.1, 0.1, 0.1] | STRATEGIES: ['width_head', 'width_neuron', 'width_embedding']
# trainable parameters: 3222593
Number of trainable parameters before pruning:               3222593
Number of training parameters after pruning:                 2602749
Ratio of the pruned weights to the base model:               19.23%
Pruned evaluation loss (before calibration):                 4.0400
Starting the calibration
UNIFORM BASELINE:  4.174387454986572


step 150: calibrate loss 2.2714, val loss 2.3592,  	 | baseline (uniform random): 4.1744: 100%|██████████| 200/200 [00:33<00:00,  5.95it/s]


Pruned evaluation loss (after calibration):                  2.2003
--------------------------------------------------
RUN 2 | RATIO: [0.2, 0.2, 0.2] | STRATEGIES: ['width_head', 'width_neuron', 'width_embedding']


  model.load_state_dict(torch.load(save_dir / "model.pth"))


# trainable parameters: 3222593
Number of trainable parameters before pruning:               3222593
Number of training parameters after pruning:                 2063741
Ratio of the pruned weights to the base model:               35.96%
Pruned evaluation loss (before calibration):                 3.9875
Starting the calibration
UNIFORM BASELINE:  4.174387454986572


step 150: calibrate loss 2.2191, val loss 2.3000,  	 | baseline (uniform random): 4.1744: 100%|██████████| 200/200 [00:32<00:00,  6.15it/s]


Pruned evaluation loss (after calibration):                  2.1605


In [14]:
pd.read_csv("run_results.csv")

Unnamed: 0,run,base_num_params,pruned_num_params,pruning_ratio,param_diff_ratio,before_calibration_loss,after_calibration_loss,base_loss,learning_rate,pruning_strategies,training_losses
0,1,3222593,2602749,"[0.1, 0.1, 0.1]",0.192343,4.040036,2.200298,2.009167,0.002,"['width_head', 'width_neuron', 'width_embedding']","[0.608637809753418, 0.5644298791885376, 0.5361..."
1,2,3222593,2063741,"[0.2, 0.2, 0.2]",0.359602,3.987487,2.160472,2.009167,0.002,"['width_head', 'width_neuron', 'width_embedding']","[0.5986542701721191, 0.5519154667854309, 0.527..."
