In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import inspect
import time
import tiktoken

device = 'cuda' if torch.cuda.is_available() else 'cpu'
enc = tiktoken.get_encoding("gpt2")

# Model architectures 

# attn-> concat-> linear
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd


    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)x`
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention # (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        # output projection
        y = self.c_proj(y)
        return y


# feedforward network: two linear maps with a GELU in between
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x


# transformer block
class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

# Configuration classes
@dataclass
class GPTConfig124M:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

@dataclass
class GPTConfig19M:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 8
    n_head: int = 8
    n_embd: int = 256


class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        #weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

        #init paramsxxxxxxxxxxxxxxxxxxx
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, idx, targets=None):
        # idx is of shape (B, T)
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        # forward the token and posisition embeddings
        pos = torch.arange(0, T, dtype=torch.long, device= 'cuda' if torch.cuda.is_available() else "cpu") # shape (T)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
        x = tok_emb + pos_emb
        # forward the blocks of the transformer
        for block in self.transformer.h:
            x = block(x)
        # forward the final layernorm and the classifier
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)

        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available

        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer






In [None]:
# notice : use of decorator might results in lose of metadata, @functools.wraps() -> to preserve it

In [28]:
class SpeculativeDecoder:
    def __init__(self, draft_model, target_model):
        self.draft_model = draft_model.to(device)
        self.target_model = target_model.to(device)
        self.draft_model.eval()
        self.target_model.eval()
    
    def sample_token(self, logits, temperature=1.0, top_k=90):
        """Sample token using top-k sampling like your original code."""
        if temperature == 0:
            return torch.argmax(logits, dim=-1).item()
        
        probs = F.softmax(logits / temperature, dim=-1)
        if top_k > 0:
            topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
            ix = torch.multinomial(topk_probs, 1)
            return torch.gather(topk_indices, -1, ix).item()
        else:
            return torch.multinomial(probs, 1).item()
    
    def draft_decode(self, input_ids, num_draft_tokens=4, temperature=1.0, top_k=90):
        """Generate draft tokens using the smaller model."""
        draft_tokens = []
        draft_logits = []
        current_ids = input_ids.clone()
        
        with torch.no_grad():
            for _ in range(num_draft_tokens):
                logits, _ = self.draft_model(current_ids)
                last_logits = logits[:, -1, :]
                
                next_token = self.sample_token(last_logits[0], temperature, top_k)
                draft_tokens.append(next_token)
                draft_logits.append(last_logits[0])
                
                current_ids = torch.cat([current_ids, torch.tensor([[next_token]], device=device)], dim=1)
        
        return draft_tokens, torch.stack(draft_logits)
    
    def verify_and_correct(self, input_ids, draft_tokens, draft_logits, temperature=1.0, top_k=90):
        """Verify draft tokens with target model and apply rejection sampling."""
        extended_ids = torch.cat([
            input_ids, 
            torch.tensor([draft_tokens], device=device)
        ], dim=1)
        
        with torch.no_grad():
            target_logits, _ = self.target_model(extended_ids)
            # FIX: Correct the logits alignment
            start_pos = len(input_ids[0]) - 1  # Position that predicts first draft token
            end_pos = start_pos + len(draft_tokens)
            target_logits_for_draft = target_logits[0, start_pos:end_pos]
        accepted_tokens = []
        
        for i, (draft_token, draft_logit, target_logit) in enumerate(
            zip(draft_tokens, draft_logits, target_logits_for_draft)
        ):
            # Calculate acceptance probability
            draft_prob = F.softmax(draft_logit / temperature, dim=-1)
            target_prob = F.softmax(target_logit / temperature, dim=-1)
            
            acceptance_prob = torch.min(
                torch.tensor(1.0, device=device),
                target_prob[draft_token] / (draft_prob[draft_token] + 1e-10)
            )
            
            if torch.rand(1, device=device) < acceptance_prob:
                accepted_tokens.append(draft_token)
            else:
                # Rejection: sample from corrected distribution
                corrected_prob = torch.clamp(target_prob - draft_prob, min=0)
                corrected_prob = corrected_prob / (corrected_prob.sum() + 1e-10)
                
                if corrected_prob.sum() > 1e-10:
                    if top_k > 0:
                        topk_probs, topk_indices = torch.topk(corrected_prob, min(top_k, corrected_prob.shape[0]))
                        ix = torch.multinomial(topk_probs, 1)
                        corrected_token = torch.gather(topk_indices, -1, ix).item()
                    else:
                        corrected_token = torch.multinomial(corrected_prob, 1).item()
                    accepted_tokens.append(corrected_token)
                else:
                    # Fallback to target model sampling
                    corrected_token = self.sample_token(target_logit, temperature, top_k)
                    accepted_tokens.append(corrected_token)
                
                break  # Stop after first rejection
        
        return accepted_tokens, len(accepted_tokens)
    
    def generate(self, input_ids, max_length=100, num_draft_tokens=4, temperature=1.0, top_k=90):
        """Generate text using speculative decoding."""
        generated_ids = input_ids.clone()
        total_draft_tokens = 0
        total_accepted_tokens = 0
        num_iterations = 0
        
        start_time = time.time()
        
        while generated_ids.shape[1] < max_length:
            num_iterations += 1
            
            # Generate draft tokens
            draft_tokens, draft_logits = self.draft_decode(
                generated_ids, num_draft_tokens, temperature, top_k
            )
            total_draft_tokens += len(draft_tokens)
            
            # Verify and correct with target model
            accepted_tokens, num_accepted = self.verify_and_correct(
                generated_ids, draft_tokens, draft_logits, temperature, top_k
            )
            total_accepted_tokens += num_accepted
            
            # Update generated sequence
            if accepted_tokens:
                new_tokens = torch.tensor([accepted_tokens], device=device)
                generated_ids = torch.cat([generated_ids, new_tokens], dim=1)
            
            if generated_ids.shape[1] >= max_length:
                break
        
        end_time = time.time()
        
        # Calculate statistics
        acceptance_rate = total_accepted_tokens / total_draft_tokens if total_draft_tokens > 0 else 0
        tokens_per_second = (generated_ids.shape[1] - input_ids.shape[1]) / (end_time - start_time)
        
        stats = {
            'total_iterations': num_iterations,
            'total_draft_tokens': total_draft_tokens,
            'total_accepted_tokens': total_accepted_tokens,
            'acceptance_rate': acceptance_rate,
            'tokens_per_second': tokens_per_second,
            'generation_time': end_time - start_time
        }
        
        return generated_ids, stats


In [29]:

def load_models():
    """Load both models."""
    print("Loading 124M parameter model...")
    target_model = GPT(GPTConfig124M())
    checkpoint_124m = torch.load("124mpara.pth", map_location='cpu', weights_only=True)
    # this is cause of '_orig_mod' wrapper caused by dynamo and torch compile during training
    checkpoint_124m = {k.replace('_orig_mod.', ''): v for k, v in checkpoint_124m.items()}
    target_model.load_state_dict(checkpoint_124m)
    print("Loading 19M parameter model...")
    draft_model = GPT(GPTConfig19M())
    
    checkpoint_19m = torch.load("19M.pth", map_location='cpu', weights_only=True)
    draft_model.load_state_dict(checkpoint_19m)

    return draft_model, target_model

def regular_generation(model, input_ids, max_length=100, temperature=1.0, top_k=90):
    """Regular generation for comparison."""
    model.eval()
    xgen = input_ids.clone()
    start_time = time.time()
    
    while xgen.size(1) < max_length:
        with torch.no_grad():
            logits, _ = model(xgen)
            logits = logits[:, -1, :]
            probs = F.softmax(logits / temperature, dim=-1)
            
            if top_k > 0:
                topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
                ix = torch.multinomial(topk_probs, 1)
                xcol = torch.gather(topk_indices, -1, ix)
            else:
                xcol = torch.multinomial(probs, 1)
            
            xgen = torch.cat((xgen, xcol), dim=1)
    
    end_time = time.time()
    tokens_per_second = (xgen.shape[1] - input_ids.shape[1]) / (end_time - start_time)
    
    return xgen, {'tokens_per_second': tokens_per_second, 'generation_time': end_time - start_time}

# Load models
draft_model, target_model = load_models()
# Initialize speculative decoder
decoder = SpeculativeDecoder(draft_model, target_model)


Loading 124M parameter model...
Loading 19M parameter model...


In [None]:
prompt = "love in the air"
tokens = enc.encode(prompt)
idx = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
print(f"Prompt: {prompt}")
print(f"Input tokens: {len(tokens)}")
print("-" * 50)

# params
max_length = 100
temperature = 1.2
top_k = 90
num_draft_tokens = 3

# Regular generation with target model
print("Regular generation (124M model only):")
regular_ids, regular_stats = regular_generation(
    target_model, idx, max_length, temperature, top_k
)
regular_text = enc.decode(regular_ids[0].tolist())
print(f"Generated text: {regular_text}")
print(f"Speed: {regular_stats['tokens_per_second']:.2f} tokens/sec")
print(f"Time: {regular_stats['generation_time']:.2f} seconds")
print()

# Speculative decoding
print("Speculative decoding (19M draft + 124M target):")
spec_ids, spec_stats = decoder.generate(
    idx, max_length, num_draft_tokens, temperature, top_k
)
spec_text = enc.decode(spec_ids[0].tolist())
print(f"Generated text: {spec_text}")
print(f"Speed: {spec_stats['tokens_per_second']:.2f} tokens/sec")
print(f"Time: {spec_stats['generation_time']:.2f} seconds")
print(f"Acceptance rate: {spec_stats['acceptance_rate']:.2f}")
print(f"Speedup: {spec_stats['tokens_per_second'] / regular_stats['tokens_per_second']:.2f}x")

Prompt: love in the air
Input tokens: 4
--------------------------------------------------
Regular generation (124M model only):
Generated text: love in the air  
When you know it all begins  
Baby it's in your ears  
You must try to help those in your mind  
Tonight  
I dream so far away  
I'd be so in the shadows  
I'd always end behind you and have to decide  
But somehow that it seems you would see me  
Baby it's in your veins  
But how can we part love another
Speed: 7.58 tokens/sec
Time: 12.67 seconds

Speculative decoding (19M draft + 124M target):
Generated text: love in the air  
It never gets colder  
Maybe I can't let go now  
And never leave, never let go  
If I could be, never let go  
To the end if time has come true come true come true  
Come closer don't break a little closer  
Well I won't let go  
If I could be I would be here once again  
  
If I could be
Speed: 9.59 tokens/sec
Time: 10.01 seconds
Acceptance rate: 0.38
Speedup: 1.26x


In [31]:
print(f"Total iterations: {spec_stats['total_iterations']}")
print(f"Total draft tokens: {spec_stats['total_draft_tokens']}")
print(f"Total accepted tokens: {spec_stats['total_accepted_tokens']}")

Total iterations: 25
Total draft tokens: 250
Total accepted tokens: 96
