In [1]:
import torch
import transformers
from transformers import pipeline
from transformers import DistilBertModel, DistilBertForMaskedLM, DistilBertTokenizer
#from optimum.quanto import freeze, quantize, qint8, WeightQBytesTensor
import datasets
from transformers import TrainingArguments
import numpy as np
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [3]:
dataset = datasets.load_dataset("imdb")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [4]:
def preprocess(data):
    tokens = tokenizer(data["text"], truncation=True, padding = 'max_length',  max_length=512)
    tokens["label"] = data["label"]
    return tokens

In [5]:
tokens = dataset.map(preprocess, batched = True)

In [6]:
labels = tokens['train'].features['label'].names
num_labels = len(labels)
label2id, id2label = {}, {}

for idx, lbl in enumerate(labels):
    label2id[lbl] = idx
    id2label[idx] = lbl

In [7]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding

model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels = num_labels,
    id2label = id2label,
    label2id = label2id,
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

In [12]:
print(model)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [15]:
import torch

from accelerate.test_utils.testing import get_backend

device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)

model.to(device)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [18]:
from torchao.quantization import (
    quantize_,
    Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    FromIntXQuantizationAwareTrainingConfig,
    IntXQuantizationAwareTrainingConfig,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)

quantize_(
    model,
    IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)

In [None]:
device  = 'cpu'
from transformers import TrainingArguments, Trainer

EPOCHS = 1
BATCH_SIZE = 16
LEARNING_RATE = 0.00005

training_args = TrainingArguments(
    output_dir = './imdb_tune_distilbert_qat',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    metric_for_best_model="accuracy",
    eval_strategy="epoch",
    eval_steps = 500,
    save_strategy="epoch",
    save_total_limit=2,
    report_to=['tensorboard'],
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=tokens["train"].shuffle(seed=11),         
    eval_dataset=tokens["test"].shuffle(seed=72).select(range(5000)),
    tokenizer = tokenizer,
    data_collator = data_collator,
)

  0%|          | 0/13 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 13/13 [01:34<00:00,  6.09s/it]

In [20]:
#model = qat_quantizer.convert(model)
print(model)


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): FakeQuantizedLinear(
              in_features=768, out_features=768, bias=True
              (activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
              (weight_fake_quantizer): FakeQuantizer(Fak

In [25]:
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
print(model)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7fc6bebe3380>, weight=AffineQuantizedTensor(shape=torch.Size([768, 768]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
            (k_lin): Linear(in_features=768, out_features=768, weight=LinearActivationQuantizedTensor(activation=<function

In [26]:
print(model.distilbert.transformer.layer[0].attention.q_lin.state_dict())

OrderedDict({'weight': LinearActivationQuantizedTensor(AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[ 0,  2, -2,  ...,  2,  6,  5],
        [ 1,  2, -4,  ..., -1,  7,  1],
        [ 0,  3,  2,  ..., -1,  1, -5],
        ...,
        [ 1,  0,  4,  ...,  3,  2, -1],
        [ 0,  6,  6,  ..., -1,  7,  1],
        [-1, -7,  1,  ..., -4, -8, -2]], dtype=torch.int8)... , scale=tensor([[0.0108, 0.0136, 0.0108,  ..., 0.0153, 0.0103, 0.0091],
        [0.0143, 0.0141, 0.0114,  ..., 0.0083, 0.0135, 0.0183],
        [0.0107, 0.0078, 0.0107,  ..., 0.0195, 0.0097, 0.0077],
        ...,
        [0.0177, 0.0142, 0.0139,  ..., 0.0116, 0.0142, 0.0220],
        [0.0143, 0.0119, 0.0133,  ..., 0.0117, 0.0152, 0.0155],
        [0.0121, 0.0106, 0.0142,  ..., 0.0130, 0.0164, 0.0130]])... , zero_point=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 

In [22]:
print(tokens["train"]["input_ids"][0].dtype)

torch.int64


In [23]:
for batch in train_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    print(model(**batch))
    break

SequenceClassifierOutput(loss=tensor(0.0180, grad_fn=<NllLossBackward0>), logits=tensor([[ 1.8608, -1.8304],
        [ 2.2928, -2.0456],
        [ 2.1260, -2.0180],
        [ 2.0767, -1.9183],
        [ 2.1248, -2.0246],
        [ 2.0277, -1.8520],
        [ 2.0801, -2.0569],
        [ 2.0196, -1.8730]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [24]:
for batch in train_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    print(qmodel(**batch))
    break

NameError: name 'qmodel' is not defined

In [None]:
#freeze(model)

In [None]:
#print("acc after freeze:", trainer.evaluate())

quantize(model)
trainer.train()
freeze(model)
trainer.evaluate()