**Create quantizised version of GPT-2 and save the memory =)**

Egor Shvetsov, Viktoriia Chekalina

In [1]:
# ! pip install transformers

In [2]:
device = 'cuda:0'

## Class Quantizator

In [3]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

![Image](https://drive.google.com/uc?export=view&id=1qG2l66t1tZNk-V-CQs80e-DiyevAF2ea)

In [4]:
class Quantizer(nn.Module):
    def __init__(self, bit):

        super(Quantizer, self).__init__()
        self.bit = bit
        self.thd_neg = -(2 ** (bit - 1))
        self.thd_pos = 2 ** (bit - 1) - 1
        #print ("bit, pos, neg", bit, self.thd_neg, self.thd_pos)
        self.s = nn.Parameter(torch.ones(1))

    def init_from(self, x):
        #print ("x.max(), x.min()", x.max(), x.min())
        s = (x.max() - x.min()) / (self.thd_pos - self.thd_neg )
        #print ("s", s)
        self.s = nn.Parameter(s)

    def skip_grad_scale(self,x, scale):
        y = x
        y_grad = x * scale
        return (y - y_grad).detach() + y_grad

    def round_pass(self, x): # round ans clip are not differentiable
        y = x.round()
        y_grad = x
        return (y - y_grad).detach() + y_grad

    def forward(self, x):
        if self.bit >= 32:
            return x

        s_grad_scale = 1.0 / ((self.thd_pos * x.numel()) ** 0.5)

        #print ("s_grad_scale", s_grad_scale)
        device = x.device

        s_scale = self.skip_grad_scale(self.s, s_grad_scale).to(device)

        #print ("s_scale", s_scale)

        x = x / (s_scale) # go to new space
        x = torch.clamp(x, self.thd_neg, self.thd_pos) # clip
        x = self.round_pass(x) #go back
        x = x * (s_scale)
        return x

In [5]:
vector =(np.arange(-10, 10, 1))**2

In [6]:
vector

array([100,  81,  64,  49,  36,  25,  16,   9,   4,   1,   0,   1,   4,
         9,  16,  25,  36,  49,  64,  81])

In [7]:
quantizer = Quantizer(4)
vector = torch.tensor(vector)
quantizer.init_from(vector)
quantizer(vector).detach().numpy()

array([46.666664 , 46.666664 , 46.666664 , 46.666664 , 33.333332 ,
       26.666666 , 13.333333 ,  6.6666665,  6.6666665,  0.       ,
        0.       ,  0.       ,  6.6666665,  6.6666665, 13.333333 ,
       26.666666 , 33.333332 , 46.666664 , 46.666664 , 46.666664 ],
      dtype=float32)

In [8]:
quantizer.s

Parameter containing:
tensor(6.6667, requires_grad=True)

In [9]:
quantizer = Quantizer(16)
vector = torch.tensor(vector)
quantizer.init_from(vector)
quantizer(vector).detach().numpy()

  vector = torch.tensor(vector)


array([49.999237  , 49.999237  , 49.999237  , 48.99977   , 36.00061   ,
       25.000381  , 16.00061   ,  8.999771  ,  3.9993896 ,  0.99946594,
        0.        ,  0.99946594,  3.9993896 ,  8.999771  , 16.00061   ,
       25.000381  , 36.00061   , 48.99977   , 49.999237  , 49.999237  ],
      dtype=float32)

In [10]:
# vector =(np.arange(-10, 10, 1))**2

# vector = torch.tensor(vector)

# plt.figure(figsize=(10,7))

# f = plt.figure(figsize = (10, 5))

# plt.plot(range(len(vector)), vector,label='orig', marker='o')
# for bit in [2,3,4, 8, 16]:
#     quantizer = Quantizer(bit)
#     quantizer.init_from(vector) # define step
#     dequantized = quantizer(vector) #

#     dequantized = dequantized.detach().numpy()

#     plt.plot(range(len(vector)),dequantized,label=f'int_{bit}')

# plt.title('DE - Quantized');
# plt.legend();

In [11]:
def plot_weight_distribution(model, bitwidth=32):

    fig, axes = plt.subplots(1,len([(p,n) for (p,n) in model.named_parameters()])-1, figsize=(15, 5))

    quantizer = Quantizer(bitwidth)
    qmin, qmax = quantizer.thd_neg,  quantizer.thd_pos,
    for i, (name, param) in enumerate(model.named_parameters()):
        if param.dim() > 1:
            quantizer.init_from(param)
            dequantized = quantizer(param)
            dequantized = dequantized.detach().cpu().numpy()

            axes[i].hist(dequantized.flatten(),  density=True, color = 'blue', alpha = 0.5,
                    edgecolor='black' if bitwidth <= 4 else None)

            axes[i].set_xlabel(name)
            axes[i].set_ylabel('density')

    fig.suptitle(f'Histogram of Weights (bitwidth={bitwidth} bits)')
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)
    plt.show()

