In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.ao.quantization as quant

import math
import time
import matplotlib.pyplot as plt

In [99]:
class BitLinear(nn.Module):
    def __init__(self, in_features, out_features, bits=8):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bits = bits

        # Define standard linear layer
        self.linear = nn.Linear(in_features, out_features)

        # Register fake quantization for QAT (Quantization-Aware Training)
        self.weight_fake_quant = quant.fake_quantize.FusedMovingAvgObsFakeQuantize(
            observer=quant.observer.MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8
        )

        self.activation_fake_quant = quant.fake_quantize.FusedMovingAvgObsFakeQuantize(
            observer=quant.observer.MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8
        )

    def forward(self, x):
        if self.training:
            # Apply quantization-aware training (QAT)
            w_q = self.weight_fake_quant(self.linear.weight)
            x_q = self.activation_fake_quant(x)
            return F.linear(x_q, w_q, self.linear.bias)
        else:
            # Use pre-quantized inference mode (static quantization)
            return self.linear(x)

Transformer 

In [3]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, n_embd, head_size, block_size, quantization_bits=8, dropout=0.2):
        super().__init__()
        self.head_size = head_size
        self.key = BitLinear(n_embd, head_size, quantization_bits)
        self.query = BitLinear(n_embd, head_size, quantization_bits)
        self.value = BitLinear(n_embd, head_size, quantization_bits)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # print("head forwarding")
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * self.head_size **-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)

        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out
    

    
class MultiHeadAttention(nn.Module):

  def __init__(self, num_heads, n_embd, head_size, block_size, quantization_bits=8, dropout=0.2):
    super().__init__()
    self.heads = nn.ModuleList([Head(n_embd, head_size, block_size, quantization_bits) for _ in range(num_heads)])
    self.proj = BitLinear(n_embd, n_embd, quantization_bits)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    # print("sa forwarding")
    out = torch.cat([h(x) for h in self.heads], dim = -1) #concat over channel dim
    # print("now proj")
    out = self.proj(out)
    return out
  


class FeedForward(nn.Module):

  def __init__(self, n_embd, quantization_bits=8, dropout=0.2):
    super().__init__()
    self.net = nn.Sequential(
        BitLinear(n_embd, 4*n_embd, quantization_bits), # *4 secondo argomento
        nn.ReLU(),
        BitLinear(4*n_embd, n_embd, quantization_bits), #proj layer #*4 primo argomento
        nn.Dropout(dropout),
    )

  def forward(self,x):
    # print("ffwd forwarding")
    return self.net(x)
  

class Block(nn.Module):

  def __init__(self, num_heads, n_embd, block_size, sa_bits=8, ffwd_bits=8):
    super().__init__()
    head_size = n_embd//num_heads
    self.sa = MultiHeadAttention(num_heads, n_embd, head_size, block_size, sa_bits)
    self.ffwd = FeedForward(n_embd, quantization_bits=ffwd_bits)

  def forward(self,x):
    x = x + self.sa(x)
    x = x + self.ffwd(x)
    return x

LLM

