In [5]:
from torch import nn
from torch.optim import Adam
import numpy as np
import matplotlib.pyplot as plt

from transformers import AutoTokenizer
from utils.data_utils import AG_NEWS_DATASET
from utils.constants import *
from model.transformer import Transformer
from training import Learner

from quantization.binarize import binarize

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
# load dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dl, test_dl = AG_NEWS_DATASET(tokenizer, batch_size = BATCH_SIZE).load_data()

# create model
model = Transformer(model_dimension= BASELINE_MODEL_DIMENSION,
                    src_vocab_size=tokenizer.vocab_size,
                    number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS,
                    number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
                    dropout_probability=BASELINE_MODEL_DROPOUT_PROB,
                    dim_classification=1024,
                    num_classes=4
                   )
model.model_name = 'transformer'

model_b = binarize(model, binarize_all_linear=True)
model_b.model_name = 'binary_transformer'

# loss func
loss_fn = nn.CrossEntropyLoss()

# simple optimizer -> to improve
optim = Adam(model.parameters(), lr= 0.001)
scheduler = None

train_config ={'model': model_b,
               'loss_fn': loss_fn,
               'optim': optim,
               'datasets': [train_dl, test_dl],
               'epochs': 10,
               'batch_size': BATCH_SIZE
               }

# training
learner_ag_news = Learner(train_config)

In [7]:
# learner_ag_news.train()

In [8]:
print(model_b)

Transformer(
  (src_embedding): Embedding(
    (embeddings_table): Embedding(30522, 512)
  )
  (src_pos_embedding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Encoder(
    (encoder_layers): ModuleList(
      (0): EncoderLayer(
        (sublayers): ModuleList(
          (0): SublayerLogic(
            (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerLogic(
            (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (multi_headed_attention): MultiHeadedAttention(
          (qkv_nets): ModuleList(
            (0): BinarizedLinear (512 -> 512)
            (1): BinarizedLinear (512 -> 512)
            (2): BinarizedLinear (512 -> 512)
          )
          (out_projection_net): BinarizedLinear (512 -> 512)
          (attention_dropout): Dropout(p=0.1, inplac

In [14]:
from quantization.binarize import BinarizedLinear
import torch

In [18]:
test_input = torch.ones((2,64))
layer = BinarizedLinear(64,10)

In [21]:
test_out = layer(test_input)

In [23]:
test_out

tensor([[  8.,   4.,  -6., -10.,  12.,   4., -12.,  -2.,  -8.,   0.],
        [  8.,   4.,  -6., -10.,  12.,   4., -12.,  -2.,  -8.,   0.]],
       grad_fn=<BinaryLinearFunctionBackward>)