In [7]:
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from transformers import AutoTokenizer
from utils.data_utils import AG_NEWS_DATASET
from utils.constants import *
from utils.training import Learner
from training_ema import Learner as ema_learner

# from quantization.fully_quantize import Model
from quantization.transformer import Transformer
from quantization.pytorch_api import ModelQuant
from quantization.quantize import quantizer
from quantization.fully_quantize import Model as fullyQuantModel

%load_ext autoreload
%autoreload 2

In [8]:
# load dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dl, test_dl = AG_NEWS_DATASET(tokenizer).load_data()

# create model
original_model = Transformer(4,
                tokenizer.vocab_size,
                BASELINE_MODEL_NUMBER_OF_LAYERS,
                BASELINE_MODEL_NUMBER_OF_HEADS,
                BASELINE_MODEL_DIM)



model = quantizer(original_model, 8, True)


# model = fullyQuantModel(4,
#                 tokenizer.vocab_size,
#                 BASELINE_MODEL_NUMBER_OF_LAYERS,
#                 BASELINE_MODEL_NUMBER_OF_HEADS,
#                 BASELINE_MODEL_DIM)

# loss func
loss_fn = nn.CrossEntropyLoss()

# simple optimizer
optim_original = Adam(original_model.parameters(), lr= 1e-4)
scheduler_original = MultiStepLR(optim_original, milestones=[10,15], gamma=0.1)
optim = Adam(model.parameters(), lr= 1e-4)
scheduler = MultiStepLR(optim_original, milestones=[10,15], gamma=0.1)

train_config ={'model_original': original_model,
               'model': model,
               'loss_fn': loss_fn,
               'optim_original': optim_original,
               'optim': optim,
               'datasets': [train_dl, test_dl],
               'epochs': 10,
               'batch_size': BATCH_SIZE,
               'scheduler_original': scheduler_original,
               'scheduler': scheduler,
               'exp_name': "quant_all",
               'epoch_start_quantization': 1
               }

# training
learner_ag_news = ema_learner(train_config)

  "Lambda function is not supported for pickle, please use "


original model received!


In [9]:
print(model)

Transformer(
  (input_embeddings): Embeddings(
    (token_embedding): Embedding(30522, 512)
    (pos_embedding): Embedding(512, 512)
  )
  (input_encodings): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (sublayer_attention): ModuleList(
    (0): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (heads): ModuleList(
          (0): QuantizedLinear(in_features=512, out_features=512, bias=True)
          (1): QuantizedLinear(in_features=512, out_features=512, bias=True)
          (2): QuantizedLinear(in_features=512, out_features=512, bias=True)
        )
        (output): QuantizedLinear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layernorm): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): sublayerConnectionAttention(
      (multiheads): MultiheadAttention(
        (heads): ModuleList(
          (0): QuantizedLinear(in_features=512, out_features=

In [None]:
learner_ag_news.train()