In [12]:
class QALinear(nn.Module):
    def __init__(self, fc_w, fc_b, bit):
        super(QALinear, self).__init__()
        self.bit = bit
        self.in_features, self.out_features = fc_w.shape
        self.fc = nn.Linear(in_features = self.in_features,
                      out_features = self.out_features,
                      bias = True)
        self.fc.weight = torch.nn.Parameter(torch.t(fc_w))
        self.fc.bias = fc_b
        self.define_q_fucntions(self.bit)


    # can be used to modify bits during the training
    def define_q_fucntions(self, bit):
        self.quantizer_act = Quantizer(bit)
        self.quantizer_weigh = Quantizer(bit)
        self.quantizer_weigh.init_from(self.fc.weight)

        self.quantizer_bias = Quantizer(bit)

    def forward(self, input_x):
        quantized_weight = self.quantizer_weigh(self.fc.weight)
        quantized_bias = self.quantizer_weigh(self.fc.bias)
        quantized_act = self.quantizer_act(input_x)
        out = nn.functional.linear(quantized_act, quantized_weight, bias=quantized_bias) #torch.nn.functional.linear(input, weight, bias)
        return out

# Transformer-based model quantization

### GPT-2 quantization pipeline

1) Extract fully-connected (transformers.conv1D()) from pretrained GPT-2

2) Create quantized QALinear object over it

3) Replace initial fully-connected layer with quantized layer

4) Fine-tune model to recover performance degradation


## Let's go through all of the stages in the pipeline

In [13]:
import numpy as np

In [14]:
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, TextDataset
from torch.utils.data import DataLoader

# Initializing a GPT2 configuration
configuration = GPT2Config()

device = 'cuda'


model = GPT2LMHeadModel(configuration).from_pretrained("gpt2",
                                                return_dict=True,
                                                is_decoder=True)
outt = model.to(device)


In [15]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

We will validate a pre-trained language GPT-2 model on a validation part **Wikitext-103** dataset.

To create a validation dataset we should:
    
- download archieved dataset
- extract it
- wrap tran and test parts in TextDataset and further Dataloader class

In [16]:
from datasets import load_dataset
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

In [17]:
import os

os.makedirs("data", exist_ok=True)

with open("data/wiki.valid.tokens", "w", encoding="utf-8") as f:
    for line in dataset['validation']['text']:
        f.write(line + "\n")

with open("data/wiki.test.tokens", "w", encoding="utf-8") as f:
    for line in dataset['test']['text']:
        f.write(line + "\n")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

dataset_valid = TextDataset(
    tokenizer=tokenizer,
    file_path="data/wiki.valid.tokens",
    block_size=512
)

dataset_test = TextDataset(
    tokenizer=tokenizer,
    file_path="data/wiki.test.tokens",
    block_size=512
)

eval_dataloader = DataLoader(dataset_valid, batch_size=10)

print(f"Размер валидационного набора: {len(dataset_valid)}")
print(f"Размер тестового набора: {len(dataset_test)}")

Размер валидационного набора: 484
Размер тестового набора: 553




Create a validation dataset and measure perplexity of regular GPT-2

Validate model and calculate a perplexity:

In [18]:
model.eval()
import torch

device = 'cuda'

