# 🚀 Fine-Tuning & Evaluating LLaMA 3 with ShortGPT

In this notebook, we will use the **ShortGPT** code (from the research paper) to fine-tune and evaluate a **LLaMA 3** model.

Our goals:
- 🧩 Fine-tune LLaMA 3 using ShortGPT methods.
- 🧪 Evaluate the model on benchmark tasks.
- ⚙️ Learn to use tools like Hugging Face and lm-eval.

This notebook is beginner-friendly! We’ll explain each step simply and clearly.

Let’s get started! 🚀


In [None]:
!pip install datasets lm_eval

In [2]:
# importing the required libraries
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, Trainer, TrainingArguments
from collections import OrderedDict
from typing import List, Optional
import numpy as np
from tqdm.notebook import tqdm
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from lm_eval import evaluator, tasks

In [3]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Token has not been saved to git credential helper.


In [4]:
def layer_removal(
    model: nn.Module,
    layers_to_remove: OrderedDict
):
    """
    Generic removal implementation
    """

    for layer_name, layer_idx in layers_to_remove.items():
        modules = layer_name.split(".")
        mod = model
        for m in modules[:-1]:
            mod = getattr(mod, m)
        
        if layer_idx is None:
            delattr(mod, modules[-1])
        else:
            delattr(mod, modules[-1])[layer_idx]

In [5]:
def block_influence(
    input_hidden_state: torch.Tensor,
    output_hidden_state: torch.Tensor,
    angular=False,
):
    """
    input_hidden_state: B, S, D
    output_hidden_state: B, S, D
    """
    _, _, d = input_hidden_state.shape
    input_hidden_state = input_hidden_state.reshape(-1, d)
    output_hidden_state = output_hidden_state.reshape(-1, d)

    norm_input = input_hidden_state.norm(dim=-1, keepdim=True)
    norm_output = output_hidden_state.norm(dim=-1, keepdim=True)

    sim = (input_hidden_state @ output_hidden_state.T) / (norm_input * norm_output)
    sim = sim.diagonal().nan_to_num(nan=0.5)

    if angular:
        return (torch.arccos(sim) / torch.pi)

    return 1 - sim