In [4]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, num_layers, num_heads, n_embd, block_size=8, bits=8):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        
        self.blocks = nn.Sequential(*[Block(num_heads, n_embd, block_size, sa_bits=bits, ffwd_bits=bits) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.linear = nn.Linear(n_embd, vocab_size)

        # Apply quantization stubs
        self.quant = quant.QuantStub()
        self.dequant = quant.DeQuantStub()

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb

        if not self.training:
            x = self.quant(x)  # Apply quantization before passing through layers

        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.linear(x)

        if not self.training:
            logits = self.dequant(logits)  # Dequantize before returning output

        if targets is None:
            loss = None
        else:
            logits = logits.reshape(B * T, logits.size(-1))
            targets = targets.reshape(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss


    def prepare_for_quantization(self):
        """Prepare the model for quantization-aware training (QAT)."""
        self.train()

        # Standard quantization config for all layers *except* Embedding
        self.qconfig = quant.qconfig_mapping.get_default_qat_qconfig("fbgemm")

        # Prevent Embedding layers from being quantized
        for module in self.modules():
            if isinstance(module, nn.Embedding):
                module.qconfig = None  # Exclude Embeddings from quantization

        quant.prepare_qat(self, inplace=True)


    def convert_to_quantized(self):
        """Convert model to fully quantized version for fast inference."""
        # self.eval()
        # Create a mapping that excludes Embedding layers
        mapping = {nn.Linear: torch.ao.nn.quantized.Linear}  # Only quantize Linear layers
        quant.convert(self, mapping=mapping, inplace=True)


efficient QAT

In [None]:
# def efficientQAT(model, data_loader, optimizer_block, optimizer_e2e):
def QAT(model, data_loader, optimizer_block):
    """    
    Parameters:
    - model: The quantized model.
    - data_loader: The data loader to provide input batches.
    - optimizer_block: Optimizer for block training.
    """
    model.train()
    model.prepare_for_quantization()  # Enable QAT
    
    for _, (inputs, targets) in enumerate(data_loader):
        
        for param in model.parameters():
            param.requires_grad = True
        
        optimizer_block.zero_grad()
        _, loss = model(inputs, targets)
        loss.backward()
        optimizer_block.step()

Data load, encoding, training & test set creation

In [6]:
# data loading
def get_batch(split, train_data, val_data, block_size, batch_size):
    
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss(model, train_data, val_data, block_size, batch_size, eval_iters=200):
    out = {}
    model.eval()
    model.convert_to_quantized()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, train_data, val_data, block_size, batch_size)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [108]:
import copy

@torch.no_grad()
def estimate_loss(model, train_data, val_data, block_size, batch_size, eval_iters=200):
    out = {}
    # model.eval()  # Set model to evaluation mode
    
    # Create a copy and convert it to quantized
    qmodel = copy.deepcopy(model)
    qmodel.eval()
    qmodel.convert_to_quantized()

    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, train_data, val_data, block_size, batch_size)
            _, loss = qmodel(X, Y)  # Use quantized model for inference
            losses[k] = loss.item()
        out[split] = losses.mean()

    # model.train()  # Restore training mode for original model
    return out

In [7]:
with open('input.txt','r', encoding = 'utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [153]:
# hyperparameters
batch_size = 32
block_size = 8
n_embd = 32
n_head = 4
n_layer = 6
learning_rate = 1e-3
num_epochs = 10
epoch_length = 200
eval_iters = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [154]:
qmodel8 = BigramLanguageModel(vocab_size, n_layer, n_head, n_embd, block_size=8).to(device)

initial_loss = estimate_loss(qmodel8, train_data, val_data, block_size, batch_size, eval_iters)
print(f"Beginning at: train loss {initial_loss['train']:.4f}, val loss {initial_loss['val']:.4f}\n")

optimizer = torch.optim.AdamW([
    {'params': [p for name, p in qmodel8.named_parameters()], 'lr': learning_rate}
])

data_loader = [(get_batch('train', train_data, val_data, block_size, batch_size)) for _ in range(epoch_length)]  # Replace with your DataLoader
loss_val8 = []
loss_train8 = []

start = time.time()
for epoch in range(num_epochs):
    QAT(qmodel8, data_loader, optimizer)
    losses = estimate_loss(qmodel8, train_data, val_data, block_size, batch_size, eval_iters)
    loss_val8.append(losses['val'])
    loss_train8.append(losses['train'])
    print(f"step {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    

end = time.time()
total_time = end - start
print(total_time)

Beginning at: train loss 4.3150, val loss 4.3141

step 0: train loss 2.6291, val loss 2.6160
step 1: train loss 2.4518, val loss 2.4675
step 2: train loss 2.3832, val loss 2.3946
step 3: train loss 2.3164, val loss 2.3274
step 4: train loss 2.2774, val loss 2.2902
step 5: train loss 2.2510, val loss 2.2603
step 6: train loss 2.2276, val loss 2.2401
step 7: train loss 2.2049, val loss 2.2171
step 8: train loss 2.1815, val loss 2.2120
step 9: train loss 2.1568, val loss 2.1956
676.4785394668579


In [156]:
qmodel8

BigramLanguageModel(
  (token_embedding_table): Embedding(65, 32)
  (position_embedding_table): Embedding(8, 32)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0): Head(
            (key): BitLinear(
              (linear): Linear(in_features=32, out_features=8, bias=True)
              (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
                fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([0.0022]), zero_point=tensor([-2], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=False
                (activation_post_process): MovingAverageMinMaxObserver(min_val=-0.27910614013671875, max_val=0.2838379442691803)
              )
              (activation_fake_quant): FusedMovingAvgObsFakeQuantize(
                fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([0.0377]), zero_point=tensor([140], dtype

In [157]:
qmodel8_2 = BigramLanguageModel(vocab_size, n_layer, n_head, n_embd, block_size=8).to(device)

initial_loss_2 = estimate_loss(qmodel8_2, train_data, val_data, block_size, batch_size, eval_iters)
print(f"Beginning at: train loss {initial_loss_2['train']:.4f}, val loss {initial_loss_2['val']:.4f}\n")

optimizer = torch.optim.AdamW([
    {'params': [p for name, p in qmodel8_2.named_parameters()], 'lr': learning_rate}
])

data_loader = [(get_batch('train', train_data, val_data, block_size, batch_size)) for _ in range(epoch_length)]  # Replace with your DataLoader
loss_val8_2 = []
loss_train8_2 = []

start_2 = time.time()
for epoch in range(num_epochs):
    QAT(qmodel8_2, data_loader, optimizer)
    losses_2 = estimate_loss(qmodel8_2, train_data, val_data, block_size, batch_size, eval_iters)
    loss_val8_2.append(losses_2['val'])
    loss_train8_2.append(losses_2['train'])
    print(f"step {epoch}: train loss {losses_2['train']:.4f}, val loss {losses_2['val']:.4f}")
    

end_2 = time.time()
total_time_2 = end_2 - start_2
print(total_time_2)

Beginning at: train loss 4.3689, val loss 4.3841

step 0: train loss 2.6279, val loss 2.6407
step 1: train loss 2.4489, val loss 2.4643
step 2: train loss 2.3795, val loss 2.3788
step 3: train loss 2.3177, val loss 2.3376
step 4: train loss 2.2713, val loss 2.3085
step 5: train loss 2.2513, val loss 2.2731
step 6: train loss 2.2163, val loss 2.2419
step 7: train loss 2.2105, val loss 2.2492
step 8: train loss 2.1769, val loss 2.2034
step 9: train loss 2.1708, val loss 2.2126
645.8131875991821
