In [1]:
# Mount google drive
from google.colab import drive
ROOT = "/content/drive"
print(ROOT)                 # print content of ROOT (Optional)

drive.mount(ROOT)

/content/drive
Mounted at /content/drive


In [None]:
!pip install transformers
!pip install datasets

In [None]:
!pip install evaluate

In [None]:
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
from datasets import load_metric, load_dataset
import evaluate
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import re
from os.path import join
import pdb
import tqdm

In [None]:
%cd drive/MyDrive/Repos/llm-sparsification-cvf/
from src.pruning_utils import prune_gpt2_layers, check_gpt_layer_sparsity
from src.evalu_utils import compute_ppl

/content/drive/MyDrive/Repos/llm-sparsification-cvf


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2-xl')

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-xl')

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
device = 'cuda'
model.to(device)

In [None]:
test = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test').select(range(1000))
encodings = tokenizer('\n\n'.join(test['text']), return_tensors='pt')

Token indices sequence length is longer than the specified maximum sequence length for this model (71448 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
test

Dataset({
    features: ['text'],
    num_rows: 1000
})

In [None]:
sum(p.numel() for p in model.parameters())

1557611200

In [None]:
# sparsify at 10%, 50%, 90%, 95%, 99%

In [None]:
params = model.state_dict() 
#params.keys()

In [None]:
model.transformer.h[10].attn.c_attn.weight

Parameter containing:
tensor([[ 0.0079, -0.0577, -0.0430,  ..., -0.0271, -0.0114, -0.0463],
        [-0.0701,  0.0133, -0.0299,  ..., -0.0358,  0.0244, -0.0000],
        [ 0.0590,  0.0287, -0.0167,  ..., -0.0103, -0.0111,  0.0347],
        ...,
        [-0.0852,  0.0225,  0.0196,  ...,  0.0338,  0.0103,  0.0075],
        [ 0.0124, -0.0000, -0.0302,  ..., -0.0350, -0.0383, -0.0209],
        [ 0.0118,  0.0052,  0.0344,  ..., -0.0099,  0.0522, -0.0143]],
       device='cuda:0', requires_grad=True)

In [None]:
for sparsity_lvl in [0, 0.1, 0.5, 0.9, 0.95, 0.99]:
    print(f"Pruning at sparsity level: {sparsity_lvl}")
    prune_gpt2_layers(model, sparsity_lvl)
    check_gpt_layer_sparsity(model, 10)
    ppl = compute_ppl(model)
    print(f"Model perplexity at sparsity level {sparsity_lvl} is: {ppl.item()}")


Pruning at sparsity level: 0
Sparsity in h.10.attn.c_attn.weight: 0.00%


100%|██████████| 140/140 [02:54<00:00,  1.25s/it]


Model perplexity at sparsity level 0 is: 15.702324867248535
Pruning at sparsity level: 0.1
Sparsity in h.10.attn.c_attn.weight: 10.00%


100%|██████████| 140/140 [02:58<00:00,  1.27s/it]


Model perplexity at sparsity level 0.1 is: 15.717995643615723
Pruning at sparsity level: 0.5
Sparsity in h.10.attn.c_attn.weight: 50.00%


100%|██████████| 140/140 [02:52<00:00,  1.23s/it]


Model perplexity at sparsity level 0.5 is: 205.3623046875
Pruning at sparsity level: 0.9
Sparsity in h.10.attn.c_attn.weight: 90.00%


100%|██████████| 140/140 [02:40<00:00,  1.14s/it]


Model perplexity at sparsity level 0.9 is: 5921.974609375
Pruning at sparsity level: 0.95
Sparsity in h.10.attn.c_attn.weight: 95.00%


100%|██████████| 140/140 [02:38<00:00,  1.13s/it]


Model perplexity at sparsity level 0.95 is: 18557.201171875
Pruning at sparsity level: 0.99
Sparsity in h.10.attn.c_attn.weight: 99.00%


100%|██████████| 140/140 [02:36<00:00,  1.12s/it]

Model perplexity at sparsity level 0.99 is: 27840.955078125





In [None]:
import pandas as pd
results_df = pd.DataFrame({'sparsity' : [0, 0.1, 0.5, 0.9, 0.95, 0.99],
                           'perplexity' : [15.70, 15.72, 205.4, 5921.98, 18557.2, 27840.95],
                           'speed (seconds/iteration)': [1.25, 1.27, 1.23, 1.14, 1.13, 1.12]})
results_df

Unnamed: 0,sparsity,perplexity,speed (seconds/iteration)
0,0.0,15.7,1.25
1,0.1,15.72,1.27
2,0.5,205.4,1.23
3,0.9,5921.98,1.14
4,0.95,18557.2,1.13
5,0.99,27840.95,1.12
