# Train transformer

In [65]:
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from transformers import AutoTokenizer
from utils.data_utils import AG_NEWS_DATASET
from utils.constants import *
from utils.training import Learner

from quantization.binarize import binarize, IRLinear
from quantization.transformer import Transformer
from quantization.quantize import quantizer
from quantization.pytorch_api import ModelQuant
from quantization.fully_quantize import Model as fullyQuantModel

from utils.train_utils import change_t
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [69]:
def create_model(vocab_size, quant_type=None, quant_method=None, bit_num=None, quant_pattern=None):
    '''
    Create training model based on sepcified quant_type
    ----------
    Arguments:
    quant_type    - quant type, should be one of [None, 'quantization', 'binarization']
    quant_method  - quant method to use, if quant_type is None, it should also be None
                    For 'quantization', should be one of ['basic', 'pytorch', 'fully']
                    For 'binarization', should be one of ['basic', 'ir']
    bit_num       - bit number for each parameter, only works when quant_type is 'quantization'
                    should be one of [8,4,2]
    quant_pattern - quantization pattern, should be one of ['MHA', 'FFN', 'CLS', 'ALL']
    '''
    model = Transformer(d_model=BASELINE_MODEL_DIM,
                             d_ff=BASELINE_FFN_DIM,
                             d_hidden=BASELINE_HIDDEN_DIM,
                             h=BASELINE_MODEL_NUMBER_OF_HEADS,
                             n_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                             n_class=4,
                             vocab=vocab_size
                            )
    
    __quant_type__ = [None,'quantization','binarization']
    __bit_num__ = [None,8,4,2]
    __quant_pattern__ = [None,'MHA', 'FFN', 'CLS', 'ALL']
    
    assert quant_type in __quant_type__, f"Unimplemented quantization type, should be one of {__quant_type__}, got '{quant_type}'!"
    assert bit_num in __bit_num__, f"Unimplemented bit number, should be one of {__bit_num__}, got '{bit_num}'!"
    assert quant_pattern in __quant_pattern__, f"Unimplemented quantization method, should be one of {__quant_pattern__}, got '{quant_pattern}'!"
    
    if quant_type == None:
        if quant_method is not None:
            print(f"Quant method {quant_method} will not work in baseline model!")
        if bit_num is not None:
            print(f"Bit number {bit_num} will not work in baseline model!")
        if quant_pattern is not None:
            print(f"Quant pattern {quant_pattern} will not work in baseline model!")
    
    elif quant_type == 'quantization':
        __quant_method__ = ['basic', 'pytorch', 'fully']
        
        assert quant_method in __quant_method__, f"Unimplemented quantization method, should be one of {__quant_method__}, got '{quant_method}'!"
        assert bit_num != None, f"Bit number can not be None!"
        assert quant_pattern != None, f"Quant pattern can not be None!"
        
        if quant_method == 'basic':
            if quant_pattern != 'ALL':
                print(f"Current quant method {quant_method} can only quantize the whole network, quant pattern {quant_pattern} will not work!")
            model = quantizer(model, bit_num, True)
            
        elif quant_method == 'pytorch':
            model = ModelQuant(d_model=BASELINE_MODEL_DIM,
                               d_ff=BASELINE_FFN_DIM,
                               d_hidden=BASELINE_HIDDEN_DIM,
                               h=BASELINE_MODEL_NUMBER_OF_HEADS,
                               n_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                               n_class=4,
                               vocab=tokenizer.vocab_size,
                               quant_ffn=((quant_pattern == 'FFN')|(quant_pattern == 'ALL')),
                               quant_mha=((quant_pattern == 'MHA')|(quant_pattern == 'ALL')),
                               quant_classifier=((quant_pattern == 'CLS')|(quant_pattern == 'ALL')),
                               bit_num=bit_num)
            
        elif quant_method == 'fully':
            print("For fully_quantized model, bit number and quant pattern will not work!")
            model = fullyQuantModel(4,
                tokenizer.vocab_size,
                BASELINE_MODEL_NUMBER_OF_LAYERS,
                BASELINE_MODEL_NUMBER_OF_HEADS,
                BASELINE_MODEL_DIM)
            
    elif quant_type == 'binarization':
        __quant_method__ = ['basic', 'ir']
        assert quant_method in __quant_method__, f"Unimplemented quantization method, should be one of {__quant_method__}, got '{quant_method}'!"
        assert quant_pattern != None, f"Quant pattern can not be None!"
        print(f"For binarization model, bit num will not work!")
        
        binarize(model, quant_pattern, skip_final=True, qk_only=True)
    
    return model
    

