# Prerequisites (results below)

In [1]:
from google.colab import drive
import sys


drive.mount('/content/drive')

sys.path.append('/content/drive/MyDrive/dnn_model_optimization')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install -q torchmetrics torchinfo

In [3]:
from torch.utils.data import DataLoader
from utils.torch_helpers import train_model, validate_model, warmup_torch_model
from utils.torch_model import CRNN
from utils.torch_pruning import prune_torch_model, get_layers_to_prune
from utils.data import decode_texts, load_data, OCRDataset
import torch
from torch import nn
from torchinfo import summary
from copy import deepcopy
import matplotlib.pyplot as plt

((train_imgs, train_abits), train_labels), ((val_imgs, val_abits), val_labels), alphabet = load_data('/content/drive/MyDrive/dnn_model_optimization/data', split=True)

train_dataset = OCRDataset(train_imgs, train_abits, train_labels)
val_dataset = OCRDataset(val_imgs, val_abits, val_labels)

train_loader = DataLoader(train_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)

Load model and make a copy of it so we dont have to reload it from disk every time smth goes south

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = CRNN(len(alphabet))
model.load_state_dict(torch.load('/content/drive/MyDrive/dnn_model_optimization/weights/crnn_common_fields_.pt', map_location=torch.device(device)))
summary(model, input_size=[(32, 1, 32, 400), (32, 50, 2)], device=device, depth=1)

Layer (type:depth-idx)                   Output Shape              Param #
CRNN                                     [32, 50, 46]              --
├─Sequential: 1-1                        [32, 256, 1, 50]          425,856
├─LSTM: 1-2                              [32, 50, 256]             528,384
├─LSTM: 1-3                              [32, 50, 256]             526,336
├─Sequential: 1-4                        [32, 50, 46]              11,822
Total params: 1,492,398
Trainable params: 1,492,398
Non-trainable params: 0
Total mult-adds (G): 7.49
Input size (MB): 1.65
Forward/backward pass size (MB): 413.47
Params size (MB): 5.97
Estimated Total Size (MB): 421.09