In [6]:
class ShortHFModel():

    def __init__(self, model_name: str, layers_path: str, n_prune_layers: Optional[int] = None):
        """
        HuggingFace Model Wrapper

        Args:
            model_name (str): HuggingFace model name
            layers_path (str): String in dot notation demonstrating how to access layers of the model. Ex: "model.layers"
            (Optional) n_prune_layers (int): Number of layers to prune. Defaults to None.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
        # self.model.params = self.model.to_fp16(self.model.params)
        self.model.to("cuda")

        modules = layers_path.split(".")
        mod = self.model
        for m in modules:
            mod = getattr(mod, m)
        self.layers = mod

        self.n_prune_layers = n_prune_layers
        self.importances = [0 for _ in self.layers]  # layer-wise importance scores

    def remove_layers(
        self,
        layers_to_remove: Optional[List[int]] = [],
        angular: Optional[bool] = False
    ):
        if angular:
            assert self.importances, "Need to compute importances with eval_importance()"
            assert self.n_prune_layers, "Need number of layers to prune, set `n_prune_layers`"
            start_layer = np.argsort(np.array(self.importances[:-self.n_prune_layers+1]))[0]
            layers_to_remove = list(range(start_layer, start_layer + self.n_prune_layers))
        elif not layers_to_remove and self.n_prune_layers:
            assert self.importances, "Need to compute importances with eval_importance()"
            layers_to_remove = np.argsort(np.array(self.importances))[:self.n_prune_layers].tolist()

        # remove layers in reverse to avoid indexing errors
        for layer_idx in sorted(layers_to_remove, reverse=True):
            try:
                del self.layers[layer_idx]
            except IndexError:
                print(f"layer {layer_idx} does not exist, function may have already been called")
                return []
        
        return layers_to_remove
    
    def compute_bi(self, hiddens: List[torch.Tensor], angular: bool):
        n = 1
        if angular:
            assert self.n_prune_layers is not None, "Set number of layers to prune to use angular importance"
            n = self.n_prune_layers

        for i in range(len(hiddens) - n):
            in_hidden = hiddens[i]
            out_hidden = hiddens[i+n]
            if angular:
                # use only last token for angular distance as described in section 3.2
                # https://arxiv.org/pdf/2403.17887.pdf
                in_hidden = in_hidden[:,-1:]
                out_hidden = out_hidden[:,-1:]
            
            self.importances[i] += block_influence(
                in_hidden,
                out_hidden,
                angular=angular
            ).sum().cpu().item()

    @torch.inference_mode()
    def eval_importance(
        self,
        prompts: List[str],
        max_seq_len: int,
        stride: int = 256,
        max_gen_len: int = 0,
        temperature: float = 0.6,
        top_p: float = 0.9,
        angular: Optional[bool] = False
    ):
        """
        Computes layer-wise importances over input texts.

        NOTE: ShortGPT paper performs no generation during importance computation, which suggests a `max_gen_len`= 0.

        Args:
            prompts (List[str]): List of prompts.
            max_seq_len (int): Maximum sequence length for model input, the sliding window size.
            (Optional) stride (int): Number of tokens to skip/shift between each window inference.
            (Optional) max_gen_len (int): Maximum length of the generated text sequence.
            (Optional) temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
            (Optional) top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
            (Optional) angular (bool): Whether to ues angular distance. Defaults to False.

        Returns:
            None
        """
        prompt_tokens = self.tokenizer(
            prompts,
            padding=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        input_ids = prompt_tokens.input_ids
        attn_mask = prompt_tokens.attention_mask

        max_prompt_len = max(len(t) for t in input_ids)

        # authors use a sliding window of size 1024 with a shift of 256
        for start in range(0, max_prompt_len, stride):
            seq_ids = (attn_mask.sum(dim=-1) > start).nonzero().squeeze()
            seq_ids = seq_ids.unsqueeze(0) if seq_ids.dim() == 0 else seq_ids  # ensure 2d
            inputs = input_ids[seq_ids, start:start+max_seq_len]
            attn = attn_mask[seq_ids, start:start+max_seq_len]

            if max_gen_len == 0:
                outputs = self.model(
                    input_ids=inputs.to("cuda"),
                    attention_mask=attn.to("cuda"),
                    output_hidden_states=True,
                )
            else:
                outputs = self.model.generate(
                    input_ids=inputs.to("cuda"),
                    attention_mask=attn.to("cuda"),
                    max_new_tokens=max_gen_len, 
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p,
                    output_hidden_states=True,
                    return_dict_in_generate=True,
                )
            
            self.compute_bi(outputs.hidden_states, angular=angular)

        return

### 📚 Loading the Dataset

Now, let's load our dataset!

We are using the **"pg19"** dataset, which contains long English texts from books published before 1919.  
For this notebook, we are loading the **validation split**, and we only take a small sample of **50 examples** to keep things simple and fast.

We also use **streaming** to load the data more efficiently, especially when working with large datasets.

Finally, we create a `DataLoader` to prepare the data for training or evaluation, with a `batch_size` of 1.

> 📝 Note: The shuffle option is commented out for now. If you want to randomize the data order, you can enable it!


In [7]:
data = load_dataset("pg19",split="validation",streaming=True).take(50) # loading the dataset.
dataloader = DataLoader(
    data,
    batch_size=1,
    #shuffle=True,
)

README.md:   0%|          | 0.00/8.11k [00:00<?, ?B/s]

pg19.py:   0%|          | 0.00/6.56k [00:00<?, ?B/s]

### 🦙 Initializing the LLaMA 3 Model (ShortGPT Version)

Next, we set up our model!  
We are using **ShortHFModel**, which is part of the ShortGPT approach from the paper.

Here's what we are doing:
- We set `MAX_SEQ_LEN` to **1024**, which means our model will process sequences of up to 1024 tokens at a time.
- We load the **LLaMA 3.2B model** from Hugging Face (`meta-llama/Llama-3-2-3B`).
- We tell the model where to find its layers using `layers_path="model.layers"`.
- We set `n_prune_layers=5` — this means we will **prune (remove) 5 layers** from the model to make it smaller and faster.

> 📝 Note: Pruning helps reduce the size and computation cost of large models, which is especially useful when fine-tuning!

This prepares our model for the next steps: training and evaluation! 🚀


In [8]:
MAX_SEQ_LEN = 1024
short_model = ShortHFModel(
    model_name="meta-llama/Llama-3.2-3B",
    layers_path="model.layers",
    n_prune_layers=5,
)

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/844 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [9]:
short_model.model


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((3072,), eps=1e-05)
    (rotary_emb

In [10]:
short_model.model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [11]:
# sample generation
gen = short_model.model.generate(
    short_model.tokenizer(["Dhaka is the capital city of"], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=50
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


['Dhaka is the capital city of Bangladesh. It is the largest city of Bangladesh and is also the financial and cultural hub of the country. It is a metropolitan city with a population of more than 10 million. The city is divided into two parts, the old city and the new']

### 🔍 Evaluating Token Importance

Now, let's evaluate the **importance of tokens** in our dataset!

Here’s what we are doing in this step:
- We loop through each batch in our `dataloader`.
- For each batch, we take the **text prompts**.
- We use the model's `eval_importance()` function to check how important each token is in the input.

We pass these parameters:
- `prompts`: The text data from our dataset.
- `max_seq_len=MAX_SEQ_LEN`: The maximum sequence length (1024).
- `stride=256`: This controls how much the window moves across the text. Smaller strides = more overlap.
- `max_gen_len=0`: We set this to 0 because we are not generating new text here — just evaluating.

> 💡 This step helps the model understand which parts of the input text are the most useful, so it can focus on important information during training or pruning!

The `tqdm` progress bar will show us how the process is going. 🚀


In [12]:
for i, batch in enumerate(tqdm(dataloader)):
    prompts = batch['text']

    short_model.eval_importance(
        prompts=prompts,
        max_seq_len=MAX_SEQ_LEN,
        stride=256,
        max_gen_len=0
    )

0it [00:00, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (429039 > 131072). Running this sequence through the model will result in indexing errors


In [13]:
short_model.importances

[10885348.2734375,
 4647563.234375,
 3736188.578125,
 3694681.962890625,
 3917046.2265625,
 3819230.4384765625,
 3642028.6752929688,
 3543967.5844726562,
 2997535.2583007812,
 2868903.3413085938,
 2771850.0830078125,
 3112493.9892578125,
 2694467.1118164062,
 2916800.7465820312,
 2807367.796875,
 2245364.837890625,
 1786318.9145507812,
 1495307.1870117188,
 1404219.2998046875,
 1446112.8139648438,
 1077190.0073242188,
 925231.9125976562,
 860718.6850585938,
 903481.4384765625,
 988853.53125,
 1140624.490234375,
 1498234.85546875,
 5314456.3046875]

In [14]:
param_size = 0
for param in short_model.model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in short_model.model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

# Total size in bytes to GB
size_all_gb = (param_size + buffer_size) / 1024**3
print('Model size: {:.3f} GB'.format(size_all_gb))

Model size: 5.984 GB


In [15]:
short_model.remove_layers()

[22, 23, 21, 24, 20]

In [16]:
short_model.layers

ModuleList(
  (0-22): 23 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
      (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
      (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
      (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
      (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
      (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
    (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
  )
)

In [17]:
short_model.model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [18]:
# reassign layer_idx to attentions for caching
for layer_idx, module in enumerate(short_model.layers):
    module.self_attn.layer_idx = layer_idx

In [19]:
short_model.model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [20]:
gen = short_model.model.generate(
    short_model.tokenizer(["Dhaka is the capital city of"], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=50,
    use_cache=True
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['Dhaka is the capital city of Bangladesh. It is the largest city and the largest metropolitan area in Bangladesh. The name means””""""""""""""""""""""""""""""""']

In [21]:
param_size = 0
for param in short_model.model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in short_model.model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

# Total size in bytes to GB
size_all_gb = (param_size + buffer_size) / 1024**3
print('Model size: {:.3f} GB'.format(size_all_gb))

Model size: 5.047 GB


In [22]:
short_model.model.config.num_hidden_layers = len(short_model.layers)

In [23]:
short_model.model.config


LlamaConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 3072,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 24,
  "num_hidden_layers": 23,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "float16",
  "transformers_version": "4.51.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [24]:
param_size = 0
for param in short_model.model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in short_model.model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

# Total size in bytes to GB
size_all_gb = (param_size + buffer_size) / 1024**3
print('Model size: {:.3f} GB'.format(size_all_gb))

Model size: 5.047 GB


### 💾 Saving the Pruned Model

After evaluating token importance and pruning the model, we now save our work!

Here’s what’s happening:
- We define a name for our new, pruned model: **`shortgpt_llama3_5L_50`**.
- We check if the output directory exists. If not, we create it.
- We save both:
  - The **model weights**.
  - The **tokenizer** (needed to process text inputs for the model).

> 📝 Note: There is also an optional line to save the model configuration (`new_config.save_pretrained()`), but it's currently commented out.

Finally, we print a message to confirm that our pruned model has been saved successfully!

Now, we can load and use this smaller, faster model anytime 🚀


In [25]:
import os
new_model_name = 'shortgpt_llama3_5L_50'
output_dir = './'+new_model_name
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

short_model.model.save_pretrained(output_dir)
short_model.tokenizer.save_pretrained(output_dir)
#new_config.save_pretrained(output_dir)
print(f"Pruned model saved to {output_dir}")

Pruned model saved to ./shortgpt_llama3_5L_50


In [26]:
# Push the model to your Hugging Face repository

short_model.model.push_to_hub(new_model_name, private=False)
short_model.tokenizer.push_to_hub(new_model_name)

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/453M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/rahatneuron/shortgpt_llama3_5L_50/commit/51abf2eefd8f04c9feaa060bc58bf29aa6803911', commit_message='Upload tokenizer', commit_description='', oid='51abf2eefd8f04c9feaa060bc58bf29aa6803911', pr_url=None, repo_url=RepoUrl('https://huggingface.co/rahatneuron/shortgpt_llama3_5L_50', endpoint='https://huggingface.co', repo_type='model', repo_id='rahatneuron/shortgpt_llama3_5L_50'), pr_revision=None, pr_num=None)