In [60]:
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np
import matplotlib.pyplot as plt

from transformers import AutoTokenizer
from utils.data_utils import AG_NEWS_DATASET
from utils.constants import *
# from model.transformer import Transformer
from training import Learner

from quantization.binarize import binarize
from quantization.transformer_raw import Transformer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
# load dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dl, test_dl = AG_NEWS_DATASET(tokenizer, batch_size = BATCH_SIZE).load_data()

# create model
model = Transformer(d_model=BASELINE_MODEL_DIMENSION,
                    vocab=tokenizer.vocab_size,
                    h=BASELINE_MODEL_NUMBER_OF_HEADS,
                    n_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                    d_ff =512,
                    n_class=4
                   )
model.model_name = 'transformer'

# model_b = binarize(model, binarize_all_linear=True)
# model_b.model_name = 'binary_transformer'

# loss func
loss_fn = nn.CrossEntropyLoss()

# simple optimizer -> to improve
optim = Adam(model.parameters(), lr= 1e-4)
scheduler = MultiStepLR(optim, milestones=[10,15], gamma=0.1)

train_config ={'model': model,
               'loss_fn': loss_fn,
               'optim': optim,
               'scheduler': scheduler,
               'datasets': [train_dl, test_dl],
               'epochs': 10,
               'batch_size': BATCH_SIZE
               }

# training
learner_ag_news = Learner(train_config)

  "Lambda function is not supported for pickle, please use "


In [71]:
# learner_ag_news.train()

In [72]:
print(model)

Transformer(
  (input_embeddings): Embeddings(
    (token_embedding): Embedding(30522, 512)
    (pos_embedding): Embedding(512, 512)
  )
  (input_encodings): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (sublayer_attention): ModuleList(
    (0): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (output): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layernorm): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (output): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layernorm): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (sublayer_ffn): ModuleList(
    (0): sublayerConnectionFFN(
      (ffn): PositionalWiseFFN(
        (w_1): Linear(in_features=512, out_features=512, bi

In [73]:
from utils.utils import count_memory_size

In [74]:
count_memory_size(model)

71987216

In [75]:
total = count_memory_size(model) - count_memory_size(model.input_embeddings)
for name, layer in model.named_children():
    print(f'{name}: {count_memory_size(layer)/total}')

input_embeddings: 7.5398301980263795
input_encodings: 0.0
sublayer_attention: 0.2502424793441764
sublayer_ffn: 0.49951314323458906
classifier: 0.25024437742123457
