<a href="https://colab.research.google.com/github/ahsanabbas123/dynamic-quantization-LSTM/blob/main/Quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%matplotlib inline

In [2]:
# imports
import os
from io import open
import time
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.quantization

### Preprocess 

In [3]:
"""
We define a dictionary and corpus class to hold our WikiText-2 dataset
"""

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ['<eos>']
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)

        return ids


In [4]:
# tokenise and preprocess the dataset
corpus = Corpus('wikitext-2')

In [5]:
# setting variable values

bptt = 35 # sequence length
log_interval = 200 # training output logging frequency
clip = .25 # gradient clipping
lr = 20 # learning rate for optimizer

best_val_loss = None # best model amongst epochs
epochs = 10
batch_size = 32
dropout = 0.2
save = 'model.pt' # path to save our model
criterion = nn.CrossEntropyLoss() 
eval_batch_size = 10

device = torch.device("cuda")

In [6]:
# divide the dataset into batches and trim off excess tokens
def batchify(data, bs):
    nbatch = data.size(0) // bs
    data = data.narrow(0, 0, nbatch * bs)
    data = data.view(bs, -1).t().contiguous()
    return data 

"""
sending the train and val data to GPU but not test data as quantization inference works only on CPU
and we need test data for comparing the loss of normal model vs quantized model at the end
"""
train_data = batchify(corpus.train, batch_size).to(device)
val_data = batchify(corpus.valid, eval_batch_size).to(device)
test_data = batchify(corpus.test, eval_batch_size)

### Model

In [7]:
"""
This is a basic LSTM based model with an encoder and decoder module
"""

class LSTMModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        return decoded, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

In [8]:
ntokens = len(corpus.dictionary)

model = LSTMModel(
    ntoken = ntokens,
    ninp = 512,
    nhid = 256,
    nlayers = 5,
)

# sending the model to GPU for training
model.to(device)
print(model)

LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): Linear(in_features=256, out_features=33278, bias=True)
)


### Utility functions

In [9]:
# Evaluation functions
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

# taken from the pytorch documentation
def repackage_hidden(h):
  """Wraps hidden states in new Tensors, to detach them from their history."""
  if isinstance(h, torch.Tensor):
      return h.detach()
  else:
      return tuple(repackage_hidden(v) for v in h)

def evaluate(model_, data_source):
    # Turn on evaluation mode which disables dropout.
    model_.eval()
    total_loss = 0.
    hidden = model_.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model_(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [10]:
def train(model, corpus, train_data, bptt, criterion, clip, lr, log_interval, batch_size, epoch):
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)

        model.zero_grad()

        hidden = repackage_hidden(hidden)
        output, hidden = model(data, hidden)
        output_flat = output.view(-1, ntokens)
        
        loss = criterion(output_flat, targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        for p in model.parameters():
            p.data.add_(p.grad, alpha=-lr)

        total_loss += loss.item()

        # log the output
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // bptt, lr,
                elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

### Training

In [11]:
def train_epochs(model, corpus, train_data, val_data, bptt, criterion, clip, lr, log_interval, batch_size, epochs, save):
    # Loop over epochs.
    best_val_loss = None
    try:
        for epoch in range(1, epochs+1):
            epoch_start_time = time.time()
            train(model, corpus, train_data, bptt, criterion, clip, lr, log_interval, batch_size, epoch)
            val_loss = evaluate(model, val_data)
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                    'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                            val_loss, math.exp(val_loss)))
            print('-' * 89)
            # Save the model if the validation loss is the best we've seen so far.
            if not best_val_loss or val_loss < best_val_loss:
                # with open(save, 'wb') as f:
                #     torch.save(model, f)
                best_val_loss = val_loss
                best_model = model
            else:
                # Anneal the learning rate if no improvement has been seen in the validation dataset.
                lr /= 4.0
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    return best_model

In [12]:
model = train_epochs(model, corpus, train_data, val_data, bptt, criterion, clip, lr, log_interval, batch_size, epochs, save)

| epoch   1 |   200/ 1864 batches | lr 20.00 | ms/batch 41.76 | loss  7.74 | ppl  2301.78
| epoch   1 |   400/ 1864 batches | lr 20.00 | ms/batch 40.58 | loss  7.33 | ppl  1523.39
| epoch   1 |   600/ 1864 batches | lr 20.00 | ms/batch 40.56 | loss  7.22 | ppl  1361.55
| epoch   1 |   800/ 1864 batches | lr 20.00 | ms/batch 40.57 | loss  7.19 | ppl  1332.12
| epoch   1 |  1000/ 1864 batches | lr 20.00 | ms/batch 40.58 | loss  7.18 | ppl  1317.58
| epoch   1 |  1200/ 1864 batches | lr 20.00 | ms/batch 40.57 | loss  7.15 | ppl  1276.18
| epoch   1 |  1400/ 1864 batches | lr 20.00 | ms/batch 40.53 | loss  7.17 | ppl  1295.56
| epoch   1 |  1600/ 1864 batches | lr 20.00 | ms/batch 40.52 | loss  7.14 | ppl  1259.36
| epoch   1 |  1800/ 1864 batches | lr 20.00 | ms/batch 40.50 | loss  7.14 | ppl  1262.25
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 79.58s | valid loss  6.92 | valid ppl  1012.49
--------------------------

### Dynamic Quantization

We quantize the LSTM and Linear layers of our model using the PyTorch API

In [13]:
model.to("cpu")
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)


### Memory and Performance Comparison 

In [14]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [15]:
print_size_of_model(model)
print_size_of_model(quantized_model)

Size (MB): 113.945551
Size (MB): 79.73967


In [16]:
torch.set_num_threads(1)
def time_model_evaluation(model, test_data):
    s = time.time()
    loss = evaluate(model, test_data)
    elapsed = time.time() - s
    print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))


In [17]:
time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)

loss: 5.152
elapsed time (seconds): 112.1
loss: 5.151
elapsed time (seconds): 83.7