In [5]:
print('Original model before warmup: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model, val_loader, alphabet, device=device)])))
warmup_torch_model(model, [(32, 1, 32, 400), (32, 50, 2)], device)
print('Original model after warmup: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model, val_loader, alphabet, device=device)])))

Original model before warmup:  {'batch_time': 0.004436, 'loss': 14.042188, 'metric': 0.049073}
Original model after warmup:  {'batch_time': 0.004552, 'loss': 14.042188, 'metric': 0.049073}


# Get ALL layers we want to prune (and others that will be affected).

Here I try to prune every layer, though, as the experimets had shown, it's better to leave RNN layers out of pruning. Partial pruning below

Also, it must be considered that some optionally trainable layers such as BatchNorm should be pruned to maintain shapes

In [6]:
model_to_prune = deepcopy(model)

model_to_prune.to('cpu')

layers = get_layers_to_prune(model_to_prune)
layers_to_prune = list(filter(lambda x: isinstance(x, (nn.Conv2d, nn.BatchNorm2d, nn.LSTM, nn.Linear)), layers))
# layers_to_prune = list(filter(lambda x: isinstance(x, (nn.Conv2d, nn.BatchNorm2d)), layers))
layers_to_prune

[Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 LSTM(258, 256, batch_first=True),
 LSTM(256, 256, batch_first=True),
 Linear(in_features=256, out_features=46, bias=True)]

The pruning function moved to utils/torch_pruning.py. Its pure torch, so make sure to check it if you're interested

In [7]:
prune_torch_model(layers_to_prune)

In [8]:
summary(model_to_prune, input_size=[(32, 1, 32, 400), (32, 50, 2)], device=device, depth=1)

Layer (type:depth-idx)                   Output Shape              Param #
CRNN                                     [32, 50, 46]              --
├─Sequential: 1-1                        [32, 231, 1, 50]          348,867
├─LSTM: 1-2                              [32, 50, 231]             430,584
├─LSTM: 1-3                              [32, 50, 231]             428,736
├─Sequential: 1-4                        [32, 50, 46]              10,672
Total params: 1,218,859
Trainable params: 1,218,859
Non-trainable params: 0
Total mult-adds (G): 6.15
Input size (MB): 1.65
Forward/backward pass size (MB): 374.68
Params size (MB): 4.88
Estimated Total Size (MB): 381.21

As we can see model lost about 300K parameters during pruning. Lets validate it, tune it a little bit and validate again

In [9]:
warmup_torch_model(model_to_prune, [(32, 1, 32, 400), (32, 50, 2)], device)
print('Full prunned model w/o tuning: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model_to_prune, val_loader, alphabet, device=device)])))

Full prunned model w/o tuning:  {'batch_time': 0.003171, 'loss': 16.876544, 'metric': 1.40559}


In [10]:
best_state, _ = train_model(model_to_prune,  alphabet, 2, train_loader, val_loader, lr=5e-4, device=device)
model_to_prune.load_state_dict(best_state)

Epoch 0, 103/103, loss: 22.635205, cer: 0.554532, val_loss: 8.270752, val_cer: 0.231835
Epoch 1, 103/103, loss: 1.505179, cer: 0.049289, val_loss: 3.5633, val_cer: 0.121446


<All keys matched successfully>

In [11]:
print('Full prunned model w/ tuning: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model_to_prune, val_loader, alphabet, device=device)])))

Full prunned model w/ tuning:  {'batch_time': 0.003432, 'loss': 14.100338, 'metric': 0.050777}


In [12]:
warmup_torch_model(model, [(32, 1, 32, 400), (32, 50, 2)], device)
print('Original model: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model, val_loader, alphabet, device=device)])))

Original model:  {'batch_time': 0.003332, 'loss': 14.042188, 'metric': 0.049073}


## Results for full pruning

* Original model: {'batch_time': 0.003332, 'loss': 14.042188, 'metric': 0.049073}
* After pruning (w/o tuning): {'batch_time': 0.003171, 'loss': 16.876544, 'metric': 1.40559}
* After pruning (w/ tuning): {'batch_time': 0.003432, 'loss': 14.100338, 'metric': 0.050777}

# Partial pruning (only convs)

In [13]:
model_to_prune = deepcopy(model)

model_to_prune.to('cpu')

layers = get_layers_to_prune(model_to_prune)
layers_to_prune = list(filter(lambda x: isinstance(x, (nn.Conv2d, nn.BatchNorm2d, nn.LSTM, nn.Linear)), layers))[:-2]
layers_to_prune

[Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same),
 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 LSTM(258, 256, batch_first=True)]

In [14]:
prune_torch_model(layers_to_prune)

In [15]:
summary(model_to_prune, input_size=[(32, 1, 32, 400), (32, 50, 2)], device=device, depth=1)

Layer (type:depth-idx)                   Output Shape              Param #
CRNN                                     [32, 50, 46]              --
├─Sequential: 1-1                        [32, 231, 1, 50]          348,867
├─LSTM: 1-2                              [32, 50, 256]             502,784
├─LSTM: 1-3                              [32, 50, 256]             526,336
├─Sequential: 1-4                        [32, 50, 46]              11,822
Total params: 1,389,809
Trainable params: 1,389,809
Non-trainable params: 0
Total mult-adds (G): 6.42
Input size (MB): 1.65
Forward/backward pass size (MB): 375.32
Params size (MB): 5.56
Estimated Total Size (MB): 382.53

In [16]:
warmup_torch_model(model_to_prune, [(32, 1, 32, 400), (32, 50, 2)], device)
print('Partial prunned model w/o tuning: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model_to_prune, val_loader, alphabet, device=device)])))

Partial prunned model w/o tuning:  {'batch_time': 0.003353, 'loss': 14.052087, 'metric': 0.055323}


In [17]:
best_state, _ = train_model(model_to_prune,  alphabet, 2, train_loader, val_loader, lr=5e-4, device=device)
model_to_prune.load_state_dict(best_state)
print('Partial prunned model w/ tuning: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model_to_prune, val_loader, alphabet, device=device)])))

Epoch 0, 103/103, loss: 0.217662, cer: 0.038555, val_loss: 3.066838, val_cer: 0.104845
Epoch 1, 103/103, loss: 0.142083, cer: 0.037157, val_loss: 3.587152, val_cer: 0.116859
Partial prunned model w/ tuning:  {'batch_time': 0.003161, 'loss': 14.0428, 'metric': 0.049154}


In [18]:
warmup_torch_model(model, [(32, 1, 32, 400), (32, 50, 2)], device)
print('Original model: ', dict(zip(['batch_time', 'loss', 'metric'], [round(e, 6) for e in validate_model(model, val_loader, alphabet, device=device)])))

Original model:  {'batch_time': 0.003282, 'loss': 14.042188, 'metric': 0.049073}


## Results for partial pruning

* Original model: {'batch_time': 0.003282, 'loss': 14.042188, 'metric': 0.049073}
* After pruning (w/o tuning): {'batch_time': 0.003353, 'loss': 14.052087, 'metric': 0.055323}
* After pruning (w/ tuning): {'batch_time': 0.003161, 'loss': 14.0428, 'metric': 0.049154}

# Overall results

<pre>
+---------------------------------+-------------+-------------+-------------+------------+
|      Model                      | #Params     |   Val loss  | Val CER     | Batch time |
+---------------------------------+-------------+-------------+-------------+------------+
| Original model                  | 1.492M      | 14.042188   | 0.049073    | 0.003332   |
+---------------------------------+-------------+-------------+-------------+------------+
| Full pruning (before tuning)    | 1.219M      | 16.876544   | 1.40559     | 0.003171   |
+---------------------------------+-------------+-------------+-------------+------------+
| Full pruning (after tuning)     | 1.219M      | 14.100338   | 0.050777    | 0.003432   |
+---------------------------------+-------------+-------------+-------------+------------+
| Partial pruning (before tuning) | 1.389M      | 14.052087   | 0.055323    | 0.003353   |
+---------------------------------+-------------+-------------+-------------+------------+
| Partial pruning (after tuning)  | 1.389M      | 14.0428     | 0.049154    | 0.003161   |
+---------------------------------+-------------+-------------+-------------+------------+
</pre>

As we can see, partial pruning looks better (due to lower net stress). Reccurent part doesn't have to restructure hidden size vecor as its dimension persists unchanged