# Train transformer

In [7]:
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from transformers import AutoTokenizer
from utils.data_utils import AG_NEWS_DATASET
from utils.constants import *
from model.transformer import Transformer as Transformer_origin
from utils.training import Learner

from quantization.binarize import binarize, binarize_origin
from quantization.transformer_raw import Transformer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
# load dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dl, test_dl = AG_NEWS_DATASET(tokenizer, batch_size = BATCH_SIZE).load_data()

# create model
model = Transformer(d_model=BASELINE_MODEL_DIM,
                    d_ff=BASELINE_FFN_DIM,
                    d_hidden=BASELINE_HIDDEN_DIM,
                    h=BASELINE_MODEL_NUMBER_OF_HEADS,
                    n_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                    n_class=4,
                    vocab=tokenizer.vocab_size
                   )

binarize(model, 'FFN_ONLY')
print(model)

# loss func
loss_fn = nn.CrossEntropyLoss()

# baseline training config -> do not change!
optim = Adam(model.parameters(), lr= 1e-4)
scheduler = MultiStepLR(optim, milestones=[10,15], gamma=0.1)

train_config ={'model': model,
               'loss_fn': loss_fn,
               'optim': optim,
               'scheduler': scheduler,
               'datasets': [train_dl, test_dl],
               'epochs': 10,
               'batch_size': BATCH_SIZE
               }

train_config['exp_name'] = 'transformer_binarization_FFN_ONLY'

# training
learner_ag_news = Learner(train_config)