In [8]:
# load dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dl, test_dl = AG_NEWS_DATASET(tokenizer, batch_size = BATCH_SIZE).load_data()

# create model
model = Transformer(d_model=BASELINE_MODEL_DIM,
                    d_ff=BASELINE_FFN_DIM,
                    d_hidden=BASELINE_HIDDEN_DIM,
                    h=BASELINE_MODEL_NUMBER_OF_HEADS,
                    n_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                    n_class=4,
                    vocab=tokenizer.vocab_size
                   )

binarize(model, 'ALL', binarize_layer='ir', skip_final=True, qk_only=False)
print(model)

# loss func
loss_fn = nn.CrossEntropyLoss()

# baseline training config -> do not change!
optim = Adam(model.parameters(), lr= 1e-4)
scheduler = MultiStepLR(optim, milestones=[10,15], gamma=0.1)

train_config ={'model': model,
               'loss_fn': loss_fn,
               'optim': optim,
               'scheduler': scheduler,
               'datasets': [train_dl, test_dl],
               'epochs': 10,
               'batch_size': BATCH_SIZE
               }

train_config['exp_name'] = 'transformer_IRNET'

# training
learner_ag_news = Learner(train_config, ir = True)

Transformer(
  (input_embeddings): Embeddings(
    (token_embedding): Embedding(30522, 512)
    (pos_embedding): Embedding(512, 512)
  )
  (input_encodings): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (sublayer_attention): ModuleList(
    (0): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (heads): ModuleList(
          (0): IRLinear (512 -> 512)
          (1): IRLinear (512 -> 512)
          (2): IRLinear (512 -> 512)
        )
        (output): IRLinear (512 -> 512)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layernorm): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (heads): ModuleList(
          (0): IRLinear (512 -> 512)
          (1): IRLinear (512 -> 512)
          (2): IRLinear (512 -> 512)
        )
        (output): IRLinear (512 -> 512)
        (dropout): Dropout(p=0.1, inplace=False)
      )
   

In [9]:
print(model.sublayer_attention[0].multiheads.heads[1].t)

None


In [10]:
learner_ag_news.train()

  0%|                                                    | 0/10 [00:00<?, ?it/s]

current lr 1.00000e-04
tensor([0.1000], device='cuda:0')
Epoch: [0][0/3750]	Loss 21165.4707	Prec@1 31.250
Epoch: [0][100/3750]	Loss 24068.9844	Prec@1 31.374
Epoch: [0][200/3750]	Loss 17295.1633	Prec@1 29.058
Epoch: [0][300/3750]	Loss 14339.7965	Prec@1 28.686
Epoch: [0][400/3750]	Loss 12647.1906	Prec@1 28.265
Epoch: [0][500/3750]	Loss 11378.3273	Prec@1 28.412
Epoch: [0][600/3750]	Loss 10572.3136	Prec@1 28.026
Epoch: [0][700/3750]	Loss 9866.4458	Prec@1 27.960
Epoch: [0][800/3750]	Loss 9347.8048	Prec@1 27.809
Epoch: [0][900/3750]	Loss 8874.2556	Prec@1 27.629
Epoch: [0][1000/3750]	Loss 8416.5995	Prec@1 27.691
Epoch: [0][1100/3750]	Loss 8019.8246	Prec@1 27.725
Epoch: [0][1200/3750]	Loss 7681.0817	Prec@1 27.597
Epoch: [0][1300/3750]	Loss 7355.8397	Prec@1 27.693
Epoch: [0][1400/3750]	Loss 7066.3091	Prec@1 27.663
Epoch: [0][1500/3750]	Loss 6801.2076	Prec@1 27.625
Epoch: [0][1600/3750]	Loss 6558.7160	Prec@1 27.567
Epoch: [0][1700/3750]	Loss 6338.2557	Prec@1 27.504
Epoch: [0][1800/3750]	Loss 611

 10%|████▎                                      | 1/10 [01:49<16:22, 109.21s/it]

