In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from tqdm import tqdm

import matplotlib.pyplot as plt

from models import GPT
from tokenizers import Tokenizer

from utils import BatchLoader, estimate_loss, train_loop

# hyperparameters
batch_size = 16  # number of independent sequences that'll be processed in parallel
block_size = 128  # maximum context length for the preds
max_iters = 1000
eval_interval = 200
learning_rate = 3e-4
device = "mps" if torch.backends.mps.is_available() else "cpu"
eval_iters = 200
n_embd = 256
n_head = 4
n_blocks = 4
dropout = 0.2
# --------------

torch.manual_seed(1337)

# data preparation
text = open("dataset/tinyshakespeare.txt", "r").read()
# set up the vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
tokenizer = Tokenizer(chars)

data = torch.tensor(tokenizer.encode(text), dtype=torch.long)

n = int(0.9 * len(data))  # first 90% will be the training set
n1 = int(0.98 * len(data))  # 90-98% will be the validation set and the last 2% will be the calibration set for the paper

train_data = data[:n]
val_data = data[n:n1]
calibrate_data = data[n1:]

train_loader = BatchLoader(train_data, block_size, batch_size, device, name="train")
val_loader = BatchLoader(val_data, block_size, batch_size, device, name="val")
calibration_loader = BatchLoader(calibrate_data, block_size, batch_size, device, name="calibrate")


In [2]:
from utils import load

loaded_model, tokenizer = load(GPT, "model")
loaded_model.to(device);

  model.load_state_dict(torch.load(save_dir / "model.pth"))


In [3]:
def attn_head_importance_hook(module, ins, outs) -> None: # TODO: does the importance calculation returns the correct values for each head? 
    """ calculates the multi-head-attention layer's importance per head """

    # outs.shape = (B, T, E) where B: batch_size, T: num tokens, E: embedding size
    # the importance is calculated as summing the L2 norm of the attn outputs on B and T dimensions
    outs_flat = outs.view(-1, outs.shape[-1]) # (b,t,e) -> (b*t, e)
    importance = torch.linalg.vector_norm(outs_flat.detach().cpu(), ord=2, dim=-1).sum()

    module.calculated_importance = importance
    
    # print(outs_flat.shape)
    # print("module:", module.__class__.__name__, end=" ")
    # print("importance:", importance)
    # print(f"{module.__class__.__name__} importance: {importance.shape}")

def neuron_importance_hook(module, ins, outs) -> None:
    """ calculates the neuron importance for the given layer """ 
    
    # the ffwd linear weights should be in the shape of (out, in)
    # the paper sums up the values of (X * W_i^T) meaning (B, T, in) x (in, 1)= (B,T,1) -> (1, ) (summed up)
    
    # thus, in order to vectorize this operation, we'll need to hook this function to the first linear layer itself rather than the whole ffwd block. 

    # for each neuron in the ffwd layer, we can simply sum up the output columns

    # as they're the activations of individual neurons
    # calculate the importances
    # importance = outs.detach().sum()
    importance = outs.detach().cpu().sum(dim=(0,1))
    # print(f"{module.__class__.__name__} importance.shape: {importance.shape}")

    module.calculated_importance = importance

def embedding_importance_hook(module, ins, outs) -> None:
    # the first block's first processing layer will be the 
    # layer norm,
    # so we'll just sum up the layer norm outputs after getting them
   # calculate the importances
    importance = outs.detach().sum(dim=(0,1))
    # print("importance.shape:", importance.shape)
    # print("n_embd: ", outs.size(-1))
    # print("module:", module.__class__.__name__)
    # print("outs.shape:", outs.shape) # probably (B, T, E)
 
    module.calculated_importance = importance

    # print(f"{module.__class__.__name__} importance.shape: {importance.shape}")

def block_importance_hook(module, ins, outs) -> None:
   
    in_vectors = ins[0].detach()  # (B, T, E)
    out_vectors = outs.detach()   # (B, T, E)
    
    # Calculate cosine similarity for each sample and time step
    dot_product = torch.sum(in_vectors * out_vectors, dim=-1)  # (B, T)
    in_norm = torch.norm(in_vectors, p=2, dim=-1)  # (B, T)
    out_norm = torch.norm(out_vectors, p=2, dim=-1)  # (B, T)
    
    cosine_sim = dot_product / (in_norm * out_norm + 1e-8)  # (B, T)
    
    # Calculate BI by taking the expectation (mean) and subtracting from 1
    block_importance = 1 - torch.mean(cosine_sim)
    
    # print("Block Importance:", block_importance.item())
    # print("module:", module.__class__.__name__)
    # print("outs.shape:", outs.shape)  # (B, T, E)
 
    module.calculated_importance = block_importance

    # print(f"{module.__class__.__name__} importance.shape: {block_importance.shape}")

In [4]:
# set up the initial hooks for all the corresponding layers
from models import Block, GPT

def delete_importance_attr(layer: nn.Module):
    if hasattr(layer, "calculated_importance"):
        del layer.calculated_importance

def remove_all_forward_hooks(model: GPT):
    if not isinstance(model, GPT):
        raise NotImplementedError("Only GPT models are supported for now")
    
    for module in model.modules():
        if isinstance(module, Block):
            for head in module.sa.heads:
                head._forward_hooks.clear()
                delete_importance_attr(head)

            module.ffwd.net[0]._forward_hooks.clear()
            module.ln1._forward_hooks.clear()
            module.sa._forward_hooks.clear()

            delete_importance_attr(module.ffwd.net[0])
            delete_importance_attr(module.ln1)
            delete_importance_attr(module.sa)

