In [45]:
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np
import matplotlib.pyplot as plt

from transformers import AutoTokenizer
from utils.data_utils import AG_NEWS_DATASET
from utils.constants import *
# from model.transformer import Transformer
from training import Learner

from quantization.binarize import binarize
from quantization.transformer_raw import Transformer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
# load dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dl, test_dl = AG_NEWS_DATASET(tokenizer, batch_size = BATCH_SIZE).load_data()

# create model
model = Transformer(d_model=BASELINE_MODEL_DIMENSION,
                    vocab=tokenizer.vocab_size,
                    h=BASELINE_MODEL_NUMBER_OF_HEADS,
                    n_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                    n_class=4
                   )
model.model_name = 'transformer'

# model_b = binarize(model, binarize_all_linear=True)
# model_b.model_name = 'binary_transformer'

# loss func
loss_fn = nn.CrossEntropyLoss()

# simple optimizer -> to improve
optim = Adam(model.parameters(), lr= 1e-4)
scheduler = MultiStepLR(optim, milestones=[10,15], gamma=0.1)

train_config ={'model': model,
               'loss_fn': loss_fn,
               'optim': optim,
               'scheduler': scheduler,
               'datasets': [train_dl, test_dl],
               'epochs': 10,
               'batch_size': BATCH_SIZE
               }

# training
learner_ag_news = Learner(train_config)

In [21]:
learner_ag_news.train()

  0%|                                                    | 0/10 [00:00<?, ?it/s]

current lr 1.00000e-04
Epoch: [0][0/3750]	Loss 0.3273	Prec@1 87.500
Epoch: [0][100/3750]	Loss 3.6267	Prec@1 34.220
Epoch: [0][200/3750]	Loss 2.5518	Prec@1 34.748
Epoch: [0][300/3750]	Loss 2.1056	Prec@1 38.694
Epoch: [0][400/3750]	Loss 1.8062	Prec@1 45.114
Epoch: [0][500/3750]	Loss 1.5801	Prec@1 51.166
Epoch: [0][600/3750]	Loss 1.4141	Prec@1 55.735
Epoch: [0][700/3750]	Loss 1.2767	Prec@1 59.816
Epoch: [0][800/3750]	Loss 1.1698	Prec@1 63.011
Epoch: [0][900/3750]	Loss 1.0910	Prec@1 65.365
Epoch: [0][1000/3750]	Loss 1.0252	Prec@1 67.386
Epoch: [0][1100/3750]	Loss 0.9680	Prec@1 69.213
Epoch: [0][1200/3750]	Loss 0.9168	Prec@1 70.785
Epoch: [0][1300/3750]	Loss 0.8741	Prec@1 72.067
Epoch: [0][1400/3750]	Loss 0.8369	Prec@1 73.238
Epoch: [0][1500/3750]	Loss 0.8058	Prec@1 74.184
Epoch: [0][1600/3750]	Loss 0.7801	Prec@1 74.977
Epoch: [0][1700/3750]	Loss 0.7550	Prec@1 75.792
Epoch: [0][1800/3750]	Loss 0.7316	Prec@1 76.496
Epoch: [0][1900/3750]	Loss 0.7148	Prec@1 77.001
Epoch: [0][2000/3750]	Loss 0.

 10%|████▎                                      | 1/10 [06:04<54:40, 364.45s/it]

current lr 1.00000e-04
Epoch: [1][0/3750]	Loss 0.6588	Prec@1 78.125
Epoch: [1][100/3750]	Loss 0.8462	Prec@1 80.136


 10%|████▎                                      | 1/10 [06:23<57:30, 383.33s/it]


KeyboardInterrupt: 