In [30]:
from src.lab1.shakespeare_trainer import ShakespeareModule

In [31]:
module = ShakespeareModule.load_from_checkpoint("src/lab1/checkpoints/float-best.ckpt")

In [68]:
import torch
import torch.nn as nn
from typing import Dict 

class QLinear(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 weight_bitwidth: int = 8,
                 act_bitwidth: int = 8):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.has_bias = bias
        self.weight_bitwidth = weight_bitwidth
        self.act_bitwidth = act_bitwidth

        # buffers (not parameters) to hold quantized weight + scale
        self.register_buffer("qweight", torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer("weight_scale", torch.ones(1))  # single scalar

        # treat bias as a float32 buffer (optional)
        if bias:
            self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float32))
        else:
            self.bias = None

    @staticmethod
    def _quantize_tensor(x: torch.Tensor, bitwidth: int):
        q_max = 2 ** (bitwidth - 1) - 1
        r_max = x.abs().max()
        # avoid division by zero
        scale = r_max / q_max if r_max > 0 else torch.tensor(1.0, device=x.device)
        q = torch.clamp(torch.round(x / scale), -q_max, +q_max).to(torch.int8)
        return q, scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 1) Quantize activations dynamically
        q_act, act_scale = self._quantize_tensor(x, self.act_bitwidth)

        # 2) Integer GEMM (q_act: int8, qweight: int8) → int32 accumulator
        qout_int = torch.matmul(q_act, self.qweight.t().to(q_act.dtype))

        # 3) Dequantize: y = qout_int * act_scale * weight_scale
        y = qout_int.to(torch.float32) * act_scale * self.weight_scale

        # 4) Add (float) bias if present
        if self.has_bias:
            y = y + self.bias

        return y

    def __repr__(self): 
        return f"QLinear(in_features={self.in_features}, out_features={self.out_features}, weight_bitwidth={self.weight_bitwidth})"


def quantize_linear(module: nn.Linear, weight_bitwidth: int = 8, act_bitwidth: int = 8) -> QLinear:
    """
    Given a trained nn.Linear, return a QLinear with:
     - weights quantized to `weight_bitwidth` bits
     - single shared scale factor stored in module.weight_scale
     - bias copied (fp32)
    """
    # 1) instantiate QLinear with same dimensions
    qmod = QLinear(module.in_features,
                   module.out_features,
                   bias=(module.bias is not None),
                   weight_bitwidth=weight_bitwidth,
                   act_bitwidth=act_bitwidth)

    # 2) quantize the floating-point weights
    q_w, w_scale = QLinear._quantize_tensor(module.weight.data, weight_bitwidth)
    # copy into buffers
    qmod.qweight.copy_(q_w)
    qmod.weight_scale.copy_(w_scale)

    # 3) copy bias (if any)
    if module.bias is not None:
        qmod.bias.copy_(module.bias.data)

    return qmod


def quantize_model(module: nn.Module, 
                   qconfig: Dict[str, int]) -> nn.Module:
    """
    Given an nn.Module and a qconfig dict mapping module‐paths (e.g.
    "model.transformer_blocks.3.attn_proj") to weight_bitwidths,
    replace each specified nn.Linear with a QLinear using that bitwidth
    (and the same for activations).
    """
    for full_name, weight_bw in qconfig.items():
        # split off the last component to find the parent container
        *parent_path, child_name = full_name.split('.')
        
        # navigate to the parent module
        if parent_path:
            parent = module.get_submodule('.'.join(parent_path))
        else:
            parent = module  # top‐level
        
        orig = getattr(parent, child_name)
        if not isinstance(orig, nn.Linear):
            raise ValueError(f"Expected nn.Linear at '{full_name}', "
                             f"but found {type(orig)}")
        
        # quantize weight (and use same bw for activations)
        qlin = quantize_linear(orig,
                               weight_bitwidth=weight_bw,
                               act_bitwidth=weight_bw)
        
        # overwrite in the parent
        setattr(parent, child_name, qlin)
    
    return module

    

In [69]:
print(module)
    

ShakespeareModule(
  (model): AutoRegressiveTransformer(
    (embedding): Embedding(1024, 256)
    (pos_encoder): Embedding(1024, 256)
    (transformer_blocks): ModuleList(
      (0-7): 8 x TransformerBlock(
        (self_attn): MultiHeadAttention(
          (qkv_proj): QLinear()
          (out_proj): QLinear()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): FeedForward(
          (linear1): QLinear()
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): QLinear()
          (activation): GELU(approximate='none')
        )
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (linear): Linear(in_features=256, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)


In [70]:
def get_init_qconfig(module):
    blocks = module.model.transformer_blocks
    qconfig = {}
    for name, module in blocks.named_modules(): 
        if isinstance(module, nn.Linear): 
            qconfig[f"model.transformer_blocks.{name}"] = 8
    return qconfig

In [71]:
module = ShakespeareModule.load_from_checkpoint("src/lab1/checkpoints/float-best.ckpt")
qconfig = get_init_qconfig(module)
print(qconfig)

{'model.transformer_blocks.0.self_attn.qkv_proj': 8, 'model.transformer_blocks.0.self_attn.out_proj': 8, 'model.transformer_blocks.0.feed_forward.linear1': 8, 'model.transformer_blocks.0.feed_forward.linear2': 8, 'model.transformer_blocks.1.self_attn.qkv_proj': 8, 'model.transformer_blocks.1.self_attn.out_proj': 8, 'model.transformer_blocks.1.feed_forward.linear1': 8, 'model.transformer_blocks.1.feed_forward.linear2': 8, 'model.transformer_blocks.2.self_attn.qkv_proj': 8, 'model.transformer_blocks.2.self_attn.out_proj': 8, 'model.transformer_blocks.2.feed_forward.linear1': 8, 'model.transformer_blocks.2.feed_forward.linear2': 8, 'model.transformer_blocks.3.self_attn.qkv_proj': 8, 'model.transformer_blocks.3.self_attn.out_proj': 8, 'model.transformer_blocks.3.feed_forward.linear1': 8, 'model.transformer_blocks.3.feed_forward.linear2': 8, 'model.transformer_blocks.4.self_attn.qkv_proj': 8, 'model.transformer_blocks.4.self_attn.out_proj': 8, 'model.transformer_blocks.4.feed_forward.linear

In [74]:
qmodule = quantize_model(module, qconfig)

ValueError: Expected nn.Linear at 'model.transformer_blocks.0.self_attn.qkv_proj', but found <class '__main__.QLinear'>

In [75]:
qmodule

ShakespeareModule(
  (model): AutoRegressiveTransformer(
    (embedding): Embedding(1024, 256)
    (pos_encoder): Embedding(1024, 256)
    (transformer_blocks): ModuleList(
      (0): TransformerBlock(
        (self_attn): MultiHeadAttention(
          (qkv_proj): QLinear(in_features=256, out_features=768, weight_bitwidth=8)
          (out_proj): QLinear(in_features=256, out_features=256, weight_bitwidth=4)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): FeedForward(
          (linear1): QLinear(in_features=256, out_features=1024, weight_bitwidth=8)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): QLinear(in_features=1024, out_features=256, weight_bitwidth=8)
          (activation): GELU(approximate='none')
        )
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): D