# Quantizing RNN Models

In this example, we show how to quantize recurrent models.  
Using a pretrained model `model.RNNModel`, we convert the built-in pytorch implementation of LSTM to our own, modular implementation.  
The pretrained model was generated with:  
```time python3 main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --tied --wd=1e-6```  
The reason we replace the LSTM that is because the inner operations in the pytorch implementation are not accessible to us, but we still want to quantize these operations. <br />
Afterwards we can try different techniques to quantize the whole model.  

_NOTE_: We use `tqdm` to plot progress bars, since it's not in `requirements.txt` you should install it using 
`pip install tqdm`.

In [1]:
from model import PerformanceRNN
from distiller_model import DistillerPerformanceRNN
import torch
from torch import nn
import distiller
from distiller.modules import DistillerLSTM as LSTM
from distiller.modules import convert_model_to_distiller_lstm
from tqdm import tqdm # for pretty progress bar
import numpy as np

### Preprocess the data:

Skip this, my data is already preprocessed.

In [None]:
corpus = Corpus('./data/wikitext-2/')

In [None]:
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)
device = 'cuda:0'
batch_size = 20
eval_batch_size = 10
train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)

### Loading the model and converting to our own implementation.

In [None]:
# Original (keep for sake of comparison)
rnn_model = torch.load('./checkpoint.pth.tar.best')
rnn_model = rnn_model.to(device)
rnn_model

In [2]:
assert torch.cuda.is_available()
device = 'cuda:0'
sess_path = "save/LSTM_model.sess"
state = torch.load(sess_path)
rnn_model = PerformanceRNN(**state['model_config']).to(device)
rnn_model.load_state_dict(state['model_state'])

In [3]:
# Trying out something I saw in the source code.
man_model = convert_model_to_distiller_lstm(rnn_model)

In [4]:
man_model.gru

DistillerLSTM(512, 512, num_layers=3, dropout=0.30, bidirectional=False)

Here we convert the pytorch LSTM implementation to our own, by calling `LSTM.from_pytorch_impl`:

In [None]:
# Original, I think I did this with convert_model_to_distiller_lstm above.
def manual_model(pytorch_model_: 'RNNModel'):
    nlayers, ninp, nhid, ntoken, tie_weights = \
        pytorch_model_.nlayers, \
        pytorch_model_.ninp, \
        pytorch_model_.nhid, \
        pytorch_model_.ntoken, \
        pytorch_model_.tie_weights

    model = DistillerRNNModel(nlayers=nlayers, ninp=ninp, nhid=nhid, ntoken=ntoken, tie_weights=tie_weights).to(device)
    model.eval()
    model.encoder.weight = nn.Parameter(pytorch_model_.encoder.weight.clone().detach())
    model.decoder.weight = nn.Parameter(pytorch_model_.decoder.weight.clone().detach())
    model.decoder.bias = nn.Parameter(pytorch_model_.decoder.bias.clone().detach())
    model.rnn = LSTM.from_pytorch_impl(pytorch_model_.rnn)

    return model

man_model = manual_model(rnn_model)
torch.save(man_model, 'manual.checkpoint.pth.tar')
man_model

### Batching the data for evaluation:

In [None]:
# Original
sequence_len = 35
def get_batch(source, i):
    seq_len = min(sequence_len, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

hidden = rnn_model.init_hidden(eval_batch_size)
data, targets = get_batch(test_data, 0)

In [None]:
sequence_len = 35
def get_batch(source, i):
    seq_len = min(sequence_len, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

batch_size = 20
eval_batch_size = 10

hidden = rnn_model.init_hidden(eval_batch_size)
data, targets = get_batch(test_data, 0)

### Check that the convertion has succeeded:

In [5]:
rnn_model.eval()

PerformanceRNN(
  (inithid_fc): Linear(in_features=32, out_features=1536, bias=True)
  (inithid_fc_activation): Tanh()
  (event_embedding): Embedding(240, 240)
  (concat_input_fc): Linear(in_features=265, out_features=512, bias=True)
  (concat_input_fc_activation): LeakyReLU(negative_slope=0.1, inplace)
  (gru): DistillerLSTM(512, 512, num_layers=3, dropout=0.30, bidirectional=False)
  (output_fc): Linear(in_features=1536, out_features=240, bias=True)
  (output_fc_activation): Softmax()
)

In [6]:
man_model.eval()

PerformanceRNN(
  (inithid_fc): Linear(in_features=32, out_features=1536, bias=True)
  (inithid_fc_activation): Tanh()
  (event_embedding): Embedding(240, 240)
  (concat_input_fc): Linear(in_features=265, out_features=512, bias=True)
  (concat_input_fc_activation): LeakyReLU(negative_slope=0.1, inplace)
  (gru): DistillerLSTM(512, 512, num_layers=3, dropout=0.30, bidirectional=False)
  (output_fc): Linear(in_features=1536, out_features=240, bias=True)
  (output_fc_activation): Softmax()
)

In [None]:
y_t, h_t = rnn_model(data, hidden)
y_p, h_p = man_model(data, hidden)

print("Max error in y: %f" % (y_t-y_p).abs().max().item())

### Defining the evaluation:

In [None]:
criterion = nn.CrossEntropyLoss()
def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)
    

def evaluate(model, data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        # The line below was fixed as per: https://github.com/pytorch/examples/issues/214
        for i in tqdm(range(0, data_source.size(0), sequence_len)):
            data, targets = get_batch(data_source, i)
            output, hidden = model(data, hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
            hidden = repackage_hidden(hidden)
    return total_loss / len(data_source)

# Quantizing the model:

## Collect activation statistics:

The model uses activation statistics to determine how big the quantization range is. The bigger the range - the larger the round off error after quantization which leads to accuracy drop.  
Our goal is to minimize the range s.t. it contains the absolute most of our data.  
After that, we divide the range into chunks of equal size, according to the number of bits, and transform the data according to this scale factor.  
Read more on scale factor calculation [in our docs](https://nervanasystems.github.io/distiller/algo_quantization.html).

The class `QuantCalibrationStatsCollector` collects the statistics for defining the range $r = max - min$.  

Each forward pass, the collector records the values of inputs and outputs, for each layer:
- absolute over all batches min, max (stored in `min`, `max`)
- average over batches, per batch min, max (stored in `avg_min`, `avg_max`)
- mean
- std
- shape of output tensor  

All these values can be used to define the range of quantization, e.g. we can use the absolute `min`, `max` to define the range.

In [None]:
# Original.
import os
from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context

man_model = torch.load('./manual.checkpoint.pth.tar')
distiller.utils.assign_layer_fq_names(man_model)
collector = QuantCalibrationStatsCollector(man_model)

if not os.path.isfile('manual_lstm_pretrained_stats.yaml'):
    with collector_context(collector) as collector:
        val_loss = evaluate(man_model, val_data)
        collector.save('manual_lstm_pretrained_stats.yaml')

In [None]:
assert torch.cuda.is_available()
device = 'cuda:0'
sess_path = "save/LSTM_model.sess"
state = torch.load(sess_path)
man_model = PerformanceRNN(**state['model_config']).to(device)
rnn_model.load_state_dict(state['model_state'])

Check that `man_model` has the same weights as `rnn_model`.

In [9]:
man_model.output_fc.weight

Parameter containing:
tensor([[-0.1581,  0.1355,  0.1228,  ...,  0.0996, -0.0048,  0.2377],
        [-0.1011,  0.0526,  0.0331,  ...,  0.0476, -0.0766, -0.0931],
        [-0.0936,  0.1523,  0.0625,  ...,  0.0372, -0.0348, -0.0918],
        ...,
        [-0.0509,  0.0499,  0.0362,  ...,  0.1872, -0.0104, -0.0966],
        [ 0.0243, -0.0208,  0.0151,  ...,  0.2488, -0.0781, -0.1401],
        [-0.0366,  0.1303,  0.0476,  ...,  0.1466, -0.0225, -0.0725]],
       device='cuda:0', requires_grad=True)

In [12]:
rnn_model.output_fc.weight

Parameter containing:
tensor([[-0.1581,  0.1355,  0.1228,  ...,  0.0996, -0.0048,  0.2377],
        [-0.1011,  0.0526,  0.0331,  ...,  0.0476, -0.0766, -0.0931],
        [-0.0936,  0.1523,  0.0625,  ...,  0.0372, -0.0348, -0.0918],
        ...,
        [-0.0509,  0.0499,  0.0362,  ...,  0.1872, -0.0104, -0.0966],
        [ 0.0243, -0.0208,  0.0151,  ...,  0.2488, -0.0781, -0.1401],
        [-0.0366,  0.1303,  0.0476,  ...,  0.1466, -0.0225, -0.0725]],
       device='cuda:0', requires_grad=True)

Check that `man_model` is on the GPU.

In [11]:
next(man_model.parameters()).is_cuda

True

In [16]:
# My version.
import os
from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context

# Commented line is probably not necessary.
#man_model = torch.load('./manual.checkpoint.pth.tar')
distiller.utils.assign_layer_fq_names(man_model)
collector = QuantCalibrationStatsCollector(man_model)

# Random numbers.
batch_size = 64
max_len = 100

if not os.path.isfile('performance_rnn_pretrained_stats.yaml'):
    with collector_context(collector) as collector:
        init = torch.randn(batch_size, man_model.init_dim).to(device)
        output = man_model.generate(init, max_len)
        collector.save('performance_rnn_pretrained_stats.yaml')

In [14]:
man_model.init_dim

32

## Quantize Model:
  
We quantize the model after the training has completed.  
Here we check the baseline model perplexity, to have an idea how good the quantization is.

In [None]:
from distiller.quantization import PostTrainLinearQuantizer, LinearQuantMode
from copy import deepcopy

# Load and evaluate the baseline model.
man_model = torch.load('./manual.checkpoint.pth.tar')
val_loss = evaluate(man_model, val_data)
print('val_loss:%8.2f\t|\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))

Now we do our magic - __Quantizing the model__.  
The quantizer replaces the layers in out model with their quantized versions.  
We can see that our model has changed:

In [19]:
from distiller.quantization import PostTrainLinearQuantizer, LinearQuantMode
from copy import deepcopy
# Define the quantizer
quantizer = PostTrainLinearQuantizer(
    deepcopy(man_model),
    model_activation_stats='performance_rnn_pretrained_stats.yaml')

# Quantizer magic:
quantizer.prepare_model()

In [20]:
quantizer.model

PerformanceRNN(
  (inithid_fc): RangeLinearQuantParamLayerWrapper(
    mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None
    preset_activation_stats=True
    w_scale=854.0115, w_zero_point=0.0000
    in_scale=38.1417, in_zero_point=0.0000
    out_scale=92.7936, out_zero_point=0.0000
    (wrapped_module): Linear(in_features=32, out_features=1536, bias=True)
  )
  (inithid_fc_activation): Tanh()
  (event_embedding): RangeLinearEmbeddingWrapper(
    (wrapped_module): Embedding(240, 240)
  )
  (concat_input_fc): RangeLinearQuantParamLayerWrapper(
    mode=SYMMETRIC, num_bits_acts=8, num_bits_params=8, num_bits_accum=32, clip_acts=NONE, per_channel_wts=False, scale_approx_mult_bits=None
    preset_activation_stats=True
    w_scale=91.4019, w_zero_point=0.0000
    in_scale=127.0000, in_zero_point=0.0000
    out_scale=47.6283, out_zero_point=0.0000
    (wrapped_module): Linear(in_features=265, out_features

In [None]:
val_loss = evaluate(quantizer.model.to(device), val_data)

In [None]:
print('val_loss:%8.2f\t|\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))

As we can see here, the perplexity has increased much - meaning our quantization has damaged the accuracy of our model.  
Let's try quantizing each channel separately, and making the range of the quantization asymmetric.  
Also - we replaced the `min`, `max` boundaries manually in the file.  
The idea is - the quantizer takes the absolute `min`, `max` boundaries by default, and in the original file many of the activations had a very large range that makes our quants very big - while we want to minimize their size since each quant corresponds to a roundoff error.  
The activations in every LSTM are either `sigmoid` or `tanh`, and since these are bounded respectively by
$[0,1]$, $[-1,1]$ and they saturate very quickly - we can clip the inputs to be between in the range of $[-6,6]$.

In [None]:
quantizer = PostTrainLinearQuantizer(
    deepcopy(man_model),
    model_activation_stats='./manual_lstm_pretrained_stats_new.yaml',
    mode=LinearQuantMode.ASYMMETRIC_SIGNED,
    per_channel_wts=True
)
quantizer.prepare_model()
quantizer.model

In [None]:
val_loss = evaluate(quantizer.model.to(device), val_data)
print('val_loss:%8.2f\t|\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))

A tiny bit better, but still no good. Let us try the half precision version of the model:

In [None]:
model_fp16 = deepcopy(man_model).half()
val_loss = evaluate(model_fp16, val_data)
print('val_loss: %8.6f\t|\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))

The result is very close to our original model! That means that the roundoff when quantizing lineary is what hurts our accuracy. Let's try then quantizing everything except elemtentwise operations, as stated in 
[`Effective Quantization Methods for Recurrent Neural Networks`](https://arxiv.org/abs/1611.10176) :

In [None]:
overrides_yaml = """
.*eltwise.*:
    fp16: true
encoder:
    fp16: true
decoder:
    fp16: true
"""
overrides = distiller.utils.yaml_ordered_load(overrides_yaml)
quantizer = PostTrainLinearQuantizer(
    deepcopy(man_model),
    model_activation_stats='./manual_lstm_pretrained_stats_new.yaml',
    mode=LinearQuantMode.ASYMMETRIC_SIGNED,
    overrides=overrides,
    per_channel_wts=True
)
quantizer.prepare_model()
val_loss = evaluate(quantizer.model.to(device), val_data)
print('val_loss:%8.6f\t|\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))

In [None]:
quantizer.model

The accuracy is still holding up very well, even though we quantized the inner linear layers!  
Now, lets try to choose different boundaries for `min`, `max` -  
Instead of using absolute ones, we take the average of all batches (`avg_min`, `avg_max`), which is an indication of where usually most of the boundaries lie. This is done by specifying the `clip_acts` parameter to `ClipMode.AVG` or `"AVG"` in the quantizer ctor:

In [None]:
overrides_yaml = """
encoder:
    fp16: true
decoder:
    fp16: true
"""
overrides = distiller.utils.yaml_ordered_load(overrides_yaml)
quantizer = PostTrainLinearQuantizer(
    deepcopy(man_model),
    model_activation_stats='./manual_lstm_pretrained_stats.yaml',
    mode=LinearQuantMode.ASYMMETRIC_SIGNED,
    overrides=overrides,
    per_channel_wts=True,
    clip_acts="AVG"
)
quantizer.prepare_model()
val_loss = evaluate(quantizer.model.to(device), val_data)
print('val_loss:%8.6f\t|\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))

Great! Even though we quantized all of the layers except the embedding and the decoder - we got almost no accuracy penalty. Lets try quantizing them as well:

In [None]:
quantizer = PostTrainLinearQuantizer(
    deepcopy(man_model),
    model_activation_stats='./manual_lstm_pretrained_stats_new.yaml',
    mode=LinearQuantMode.ASYMMETRIC_SIGNED,
    per_channel_wts=True,
    clip_acts="AVG"
)
quantizer.prepare_model()
val_loss = evaluate(quantizer.model.to(device), val_data)
print('val_loss:%8.6f\t|\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))

In [None]:
quantizer.model

Here we see that sometimes quantizing with the right boundaries gives better results than actually using floating point operations (even though they are half precision). 

## Conclusion

Choosing the right boundaries for quantization  was crucial for achieving almost no degradation in accrucay of LSTM.  
  
Here we showed how to use the distiller quantization API to quantize an RNN model, by converting the pytorch implementation into a modular one and then quantizing each layer separately.