current lr 1.00000e-04
tensor([0.2887], device='cuda:0')
Epoch: [1][0/3750]	Loss 585.1575	Prec@1 15.625
Epoch: [1][100/3750]	Loss 2129.6004	Prec@1 32.147
Epoch: [1][200/3750]	Loss 1411.7155	Prec@1 30.193
Epoch: [1][300/3750]	Loss 1134.2423	Prec@1 29.527
Epoch: [1][400/3750]	Loss 982.5048	Prec@1 28.858
Epoch: [1][500/3750]	Loss 908.4314	Prec@1 29.148
Epoch: [1][600/3750]	Loss 883.7329	Prec@1 29.201
Epoch: [1][700/3750]	Loss 833.7551	Prec@1 29.017
Epoch: [1][800/3750]	Loss 807.0178	Prec@1 28.792
Epoch: [1][900/3750]	Loss 778.5893	Prec@1 28.548
Epoch: [1][1000/3750]	Loss 745.7690	Prec@1 28.553
Epoch: [1][1100/3750]	Loss 718.5661	Prec@1 28.590
Epoch: [1][1200/3750]	Loss 698.2553	Prec@1 28.450
Epoch: [1][1300/3750]	Loss 685.3663	Prec@1 28.550
Epoch: [1][1400/3750]	Loss 671.2385	Prec@1 28.437
Epoch: [1][1500/3750]	Loss 652.4644	Prec@1 28.423
Epoch: [1][1600/3750]	Loss 637.8115	Prec@1 28.508
Epoch: [1][1700/3750]	Loss 627.9484	Prec@1 28.516
Epoch: [1][1800/3750]	Loss 617.0398	Prec@1 28.585
Ep

 20%|████████▌                                  | 2/10 [03:36<14:25, 108.20s/it]

current lr 1.00000e-04
tensor([0.8337], device='cuda:0')
Epoch: [2][0/3750]	Loss 381.4594	Prec@1 15.625
Epoch: [2][100/3750]	Loss 483.5287	Prec@1 33.045
Epoch: [2][200/3750]	Loss 344.0880	Prec@1 30.737
Epoch: [2][300/3750]	Loss 279.3286	Prec@1 30.866
Epoch: [2][400/3750]	Loss 240.7345	Prec@1 30.884
Epoch: [2][500/3750]	Loss 215.5271	Prec@1 31.524
Epoch: [2][600/3750]	Loss 203.1654	Prec@1 31.474
Epoch: [2][700/3750]	Loss 188.9650	Prec@1 31.624
Epoch: [2][800/3750]	Loss 177.8251	Prec@1 31.726
Epoch: [2][900/3750]	Loss 168.7909	Prec@1 31.642
Epoch: [2][1000/3750]	Loss 160.4991	Prec@1 31.721
Epoch: [2][1100/3750]	Loss 154.9529	Prec@1 31.747
Epoch: [2][1200/3750]	Loss 149.0420	Prec@1 31.841
Epoch: [2][1300/3750]	Loss 143.4316	Prec@1 32.208
Epoch: [2][1400/3750]	Loss 139.5627	Prec@1 32.084
Epoch: [2][1500/3750]	Loss 135.3463	Prec@1 32.174
Epoch: [2][1600/3750]	Loss 132.0684	Prec@1 32.224
Epoch: [2][1700/3750]	Loss 129.3862	Prec@1 32.251
Epoch: [2][1800/3750]	Loss 127.1842	Prec@1 32.425
Epoch

 30%|████████████▉                              | 3/10 [05:22<12:29, 107.03s/it]