def evaluate(model):
    losses = []
    eval_loss = 0.0
    perplexity = 0.0
    nb_eval_steps = 0
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = (batch, batch)
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(inputs, labels=labels)
            lm_loss = outputs[0]
            eval_loss += lm_loss.mean().item()
            perplexity += torch.exp(torch.tensor(eval_loss))
            losses.append(eval_loss)
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps

    perplexity = torch.exp(torch.tensor(eval_loss))
    perplexity = perplexity
    return perplexity

In [19]:
perplexity = evaluate(model.to(device))
perplexity

Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

tensor(36.5398)

In [20]:
model.transformer.h[1].mlp.c_fc.weight.shape

torch.Size([768, 3072])

Let's generate text to ensure that model is good.

In [21]:
seq_len = 512

In [22]:
i = 0
for batch_idx, sample in tqdm(enumerate(eval_dataloader)):
    input_ids, label_ids = sample[0][:seq_len - 1], sample[0][1:seq_len]
    input_ids = input_ids.to(device=device)
    label_ids = label_ids.to(device=device)
    outputs = model(input_ids)
    print ("\n")
    print("input_ids[0][-20:]", tokenizer.decode(input_ids[-10:]))
    print ("\n")
    print("label_ids[0][-21:]", tokenizer.decode(input_ids[-9:]))
    print ("\n")
    print("output", tokenizer.decode(outputs.logits.argmax(dim=-1)[-10:]))
    #print("GENERATED", tokenizer.decode(outputs.logits.argmax(dim=-1)[0][-1:]))
    print("\n\n\n")
    i += 1
    if i > 10: break
    print ("\n\n")

0it [00:00, ?it/s]



input_ids[0][-20:]  americanus . The two species are very similar ,


label_ids[0][-21:] anus . The two species are very similar ,


output anus . It American species are closely similar in but









input_ids[0][-20:]  The Michigan Department of Transportation ( MDOT ) ,


label_ids[0][-21:]  Michigan Department of Transportation ( MDOT ) ,


output  current State of Transportation (MOT ) has which









input_ids[0][-20:]  these help downtown become more active during the day and


label_ids[0][-21:]  help downtown become more active during the day and


output  businesses to residents a attractive and the day and provide









input_ids[0][-20:]  turn of the last century . Architecture students from around


label_ids[0][-21:]  of the last century . Architecture students from around


output  of the 20 century . The is and the the









input_ids[0][-20:]  the American Civil War . 

 Dill Harris ,


label_ids[0][-21:]  American Civil War . 

 Dill Harris ,


output  Ci

In [23]:
for bit in [16, 8, 4]:
    model = GPT2LMHeadModel(configuration).from_pretrained("gpt2",
                                                return_dict=True,
                                                is_decoder=True)
    for i in [6, 8, 9, 10, 11]:
        fc_w = model.transformer.h[i].mlp.c_fc
        fp_w = model.transformer.h[i].mlp.c_proj
        model.transformer.h[i].mlp.c_fc = QALinear(fc_w.weight, fc_w.bias, bit)
        model.transformer.h[i].mlp.c_proj = QALinear(fp_w.weight, fp_w.bias, bit)
    perplexity = evaluate(model.to(device))
    print(f'perplexity для квантизации до {bit}bit: {perplexity}')

Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity для квантизации до 16bit: 234.6772003173828


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity для квантизации до 8bit: 235.67401123046875


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity для квантизации до 4bit: 143.77352905273438


perplexity для квантизации до 4bit - ниже чем до 8-16. из объяснений это нестабильность модели при такой квантизации

## Fine-tune the model

In [24]:
model = GPT2LMHeadModel(configuration).from_pretrained("gpt2",
                                                return_dict=True,
                                                is_decoder=True)
for i in [6, 8, 9, 10, 11]:
    fc_w = model.transformer.h[i].mlp.c_fc
    fp_w = model.transformer.h[i].mlp.c_proj
    model.transformer.h[i].mlp.c_fc = QALinear(fc_w.weight, fc_w.bias, 8)
    model.transformer.h[i].mlp.c_proj = QALinear(fp_w.weight, fp_w.bias, 8)
