# Pruning BERT/SQuAD Model
This notebook shows how to reduce the size of a model by pruning its parameters. It assumes 
that a trained model already exists in the ```Models``` directory. Please refer to the notebook [Question Answering (BERT/SQuAD)](BertSquad.ipynb) for more info about training and using a BERT/SQuAD model.

If you want to prune a Low-Rank model, you can use [this](BertSquad-Reduce.ipynb) notebook
to reduce the number of parameters in ```BERT/SQuAD```.

## Load and evaluate the original pretrained model

In [1]:
from fireball import Model, myPrint
from fireball.datasets.squad import SquadDSet
import time, os

gpus = "upto4"

testDs = SquadDSet.makeDatasets("Test", batchSize=128, version=1 )

orgFileName = "Models/BertSquadRR.fbm"    # Reduced - Retrained

model = Model.makeFromFile(orgFileName, testDs=testDs, gpus=gpus)   
model.printLayersInfo()
model.initSession()
results = model.evaluate()

Initializing tokenizer from "/data/SQuAD/vocab.txt" ... Done. (Vocab Size: 30522)

Reading from "Models/BertSquadRR.fbm" ... Done.
Creating the fireball model "Bert-SQuAD" ... Done.

Scope            InShape       Comments                 OutShape      Activ.   Post Act.        # of Params
---------------  ------------  -----------------------  ------------  -------  ---------------  -----------
IN_EMB           ≤512 2                                 ≤512 768      None                      23,835,648 
S1_L1_LN         ≤512 768                               ≤512 768      None     DO:0.1           1,536      
S2_L1_BERT       ≤512 768      768/3072, 12 heads       ≤512 768      GELU                      5,097,216  
S2_L2_BERT       ≤512 768      768/3072, 12 heads       ≤512 768      GELU                      5,177,088  
S2_L3_BERT       ≤512 768      768/3072, 12 heads       ≤512 768      GELU                      5,207,808  
S2_L4_BERT       ≤512 768      768/3072, 12 heads       ≤512 

## Pruning the model
Here we prune the model using the [pruneModel](https://interdigitalinc.github.io/Fireball/html/source/model.html#fireball.model.Model.pruneModel) class method of the model.

In [2]:
prunedFileName = orgFileName.replace('.fbm', 'P.fbm')  # Append 'P' to the filename for "Pruned"
pResults = Model.pruneModel(orgFileName, prunedFileName, mseUb=.00005)


Reading model parameters from "Models/BertSquadRR.fbm" ... Done.
Pruning 271 tensors using 36 workers ... 
   Pruning Parameters:
        mseUb ................ 0.000050
Pruning process complete (8.76 Sec.)
Now saving to "Models/BertSquadRRP.fbm" ... Done.

Number of parameters: 89,385,986 -> 53,047,718 (36,338,268 pruned)
Model File Size: 357,560,138 -> 223,370,708 bytes


## Evaluate the pruned model
Compare the new number of parameters with the original. Let's see the impact of this reduction to the performance of the model.

In [3]:
model = Model.makeFromFile(prunedFileName, testDs=testDs, gpus=gpus)   
model.printLayersInfo()
model.initSession()
results = model.evaluate()


Reading from "Models/BertSquadRRP.fbm" ... Done.
Creating the fireball model "Bert-SQuAD" ... Done.

Scope            InShape       Comments                 OutShape      Activ.   Post Act.        # of Params
---------------  ------------  -----------------------  ------------  -------  ---------------  -----------
IN_EMB           ≤512 2                                 ≤512 768      None                      15,974,954 
S1_L1_LN         ≤512 768                               ≤512 768      None     DO:0.1           1,536      
S2_L1_BERT       ≤512 768      768/3072, 12 heads       ≤512 768      GELU                      2,838,763  
S2_L2_BERT       ≤512 768      768/3072, 12 heads       ≤512 768      GELU                      2,899,141  
S2_L3_BERT       ≤512 768      768/3072, 12 heads       ≤512 768      GELU                      2,910,253  
S2_L4_BERT       ≤512 768      768/3072, 12 heads       ≤512 768      GELU                      3,037,189  
S2_L5_BERT       ≤512 768      768

## Re-training after pruning
Here we make a new model for re-training from the file created above. We then call the [train](https://interdigitalinc.github.io/Fireball/html/source/model.html#fireball.model.Model.train) method of the model to start the re-training.

After re-training, we run the [evaluate](https://interdigitalinc.github.io/Fireball/html/source/model.html#fireball.model.Model.evaluate) function again to see how the re-training improved the performance
of the model.

The re-trained model is then saved to a file appending an 'R' letter (for Re-trained) to the end of the pruned model file name.

In [4]:
trainDs = SquadDSet.makeDatasets("Train", batchSize=128, version=1 )

model = Model.makeFromFile(prunedFileName, trainDs=trainDs, testDs=testDs,
                           batchSize=32, numEpochs=2,
                           regFactor=0.0001,
                           learningRate=(2e-5,4e-6), optimizer='Adam',
                           saveBest=False,
                           gpus=gpus)
model.printNetConfig()
model.initSession()
model.train()
results = model.evaluate()

retrainedFileName = prunedFileName.replace('.fbm', 'R.fbm')  # Append 'R' to the filename for "Retrained"
model.save(retrainedFileName)

Initializing tokenizer from "/data/SQuAD/vocab.txt" ... Done. (Vocab Size: 30522)

Reading from "Models/BertSquadRRP.fbm" ... Done.
Creating the fireball model "Bert-SQuAD" ... Done.

Network configuration:
  Input:                     A tuple of TokenIds and TokenTypes.
  Output:                    2 logit vectors (with length ≤ 512) for start and end indexes of the answer.
  Network Layers:            16
  Tower Devices:             GPU0, GPU1, GPU2, GPU3
  Total Network Parameters:  53,047,718
  Total Parameter Tensors:   271
  Trainable Tensors:         271
  Training Samples:          87,844
  Test Samples:              10,833
  Num Epochs:                2
  Batch Size:                32
  L2 Reg. Factor:            0.0001
  Global Drop Rate:          0   
  Learning Rate: (Exponential Decay)
    Initial Value:           0.00002      
    Final Value:             0.000004     
  Optimizer:                 Adam

+--------+---------+---------------+-----------+-------------------+


## Also look at

[Quantizing BERT/SQuAD Model](BertSquad-Quantize.ipynb)

[Exporting BERT/SQuAD Model to ONNX](BertSquad-ONNX.ipynb)

[Exporting BERT/SQuAD Model to TensorFlow](BertSquad-TF.ipynb)

[Exporting BERT/SQuAD Model to CoreML](BertSquad-CoreML.ipynb)

________________

[Fireball Playgrounds](../Contents.ipynb)

[Question Answering (BERT/SQuAD)](BertSquad.ipynb)

[Reducing number of parameters of BERT/SQuAD Model](BertSquad-Reduce.ipynb)