current lr 1.00000e-04
tensor([2.4074], device='cuda:0')
Epoch: [3][0/3750]	Loss 1526.2471	Prec@1 0.000
Epoch: [3][100/3750]	Loss 2189.3204	Prec@1 36.231
Epoch: [3][200/3750]	Loss 1474.9435	Prec@1 36.427
Epoch: [3][300/3750]	Loss 1146.7469	Prec@1 37.407
Epoch: [3][400/3750]	Loss 947.3324	Prec@1 38.700
Epoch: [3][500/3750]	Loss 813.3875	Prec@1 40.675
Epoch: [3][600/3750]	Loss 724.4939	Prec@1 41.535
Epoch: [3][700/3750]	Loss 655.9231	Prec@1 42.497
Epoch: [3][800/3750]	Loss 599.4893	Prec@1 43.473
Epoch: [3][900/3750]	Loss 553.6310	Prec@1 44.374
Epoch: [3][1000/3750]	Loss 514.6924	Prec@1 45.174
Epoch: [3][1100/3750]	Loss 482.2177	Prec@1 45.967
Epoch: [3][1200/3750]	Loss 454.8097	Prec@1 46.524
Epoch: [3][1300/3750]	Loss 430.4817	Prec@1 47.451
Epoch: [3][1400/3750]	Loss 408.8771	Prec@1 48.022
Epoch: [3][1500/3750]	Loss 389.3861	Prec@1 48.709
Epoch: [3][1600/3750]	Loss 373.7526	Prec@1 48.969
Epoch: [3][1700/3750]	Loss 360.2273	Prec@1 49.405
Epoch: [3][1800/3750]	Loss 346.6165	Prec@1 50.134
Ep

 40%|█████████████████▏                         | 4/10 [07:07<10:36, 106.14s/it]

current lr 1.00000e-04
tensor([6.9513], device='cuda:0')
Epoch: [4][0/3750]	Loss 1353.9819	Prec@1 6.250
Epoch: [4][100/3750]	Loss 2900.7069	Prec@1 51.918
Epoch: [4][200/3750]	Loss 1796.2151	Prec@1 59.484
Epoch: [4][300/3750]	Loss 1353.6102	Prec@1 62.583
Epoch: [4][400/3750]	Loss 1104.4719	Prec@1 64.885
Epoch: [4][500/3750]	Loss 941.6392	Prec@1 66.791
Epoch: [4][600/3750]	Loss 832.6884	Prec@1 67.752
Epoch: [4][700/3750]	Loss 746.2222	Prec@1 68.848
Epoch: [4][800/3750]	Loss 680.6020	Prec@1 69.573
Epoch: [4][900/3750]	Loss 627.2893	Prec@1 70.255
Epoch: [4][1000/3750]	Loss 583.3240	Prec@1 70.892
Epoch: [4][1100/3750]	Loss 545.4357	Prec@1 71.463
Epoch: [4][1200/3750]	Loss 513.0402	Prec@1 71.865
Epoch: [4][1300/3750]	Loss 484.8997	Prec@1 72.360
Epoch: [4][1400/3750]	Loss 461.9597	Prec@1 72.667
Epoch: [4][1500/3750]	Loss 441.2969	Prec@1 72.843
Epoch: [4][1600/3750]	Loss 422.6132	Prec@1 73.064
Epoch: [4][1700/3750]	Loss 407.5584	Prec@1 73.212
Epoch: [4][1800/3750]	Loss 394.4210	Prec@1 73.510
E

 50%|█████████████████████▌                     | 5/10 [08:52<08:49, 105.94s/it]

current lr 1.00000e-04
tensor([20.0717], device='cuda:0')
Epoch: [5][0/3750]	Loss 321.5333	Prec@1 53.125
Epoch: [5][100/3750]	Loss 927.8808	Prec@1 72.184
Epoch: [5][200/3750]	Loss 659.0438	Prec@1 74.705
Epoch: [5][300/3750]	Loss 530.5450	Prec@1 76.132
Epoch: [5][400/3750]	Loss 449.5120	Prec@1 77.385
Epoch: [5][500/3750]	Loss 402.6307	Prec@1 77.832
Epoch: [5][600/3750]	Loss 363.6493	Prec@1 78.546
Epoch: [5][700/3750]	Loss 335.4043	Prec@1 79.124
Epoch: [5][800/3750]	Loss 312.5256	Prec@1 79.588
Epoch: [5][900/3750]	Loss 294.4162	Prec@1 79.925
Epoch: [5][1000/3750]	Loss 278.3849	Prec@1 80.317
Epoch: [5][1100/3750]	Loss 264.2070	Prec@1 80.660
Epoch: [5][1200/3750]	Loss 252.5225	Prec@1 80.821
Epoch: [5][1300/3750]	Loss 241.7470	Prec@1 81.142
Epoch: [5][1400/3750]	Loss 233.4942	Prec@1 81.312
Epoch: [5][1500/3750]	Loss 225.2546	Prec@1 81.489
Epoch: [5][1600/3750]	Loss 217.8847	Prec@1 81.625
Epoch: [5][1700/3750]	Loss 212.0884	Prec@1 81.691
Epoch: [5][1800/3750]	Loss 207.5187	Prec@1 81.823
Epoc

 60%|█████████████████████████▊                 | 6/10 [10:38<07:03, 105.77s/it]

