# Train transformer

In [1]:
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np
import matplotlib.pyplot as plt

from transformers import AutoTokenizer
from utils.data_utils import AG_NEWS_DATASET
from utils.constants import *
from model.transformer import Transformer as Transformer_origin
from training import Learner

from quantization.binarize import binarize, binarize_origin
from quantization.transformer_raw import Transformer

%load_ext autoreload
%autoreload 2

In [2]:
# load dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dl, test_dl = AG_NEWS_DATASET(tokenizer, batch_size = BATCH_SIZE).load_data()

# create model
model = Transformer(d_model=BASELINE_MODEL_DIM,
                    d_ff=BASELINE_FFN_DIM,
                    d_hidden=BASELINE_HIDDEN_DIM,
                    h=BASELINE_MODEL_NUMBER_OF_HEADS,
                    n_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                    n_class=4,
                    vocab=tokenizer.vocab_size
                   )

# binarize(model, 'ALL')


# loss func
loss_fn = nn.CrossEntropyLoss()

# baseline training config -> do not change!
optim = Adam(model.parameters(), lr= 1e-4)
scheduler = MultiStepLR(optim, milestones=[10,15], gamma=0.1)

train_config ={'model': model,
               'loss_fn': loss_fn,
               'optim': optim,
               'scheduler': scheduler,
               'datasets': [train_dl, test_dl],
               'epochs': 10,
               'batch_size': BATCH_SIZE
               }

train_config['exp_name'] = 'transformer_baseline'

# training
learner_ag_news = Learner(train_config)

  "Lambda function is not supported for pickle, please use "


In [3]:
learner_ag_news.train()

  0%|                                                    | 0/10 [00:00<?, ?it/s]

current lr 1.00000e-04
Epoch: [0][0/3750]	Loss 0.8490	Prec@1 59.375
Epoch: [0][100/3750]	Loss 3.2930	Prec@1 33.416
Epoch: [0][200/3750]	Loss 2.3699	Prec@1 33.349
Epoch: [0][300/3750]	Loss 2.0092	Prec@1 35.309
Epoch: [0][400/3750]	Loss 1.7264	Prec@1 42.526
Epoch: [0][500/3750]	Loss 1.5009	Prec@1 49.769
Epoch: [0][600/3750]	Loss 1.3388	Prec@1 54.960
Epoch: [0][700/3750]	Loss 1.2096	Prec@1 59.335
Epoch: [0][800/3750]	Loss 1.1057	Prec@1 62.851
Epoch: [0][900/3750]	Loss 1.0280	Prec@1 65.459
Epoch: [0][1000/3750]	Loss 0.9644	Prec@1 67.557
Epoch: [0][1100/3750]	Loss 0.9076	Prec@1 69.448
Epoch: [0][1200/3750]	Loss 0.8611	Prec@1 70.990


  0%|                                                    | 0/10 [00:30<?, ?it/s]


KeyboardInterrupt: 

# Memory compute

In [67]:
from utils.utils import count_memory_size

In [68]:
count_memory_size(model)

82497552

In [69]:
total = count_memory_size(model) - count_memory_size(model.input_embeddings)
for name, layer in model.named_children():
    print(f'{name}: {count_memory_size(layer)/total}')

input_embeddings: 3.355749760294658
input_encodings: 0.0
sublayer_attention: 0.44420377699589014
sublayer_ffn: 0.4439875142028055
classifier: 0.11180870880130434