def register_all_forward_hooks(model: GPT):
    if not isinstance(model, GPT):
        raise NotImplementedError("Only GPT models are supported for now")

    num_blocks = 0
    for module in model.modules():
        if isinstance(module, Block):
            num_blocks += 1
            for head in module.sa.heads:
                head.register_forward_hook(attn_head_importance_hook)
            module.ffwd.net[0].register_forward_hook(neuron_importance_hook) # register the forward hook to the linear layer inside of the ffwd block
            if num_blocks == 1:
                module.ln1.register_forward_hook(embedding_importance_hook)
            module.register_forward_hook(block_importance_hook)

In [5]:
def reinit_models():
    loaded_model, tokenizer = load(GPT, "model")
    loaded_model.to(device);

    remove_all_forward_hooks(loaded_model)
    register_all_forward_hooks(loaded_model)

    return loaded_model, tokenizer

model, tokenizer = reinit_models()

# estimate_loss(model, val_loader) # 2.0079

In [6]:
s = 0
for k in model.parameters():
    if k.requires_grad:
        s += k.numel()
print(s)

3222593


In [7]:
ixs = torch.randint(0, calibrate_data.size(0), (batch_size*block_size, ))

sample_batch = calibrate_data[ixs]
sample_batch = sample_batch.view(batch_size, block_size)
sample_batch = sample_batch.to(device)

model(sample_batch);

In [8]:
# neuron and head pruning? 

# start with neuron pruning

def prune_neurons(model, ratio=0.2) -> None:
    # goal: trim the MLP layer weights
    # 1 - argsort the importances of the `ffwd` layers defined in the model
    # 2 - remove the weights with respect to the given ratio

    for module in model.modules():
        if isinstance(module, Block):
            importances = module.ffwd.net[0].calculated_importance 
            num_neurons = int((1-ratio) * importances.size(0))
            idx = importances.argsort(descending=True)[:num_neurons]
            # reinitialize the weights along with the layer
            dense1 = module.ffwd.net[0]
            dense2 = module.ffwd.net[2]

            module.ffwd.net[0] = nn.Linear(dense1.in_features, num_neurons).to(model.device) # weights.shape = (num_neurons, dense1.in_features)
            module.ffwd.net[2] = nn.Linear(num_neurons, dense2.out_features).to(model.device) # weights.shape = (dense2.out_features = emb)

            # now we need to set the weights to the new layers.

            dense1.weight.data = dense1.weight.data[idx, :]
            dense1.bias.data = dense1.bias.data[idx]

            dense2.weight.data = dense2.weight.data[idx, :]
            dense2.bias.data = dense2.bias.data[idx]
    
    return model


def prune_heads(model, ratio=0.2) -> None:
    # goal: trim the attention heads' layer weights using the same approach as the `prune_neurons`
    pass


def prune_embeddings(model, ratio=0.2) -> None:
    # goal: trim the embedding dimension of the weight matrices in MLP, MHA, and LayerNorm layers.
    # TODO: check how embedding importance is calculated!
    pass

In [9]:
prune_neurons(model, 0.2)

GPT(
  (token_embedding_table): Embedding(65, 256)
  (position_embedding_table): Embedding(128, 256)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttentionConcat(
        (heads): ModuleList(
          (0-3): 4 x Head(
            (key): Linear(in_features=256, out_features=64, bias=False)
            (query): Linear(in_features=256, out_features=64, bias=False)
            (value): Linear(in_features=256, out_features=64, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=819, bias=True)
          (1): ReLU()
          (2): Linear(in_features=819, out_features=256, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [10]:
t = 0
for k in model.parameters():
    if k.requires_grad:
        t += k.numel()
print(t)

print("total difference: ", s-t)
print("diff ratio: ", (s-t)/s)

2801933
total difference:  420660
diff ratio:  0.1305346346870362


In [11]:
estimate_loss(model, val_loader)

{'val': tensor(3.1024)}

In [12]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=learning_rate)

In [13]:
train_loop(model, optimizer, vocab_size, calibration_loader, [calibration_loader, val_loader], max_iters = 200, eval_interval=50, eval_iters=50)

UNIFORM BASELINE:  4.174387454986572


step 150: calibrate loss 1.8349, val loss 2.1691,  	 | baseline (uniform random): 4.1744: 100%|██████████| 200/200 [01:58<00:00,  1.68it/s]


[0.5107743740081787,
 0.4602510333061218,
 0.4431624114513397,
 0.4300287961959839,
 0.42028582096099854,
 0.4176865518093109,
 0.4054933786392212,
 0.3975889980792999,
 0.4035460650920868,
 0.3891201615333557,
 0.380073606967926,
 0.3809751868247986,
 0.3830280900001526,
 0.36982008814811707,
 0.37740984559059143,
 0.3718697130680084,
 0.37604257464408875,
 0.3682430684566498,
 0.3606867790222168,
 0.3529806137084961,
 0.36129114031791687,
 0.34820279479026794,
 0.34358736872673035,
 0.35382091999053955,
 0.34895434975624084,
 0.34823405742645264,
 0.3469099700450897,
 0.35234910249710083,
 0.3467314541339874,
 0.3486853241920471,
 0.3428036570549011,
 0.3350028097629547,
 0.3481793999671936,
 0.3514370918273926,
 0.3434023857116699,
 0.32682523131370544,
 0.3314991295337677,
 0.31998324394226074,
 0.3255039155483246,
 0.32986509799957275,
 0.33886486291885376,
 0.3355449438095093,
 0.32665982842445374,
 0.3302578330039978,
 0.31407833099365234,
 0.31431102752685547,
 0.32778149843215

In [14]:
estimate_loss(model, val_loader)

{'val': tensor(2.1561)}