current lr 1.00000e-04
tensor([57.9565], device='cuda:0')
Epoch: [6][0/3750]	Loss 70.6200	Prec@1 81.250
Epoch: [6][100/3750]	Loss 305.5932	Prec@1 78.218
Epoch: [6][200/3750]	Loss 201.5248	Prec@1 82.136
Epoch: [6][300/3750]	Loss 165.3695	Prec@1 83.731
Epoch: [6][400/3750]	Loss 142.9147	Prec@1 84.882
Epoch: [6][500/3750]	Loss 130.6433	Prec@1 85.492
Epoch: [6][600/3750]	Loss 122.2649	Prec@1 85.727
Epoch: [6][700/3750]	Loss 116.5397	Prec@1 85.966
Epoch: [6][800/3750]	Loss 109.6074	Prec@1 86.326
Epoch: [6][900/3750]	Loss 105.5105	Prec@1 86.380
Epoch: [6][1000/3750]	Loss 102.4837	Prec@1 86.510
Epoch: [6][1100/3750]	Loss 99.9186	Prec@1 86.572
Epoch: [6][1200/3750]	Loss 97.6165	Prec@1 86.529
Epoch: [6][1300/3750]	Loss 95.3666	Prec@1 86.652
Epoch: [6][1400/3750]	Loss 94.1276	Prec@1 86.668
Epoch: [6][1500/3750]	Loss 92.6705	Prec@1 86.651
Epoch: [6][1600/3750]	Loss 91.5177	Prec@1 86.653
Epoch: [6][1700/3750]	Loss 90.5027	Prec@1 86.651
Epoch: [6][1800/3750]	Loss 89.4899	Prec@1 86.684
Epoch: [6][19

 70%|██████████████████████████████             | 7/10 [12:23<05:16, 105.52s/it]

Epoch[6] *Validation*: Prec@1 87.789
current lr 1.00000e-04
tensor([167.3475], device='cuda:0')
Epoch: [7][0/3750]	Loss 12.1944	Prec@1 90.625
Epoch: [7][100/3750]	Loss 255.7659	Prec@1 80.012
Epoch: [7][200/3750]	Loss 163.2374	Prec@1 84.375
Epoch: [7][300/3750]	Loss 132.1622	Prec@1 86.140
Epoch: [7][400/3750]	Loss 115.2240	Prec@1 86.884
Epoch: [7][500/3750]	Loss 104.2879	Prec@1 87.531
Epoch: [7][600/3750]	Loss 96.5109	Prec@1 87.890
Epoch: [7][700/3750]	Loss 89.6567	Prec@1 88.249
Epoch: [7][800/3750]	Loss 84.1652	Prec@1 88.534
Epoch: [7][900/3750]	Loss 81.9153	Prec@1 88.495
Epoch: [7][1000/3750]	Loss 79.9148	Prec@1 88.565
Epoch: [7][1100/3750]	Loss 78.1037	Prec@1 88.624
Epoch: [7][1200/3750]	Loss 76.7946	Prec@1 88.580
Epoch: [7][1300/3750]	Loss 74.9933	Prec@1 88.747
Epoch: [7][1400/3750]	Loss 74.4707	Prec@1 88.667
Epoch: [7][1500/3750]	Loss 73.7586	Prec@1 88.635
Epoch: [7][1600/3750]	Loss 73.0431	Prec@1 88.656
Epoch: [7][1700/3750]	Loss 72.3321	Prec@1 88.700
Epoch: [7][1800/3750]	Loss 71

 80%|██████████████████████████████████▍        | 8/10 [14:08<03:30, 105.38s/it]

current lr 1.00000e-04
tensor([483.2103], device='cuda:0')
Epoch: [8][0/3750]	Loss 42.7400	Prec@1 84.375
Epoch: [8][100/3750]	Loss 114.7661	Prec@1 85.705
Epoch: [8][200/3750]	Loss 86.0169	Prec@1 87.889
Epoch: [8][300/3750]	Loss 75.0524	Prec@1 88.777
Epoch: [8][400/3750]	Loss 68.2836	Prec@1 89.238
Epoch: [8][500/3750]	Loss 66.4842	Prec@1 89.583
Epoch: [8][600/3750]	Loss 64.1516	Prec@1 89.715
Epoch: [8][700/3750]	Loss 62.0181	Prec@1 89.921
Epoch: [8][800/3750]	Loss 59.6606	Prec@1 90.137
Epoch: [8][900/3750]	Loss 58.9358	Prec@1 90.174
Epoch: [8][1000/3750]	Loss 58.9339	Prec@1 90.138
Epoch: [8][1100/3750]	Loss 58.0594	Prec@1 90.222
Epoch: [8][1200/3750]	Loss 56.8734	Prec@1 90.282
Epoch: [8][1300/3750]	Loss 55.5834	Prec@1 90.416
Epoch: [8][1400/3750]	Loss 54.9903	Prec@1 90.389
Epoch: [8][1500/3750]	Loss 54.6402	Prec@1 90.390
Epoch: [8][1600/3750]	Loss 54.3203	Prec@1 90.399
Epoch: [8][1700/3750]	Loss 54.0640	Prec@1 90.362
Epoch: [8][1800/3750]	Loss 54.0598	Prec@1 90.361
Epoch: [8][1900/3750]

 90%|██████████████████████████████████████▋    | 9/10 [15:52<01:45, 105.05s/it]

Epoch[8] *Validation*: Prec@1 87.895
current lr 1.00000e-04
tensor([1395.2537], device='cuda:0')
Epoch: [9][0/3750]	Loss 60.4715	Prec@1 84.375
Epoch: [9][100/3750]	Loss 118.5935	Prec@1 86.170
Epoch: [9][200/3750]	Loss 86.3018	Prec@1 88.650
Epoch: [9][300/3750]	Loss 72.8353	Prec@1 89.826
Epoch: [9][400/3750]	Loss 65.0825	Prec@1 90.461
Epoch: [9][500/3750]	Loss 63.0385	Prec@1 90.625
Epoch: [9][600/3750]	Loss 59.2699	Prec@1 90.979
Epoch: [9][700/3750]	Loss 56.0551	Prec@1 91.200
Epoch: [9][800/3750]	Loss 53.2268	Prec@1 91.487
Epoch: [9][900/3750]	Loss 53.1662	Prec@1 91.478
Epoch: [9][1000/3750]	Loss 52.0621	Prec@1 91.574
Epoch: [9][1100/3750]	Loss 50.7428	Prec@1 91.638
Epoch: [9][1200/3750]	Loss 50.8852	Prec@1 91.517
Epoch: [9][1300/3750]	Loss 50.0650	Prec@1 91.593
Epoch: [9][1400/3750]	Loss 49.8631	Prec@1 91.537
Epoch: [9][1500/3750]	Loss 49.8908	Prec@1 91.514
Epoch: [9][1600/3750]	Loss 49.6807	Prec@1 91.511
Epoch: [9][1700/3750]	Loss 49.3541	Prec@1 91.509
Epoch: [9][1800/3750]	Loss 48.96

100%|██████████████████████████████████████████| 10/10 [17:36<00:00, 105.70s/it]

Epoch[9] *Validation*: Prec@1 87.855





[30.894736842105264,
 34.89473684210526,
 54.25,
 78.46052631578948,
 84.44736842105263,
 88.36842105263158,
 87.78947368421052,
 88.42105263157895,
 87.89473684210526,
 87.85526315789474]

# Memory compute

In [67]:
from utils.utils import count_memory_size

In [68]:
count_memory_size(model)

82497552

In [69]:
total = count_memory_size(model) - count_memory_size(model.input_embeddings)
for name, layer in model.named_children():
    print(f'{name}: {count_memory_size(layer)/total}')

input_embeddings: 3.355749760294658
input_encodings: 0.0
sublayer_attention: 0.44420377699589014
sublayer_ffn: 0.4439875142028055
classifier: 0.11180870880130434