perplexity = evaluate(model.to(device))
print(f'perplexity для квантизации до 8bit: {perplexity}')

Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity для квантизации до 8bit: 235.67401123046875


In [25]:
train_dataloader = DataLoader(dataset_valid, batch_size=4)

In [26]:
from transformers import get_cosine_schedule_with_warmup

In [27]:

epochs = 10
gradient_accumulation_steps = 10
num_train_batches = len(train_dataloader)
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.eos_token_id)
training_steps = num_train_batches // gradient_accumulation_steps * epochs
optimizer = torch.optim.Adam(model.parameters(), lr=1.25e-4)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=800, num_training_steps=training_steps)


for epoch in range(2):

    train_loss = 0
    epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
    for train_batch_idx, item in enumerate(epoch_iterator):
        optimizer.zero_grad() # in case there are unused accumulated gradients

        model.train()
        src, tgt = (item, item)
        src = src.to(device=device)
        tgt = tgt.to(device=device)


        output = model(input_ids=src, labels=tgt)
        loss = output.loss
        train_loss += loss.item() # before gradient accumulation step
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward() # accumulate gradients

        if train_batch_idx % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            optimizer.zero_grad()

        if train_batch_idx % 20 == 0:
            print ("loss", loss)
            perplexity = evaluate(model)
            print ("perplexity", perplexity)


0it [00:00, ?it/s]

loss tensor(0.5827, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(364.2351)
loss tensor(0.6009, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(362.7926)
loss tensor(0.6250, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(360.9767)
loss tensor(0.5880, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(356.6887)
loss tensor(0.5970, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(352.2087)
loss tensor(0.5923, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(345.7061)
loss tensor(0.5627, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(338.3391)


0it [00:00, ?it/s]

loss tensor(0.5782, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(333.5949)
loss tensor(0.5863, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(324.2414)
loss tensor(0.6130, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(314.8481)
loss tensor(0.5706, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(304.6779)
loss tensor(0.5798, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(293.7579)
loss tensor(0.5672, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(283.1278)
loss tensor(0.5426, device='cuda:0', grad_fn=<DivBackward0>)


Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

perplexity tensor(271.4561)


In [28]:
perplexity = evaluate(model)
perplexity

Evaluating:   0%|          | 0/49 [00:00<?, ?it/s]

tensor(272.5520)

In [29]:
i = 0
for batch_idx, sample in tqdm(enumerate(eval_dataloader)):
    input_ids, label_ids = sample[0][:seq_len - 1], sample[0][1:seq_len]
    input_ids = input_ids.to(device=device)
    label_ids = label_ids.to(device=device)
    outputs = model(input_ids)
    print ("\n")
    print("input_ids[0][-20:]", tokenizer.decode(input_ids[-10:]))
    print ("\n")
    print("label_ids[0][-21:]", tokenizer.decode(input_ids[-9:]))
    print ("\n")
    print("output", tokenizer.decode(outputs.logits.argmax(dim=-1)[-10:]))
    #print("GENERATED", tokenizer.decode(outputs.logits.argmax(dim=-1)[0][-1:]))
    print("\n\n\n")
    i += 1
    if i > 10: break
    print ("\n\n")

0it [00:00, ?it/s]



input_ids[0][-20:]  americanus . The two species are very similar ,


label_ids[0][-21:] anus . The two species are very similar ,


output us , The most of are also close, and









input_ids[0][-20:]  The Michigan Department of Transportation ( MDOT ) ,


label_ids[0][-21:]  Michigan Department of Transportation ( MDOT ) ,


output  area State of the and Department- ) and and









input_ids[0][-20:]  these help downtown become more active during the day and


label_ids[0][-21:]  help downtown become more active during the day and


output  " to and a and and the time , the









input_ids[0][-20:]  turn of the last century . Architecture students from around


label_ids[0][-21:]  of the last century . Architecture students from around


output - the world year . The and and the the









input_ids[0][-20:]  the American Civil War . 

 Dill Harris ,


label_ids[0][-21:]  American Civil War . 

 Dill Harris ,


output  first and and .




, , the









input_ids[0][-2