Transformer(
  (input_embeddings): Embeddings(
    (token_embedding): Embedding(30522, 512)
    (pos_embedding): Embedding(512, 512)
  )
  (input_encodings): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (sublayer_attention): ModuleList(
    (0): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (heads): ModuleList(
          (0): Linear(in_features=512, out_features=512, bias=True)
          (1): Linear(in_features=512, out_features=512, bias=True)
          (2): Linear(in_features=512, out_features=512, bias=True)
        )
        (output): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layernorm): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (heads): ModuleList(
          (0): Linear(in_features=512, out_features=512, bias=True)
          (1): Linear(in_feat

In [15]:
learner_ag_news.train()

  0%|                                                    | 0/10 [00:00<?, ?it/s]

current lr 1.00000e-04
Epoch: [0][0/3750]	Loss 790.5084	Prec@1 3.125
Epoch: [0][100/3750]	Loss 794.2428	Prec@1 31.900
Epoch: [0][200/3750]	Loss 483.2325	Prec@1 28.638
Epoch: [0][300/3750]	Loss 364.5698	Prec@1 28.644
Epoch: [0][400/3750]	Loss 302.0592	Prec@1 28.499
Epoch: [0][500/3750]	Loss 262.1268	Prec@1 28.493
Epoch: [0][600/3750]	Loss 234.1791	Prec@1 28.291
Epoch: [0][700/3750]	Loss 211.3285	Prec@1 28.196
Epoch: [0][800/3750]	Loss 193.3300	Prec@1 28.223
Epoch: [0][900/3750]	Loss 178.5055	Prec@1 28.014
Epoch: [0][1000/3750]	Loss 165.4602	Prec@1 28.131
Epoch: [0][1100/3750]	Loss 154.2756	Prec@1 28.247
Epoch: [0][1200/3750]	Loss 144.7231	Prec@1 28.193
Epoch: [0][1300/3750]	Loss 136.2696	Prec@1 28.192
Epoch: [0][1400/3750]	Loss 128.7231	Prec@1 28.129
Epoch: [0][1500/3750]	Loss 121.8323	Prec@1 28.115
Epoch: [0][1600/3750]	Loss 115.6597	Prec@1 28.111
Epoch: [0][1700/3750]	Loss 110.0748	Prec@1 28.072
Epoch: [0][1800/3750]	Loss 104.9048	Prec@1 28.212
Epoch: [0][1900/3750]	Loss 100.1867	Prec

 10%|████▍                                       | 1/10 [01:30<13:35, 90.60s/it]

current lr 1.00000e-04
Epoch: [1][0/3750]	Loss 1.7176	Prec@1 9.375
Epoch: [1][100/3750]	Loss 1.8472	Prec@1 27.785
Epoch: [1][200/3750]	Loss 1.6973	Prec@1 28.172
Epoch: [1][300/3750]	Loss 1.6448	Prec@1 28.270
Epoch: [1][400/3750]	Loss 1.6110	Prec@1 28.296
Epoch: [1][500/3750]	Loss 1.5763	Prec@1 29.279
Epoch: [1][600/3750]	Loss 1.5645	Prec@1 29.134
Epoch: [1][700/3750]	Loss 1.5599	Prec@1 28.950
Epoch: [1][800/3750]	Loss 1.5456	Prec@1 28.983
Epoch: [1][900/3750]	Loss 1.5380	Prec@1 29.030
Epoch: [1][1000/3750]	Loss 1.5301	Prec@1 28.993
Epoch: [1][1100/3750]	Loss 1.5193	Prec@1 29.314
Epoch: [1][1200/3750]	Loss 1.5140	Prec@1 29.366
Epoch: [1][1300/3750]	Loss 1.5064	Prec@1 29.624
Epoch: [1][1400/3750]	Loss 1.5021	Prec@1 29.550
Epoch: [1][1500/3750]	Loss 1.4976	Prec@1 29.620
Epoch: [1][1600/3750]	Loss 1.4932	Prec@1 29.614
Epoch: [1][1700/3750]	Loss 1.4908	Prec@1 29.573
Epoch: [1][1800/3750]	Loss 1.4856	Prec@1 29.782
Epoch: [1][1900/3750]	Loss 1.4835	Prec@1 29.748
Epoch: [1][2000/3750]	Loss 1.4

 20%|████████▊                                   | 2/10 [03:02<12:09, 91.23s/it]

current lr 1.00000e-04
Epoch: [2][0/3750]	Loss 1.3125	Prec@1 15.625
Epoch: [2][100/3750]	Loss 1.7825	Prec@1 38.335
Epoch: [2][200/3750]	Loss 1.4396	Prec@1 44.590
Epoch: [2][300/3750]	Loss 1.2807	Prec@1 48.702
Epoch: [2][400/3750]	Loss 1.1833	Prec@1 51.543
Epoch: [2][500/3750]	Loss 1.1192	Prec@1 54.061
Epoch: [2][600/3750]	Loss 1.0809	Prec@1 55.356
Epoch: [2][700/3750]	Loss 1.0446	Prec@1 56.513
Epoch: [2][800/3750]	Loss 1.0102	Prec@1 57.799
Epoch: [2][900/3750]	Loss 0.9849	Prec@1 58.758
Epoch: [2][1000/3750]	Loss 0.9609	Prec@1 59.828
Epoch: [2][1100/3750]	Loss 0.9330	Prec@1 61.160
Epoch: [2][1200/3750]	Loss 0.9129	Prec@1 62.149
Epoch: [2][1300/3750]	Loss 0.8929	Prec@1 63.127
Epoch: [2][1400/3750]	Loss 0.8743	Prec@1 64.079
Epoch: [2][1500/3750]	Loss 0.8569	Prec@1 64.948
Epoch: [2][1600/3750]	Loss 0.8422	Prec@1 65.650
Epoch: [2][1700/3750]	Loss 0.8285	Prec@1 66.299
Epoch: [2][1800/3750]	Loss 0.8235	Prec@1 66.789
Epoch: [2][1900/3750]	Loss 0.8201	Prec@1 67.021
Epoch: [2][2000/3750]	Loss 0.

 30%|█████████████▏                              | 3/10 [04:33<10:38, 91.23s/it]

current lr 1.00000e-04
Epoch: [3][0/3750]	Loss 0.6024	Prec@1 81.250
Epoch: [3][100/3750]	Loss 1.7343	Prec@1 68.967
Epoch: [3][200/3750]	Loss 1.1656	Prec@1 74.471
Epoch: [3][300/3750]	Loss 0.9301	Prec@1 77.803
Epoch: [3][400/3750]	Loss 0.7926	Prec@1 80.214
Epoch: [3][500/3750]	Loss 0.7105	Prec@1 81.662
Epoch: [3][600/3750]	Loss 0.6533	Prec@1 82.810
Epoch: [3][700/3750]	Loss 0.6064	Prec@1 83.746
Epoch: [3][800/3750]	Loss 0.5690	Prec@1 84.543
Epoch: [3][900/3750]	Loss 0.5454	Prec@1 84.982
Epoch: [3][1000/3750]	Loss 0.5249	Prec@1 85.430
Epoch: [3][1100/3750]	Loss 0.5045	Prec@1 85.845
Epoch: [3][1200/3750]	Loss 0.4882	Prec@1 86.100
Epoch: [3][1300/3750]	Loss 0.4743	Prec@1 86.412
Epoch: [3][1400/3750]	Loss 0.4617	Prec@1 86.684
Epoch: [3][1500/3750]	Loss 0.4515	Prec@1 86.865
Epoch: [3][1600/3750]	Loss 0.4431	Prec@1 87.049
Epoch: [3][1700/3750]	Loss 0.4343	Prec@1 87.221
Epoch: [3][1800/3750]	Loss 0.4266	Prec@1 87.373
Epoch: [3][1900/3750]	Loss 0.4231	Prec@1 87.433
Epoch: [3][2000/3750]	Loss 0.

 40%|█████████████████▌                          | 4/10 [06:05<09:08, 91.38s/it]

current lr 1.00000e-04
Epoch: [4][0/3750]	Loss 0.3630	Prec@1 87.500
Epoch: [4][100/3750]	Loss 0.8764	Prec@1 81.157
Epoch: [4][200/3750]	Loss 0.6307	Prec@1 83.986
Epoch: [4][300/3750]	Loss 0.5300	Prec@1 85.694
Epoch: [4][400/3750]	Loss 0.4646	Prec@1 86.993
Epoch: [4][500/3750]	Loss 0.4359	Prec@1 87.731
Epoch: [4][600/3750]	Loss 0.4074	Prec@1 88.457
Epoch: [4][700/3750]	Loss 0.3805	Prec@1 89.118
Epoch: [4][800/3750]	Loss 0.3620	Prec@1 89.533
Epoch: [4][900/3750]	Loss 0.3521	Prec@1 89.706
Epoch: [4][1000/3750]	Loss 0.3414	Prec@1 89.951
Epoch: [4][1100/3750]	Loss 0.3302	Prec@1 90.174
Epoch: [4][1200/3750]	Loss 0.3224	Prec@1 90.321
Epoch: [4][1300/3750]	Loss 0.3138	Prec@1 90.510
Epoch: [4][1400/3750]	Loss 0.3084	Prec@1 90.623
Epoch: [4][1500/3750]	Loss 0.3050	Prec@1 90.700
Epoch: [4][1600/3750]	Loss 0.3008	Prec@1 90.787
Epoch: [4][1700/3750]	Loss 0.2959	Prec@1 90.897
Epoch: [4][1800/3750]	Loss 0.2913	Prec@1 90.991
Epoch: [4][1900/3750]	Loss 0.2892	Prec@1 91.003
Epoch: [4][2000/3750]	Loss 0.

 50%|██████████████████████                      | 5/10 [07:36<07:36, 91.38s/it]

current lr 1.00000e-04
Epoch: [5][0/3750]	Loss 0.2643	Prec@1 93.750
Epoch: [5][100/3750]	Loss 0.5355	Prec@1 87.005
Epoch: [5][200/3750]	Loss 0.4205	Prec@1 87.951
Epoch: [5][300/3750]	Loss 0.3568	Prec@1 89.327
Epoch: [5][400/3750]	Loss 0.3205	Prec@1 90.212
Epoch: [5][500/3750]	Loss 0.3010	Prec@1 90.787
Epoch: [5][600/3750]	Loss 0.2851	Prec@1 91.285
Epoch: [5][700/3750]	Loss 0.2661	Prec@1 91.833
Epoch: [5][800/3750]	Loss 0.2543	Prec@1 92.150
Epoch: [5][900/3750]	Loss 0.2477	Prec@1 92.266
Epoch: [5][1000/3750]	Loss 0.2417	Prec@1 92.445
Epoch: [5][1100/3750]	Loss 0.2354	Prec@1 92.612
Epoch: [5][1200/3750]	Loss 0.2302	Prec@1 92.709
Epoch: [5][1300/3750]	Loss 0.2241	Prec@1 92.907
Epoch: [5][1400/3750]	Loss 0.2214	Prec@1 92.960
Epoch: [5][1500/3750]	Loss 0.2203	Prec@1 92.998
Epoch: [5][1600/3750]	Loss 0.2174	Prec@1 93.090
Epoch: [5][1700/3750]	Loss 0.2136	Prec@1 93.182
Epoch: [5][1800/3750]	Loss 0.2117	Prec@1 93.261
Epoch: [5][1900/3750]	Loss 0.2122	Prec@1 93.227
Epoch: [5][2000/3750]	Loss 0.

 60%|██████████████████████████▍                 | 6/10 [09:08<06:05, 91.48s/it]

current lr 1.00000e-04
Epoch: [6][0/3750]	Loss 0.4569	Prec@1 81.250
Epoch: [6][100/3750]	Loss 0.3577	Prec@1 89.944
Epoch: [6][200/3750]	Loss 0.2835	Prec@1 91.091
Epoch: [6][300/3750]	Loss 0.2478	Prec@1 91.881
Epoch: [6][400/3750]	Loss 0.2262	Prec@1 92.682
Epoch: [6][500/3750]	Loss 0.2131	Prec@1 93.108
Epoch: [6][600/3750]	Loss 0.2011	Prec@1 93.526
Epoch: [6][700/3750]	Loss 0.1859	Prec@1 94.031
Epoch: [6][800/3750]	Loss 0.1768	Prec@1 94.343
Epoch: [6][900/3750]	Loss 0.1717	Prec@1 94.506
Epoch: [6][1000/3750]	Loss 0.1665	Prec@1 94.687
Epoch: [6][1100/3750]	Loss 0.1620	Prec@1 94.797
Epoch: [6][1200/3750]	Loss 0.1590	Prec@1 94.884
Epoch: [6][1300/3750]	Loss 0.1546	Prec@1 95.016
Epoch: [6][1400/3750]	Loss 0.1519	Prec@1 95.082
Epoch: [6][1500/3750]	Loss 0.1534	Prec@1 95.051
Epoch: [6][1600/3750]	Loss 0.1524	Prec@1 95.101
Epoch: [6][1700/3750]	Loss 0.1499	Prec@1 95.159
Epoch: [6][1800/3750]	Loss 0.1478	Prec@1 95.206
Epoch: [6][1900/3750]	Loss 0.1476	Prec@1 95.164
Epoch: [6][2000/3750]	Loss 0.

 70%|██████████████████████████████▊             | 7/10 [10:37<04:32, 90.78s/it]

Epoch[6] *Validation*: Prec@1 90.250
current lr 1.00000e-04
Epoch: [7][0/3750]	Loss 0.1869	Prec@1 93.750
Epoch: [7][100/3750]	Loss 0.3084	Prec@1 90.996
Epoch: [7][200/3750]	Loss 0.2366	Prec@1 92.615
Epoch: [7][300/3750]	Loss 0.1995	Prec@1 93.677
Epoch: [7][400/3750]	Loss 0.1776	Prec@1 94.319
Epoch: [7][500/3750]	Loss 0.1654	Prec@1 94.736
Epoch: [7][600/3750]	Loss 0.1537	Prec@1 95.097
Epoch: [7][700/3750]	Loss 0.1410	Prec@1 95.502
Epoch: [7][800/3750]	Loss 0.1330	Prec@1 95.751
Epoch: [7][900/3750]	Loss 0.1293	Prec@1 95.852
Epoch: [7][1000/3750]	Loss 0.1251	Prec@1 95.998
Epoch: [7][1100/3750]	Loss 0.1203	Prec@1 96.129
Epoch: [7][1200/3750]	Loss 0.1170	Prec@1 96.225
Epoch: [7][1300/3750]	Loss 0.1129	Prec@1 96.342
Epoch: [7][1400/3750]	Loss 0.1127	Prec@1 96.340
Epoch: [7][1500/3750]	Loss 0.1116	Prec@1 96.367
Epoch: [7][1600/3750]	Loss 0.1093	Prec@1 96.434
Epoch: [7][1700/3750]	Loss 0.1068	Prec@1 96.504
Epoch: [7][1800/3750]	Loss 0.1046	Prec@1 96.582
Epoch: [7][1900/3750]	Loss 0.1044	Prec@1

 80%|███████████████████████████████████▏        | 8/10 [12:10<03:02, 91.36s/it]

current lr 1.00000e-04
Epoch: [8][0/3750]	Loss 0.1151	Prec@1 96.875
Epoch: [8][100/3750]	Loss 0.1754	Prec@1 94.462
Epoch: [8][200/3750]	Loss 0.1393	Prec@1 95.507
Epoch: [8][300/3750]	Loss 0.1215	Prec@1 96.055
Epoch: [8][400/3750]	Loss 0.1086	Prec@1 96.454
Epoch: [8][500/3750]	Loss 0.1023	Prec@1 96.669
Epoch: [8][600/3750]	Loss 0.0958	Prec@1 96.828
Epoch: [8][700/3750]	Loss 0.0870	Prec@1 97.138
Epoch: [8][800/3750]	Loss 0.0810	Prec@1 97.324
Epoch: [8][900/3750]	Loss 0.0810	Prec@1 97.333
Epoch: [8][1000/3750]	Loss 0.0785	Prec@1 97.412
Epoch: [8][1100/3750]	Loss 0.0759	Prec@1 97.499
Epoch: [8][1200/3750]	Loss 0.0742	Prec@1 97.533
Epoch: [8][1300/3750]	Loss 0.0726	Prec@1 97.598
Epoch: [8][1400/3750]	Loss 0.0721	Prec@1 97.624
Epoch: [8][1500/3750]	Loss 0.0723	Prec@1 97.627
Epoch: [8][1600/3750]	Loss 0.0712	Prec@1 97.652
Epoch: [8][1700/3750]	Loss 0.0702	Prec@1 97.682
Epoch: [8][1800/3750]	Loss 0.0692	Prec@1 97.717
Epoch: [8][1900/3750]	Loss 0.0695	Prec@1 97.707
Epoch: [8][2000/3750]	Loss 0.

 90%|███████████████████████████████████████▌    | 9/10 [13:38<01:30, 90.55s/it]

Epoch[8] *Validation*: Prec@1 90.066
current lr 1.00000e-04
Epoch: [9][0/3750]	Loss 0.0880	Prec@1 93.750
Epoch: [9][100/3750]	Loss 0.1418	Prec@1 96.040
Epoch: [9][200/3750]	Loss 0.1071	Prec@1 96.782
Epoch: [9][300/3750]	Loss 0.0917	Prec@1 97.124
Epoch: [9][400/3750]	Loss 0.0799	Prec@1 97.452
Epoch: [9][500/3750]	Loss 0.0733	Prec@1 97.686
Epoch: [9][600/3750]	Loss 0.0697	Prec@1 97.769
Epoch: [9][700/3750]	Loss 0.0641	Prec@1 97.958
Epoch: [9][800/3750]	Loss 0.0603	Prec@1 98.061
Epoch: [9][900/3750]	Loss 0.0582	Prec@1 98.124
Epoch: [9][1000/3750]	Loss 0.0565	Prec@1 98.177
Epoch: [9][1100/3750]	Loss 0.0541	Prec@1 98.232
Epoch: [9][1200/3750]	Loss 0.0527	Prec@1 98.280
Epoch: [9][1300/3750]	Loss 0.0514	Prec@1 98.321
Epoch: [9][1400/3750]	Loss 0.0509	Prec@1 98.323
Epoch: [9][1500/3750]	Loss 0.0516	Prec@1 98.299
Epoch: [9][1600/3750]	Loss 0.0508	Prec@1 98.321
Epoch: [9][1700/3750]	Loss 0.0502	Prec@1 98.341
Epoch: [9][1800/3750]	Loss 0.0493	Prec@1 98.362
Epoch: [9][1900/3750]	Loss 0.0494	Prec@1

100%|███████████████████████████████████████████| 10/10 [15:10<00:00, 91.01s/it]

Epoch[9] *Validation*: Prec@1 90.329





[26.63157894736842,
 51.55263157894737,
 89.6842105263158,
 89.92105263157895,
 90.01315789473684,
 90.26315789473684,
 90.25,
 90.52631578947368,
 90.0657894736842,
 90.32894736842105]

# Memory compute

In [67]:
from utils.utils import count_memory_size

In [68]:
count_memory_size(model)

82497552

In [69]:
total = count_memory_size(model) - count_memory_size(model.input_embeddings)
for name, layer in model.named_children():
    print(f'{name}: {count_memory_size(layer)/total}')

input_embeddings: 3.355749760294658
input_encodings: 0.0
sublayer_attention: 0.44420377699589014
sublayer_ffn: 0.4439875142028055
classifier: 0.11180870880130434
