# Qwen in PyTorch

In [1]:
import math
import os

from pathlib import Path
from safetensors import safe_open

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
def load_sharded_safetensors(model_dir: str) -> dict:
    state_dict = {}
    # Get all .safetensors files in directory
    safetensors_files = list(Path(model_dir).glob("*.safetensors"))
    
    # Load in numerical order (critical for some models)
    safetensors_files.sort(key=lambda x: int(x.name.split("-")[-1].split(".")[0]))
    
    for st_file in safetensors_files:
        with safe_open(st_file, framework="pt", device="mps") as f:
            for key in f.keys():
                if key in state_dict:
                    raise RuntimeError(f"Duplicate key detected: {key}. Corrupted shards?")
                state_dict[key] = f.get_tensor(key)
    
    return state_dict

In [4]:
model_dir = os.getenv("MODEL_DIRECTORY")

In [5]:
state_dict = load_sharded_safetensors(model_dir)

In [6]:
for key, value in state_dict.items():
    print(f"{key}")

model.embed_tokens.weight
model.layers.0.input_layernorm.weight
model.layers.0.mlp.down_proj.weight
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight
model.layers.0.post_attention_layernorm.weight
model.layers.0.self_attn.k_proj.bias
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.o_proj.weight
model.layers.0.self_attn.q_proj.bias
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.v_proj.bias
model.layers.0.self_attn.v_proj.weight
model.layers.1.input_layernorm.weight
model.layers.1.mlp.down_proj.weight
model.layers.1.mlp.gate_proj.weight
model.layers.1.mlp.up_proj.weight
model.layers.1.post_attention_layernorm.weight
model.layers.1.self_attn.k_proj.bias
model.layers.1.self_attn.k_proj.weight
model.layers.1.self_attn.o_proj.weight
model.layers.1.self_attn.q_proj.bias
model.layers.1.self_attn.q_proj.weight
model.layers.1.self_attn.v_proj.bias
model.layers.1.self_attn.v_proj.weight
model.layers.10.input_layernorm.weight
model.layers.10.mlp

In [7]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x.float()).type_as(x)

In [8]:
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # Get embeddings for actual positions we need
    # position_ids shape: [batch_size, seq_len]
    # cos shape: [max_seq_len, head_dim]
    
    # First get the right positions from cos/sin
    cos_pos = cos[position_ids]  # [batch_size, seq_len, head_dim]
    sin_pos = sin[position_ids]  # [batch_size, seq_len, head_dim]
    
    # Reshape for broadcasting with q and k
    cos_pos = cos_pos.unsqueeze(2)  # [batch_size, seq_len, 1, head_dim]
    sin_pos = sin_pos.unsqueeze(2)  # [batch_size, seq_len, 1, head_dim]
    
    # Apply rotary embeddings
    q_embed = (q * cos_pos) + (rotate_half(q) * sin_pos)
    k_embed = (k * cos_pos) + (rotate_half(k) * sin_pos)
    return q_embed, k_embed

In [9]:
class GQAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.n_rep = self.num_attention_heads // self.num_key_value_heads
        
        self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
        
        self.rope_theta = config.rope_theta
        self.max_position_embeddings = config.max_position_embeddings
        self.register_buffer("cos", None)
        self.register_buffer("sin", None)

    def _init_rope(self, device):
        if self.cos is None:
            inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim))
            t = torch.arange(self.max_position_embeddings, device=device, dtype=inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.register_buffer("cos", emb.cos())  # Shape: [max_seq_len, head_dim]
            self.register_buffer("sin", emb.sin())  # Shape: [max_seq_len, head_dim]
            
    def forward(self, x, position_ids):
        batch_size, seq_len, _ = x.shape
        device = x.device
        
        q = self.q_proj(x).view(batch_size, seq_len, self.num_attention_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
        
        # Apply RoPE
        self._init_rope(device)
        q, k = apply_rotary_pos_emb(q, k, self.cos, self.sin, position_ids)
        
        # Repeat KV heads for GQA
        k = torch.repeat_interleave(k, self.n_rep, dim=2)
        v = torch.repeat_interleave(v, self.n_rep, dim=2)
        
        # Transpose for attention computation
        q = q.transpose(1, 2)  # (bs, n_heads, seq_len, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, -1)
        
        return self.o_proj(attn_output)


In [10]:
class SwiGLU(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

    def forward(self, x):
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))

In [11]:
class DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.self_attn = GQAttention(config)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mlp = SwiGLU(config.hidden_size, config.intermediate_size)

    def forward(self, x, position_ids):
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x, position_ids)
        x = residual + x
        
        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = residual + x
        return x

In [12]:
class Qwen2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = nn.ModuleDict(dict(
            embed_tokens=nn.Embedding(config.vocab_size, config.hidden_size),
            layers=nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)]),
            norm=RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        ))
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def forward(self, input_ids, position_ids=None):
        if position_ids is None:
            position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
            
        x = self.model.embed_tokens(input_ids)
        
        for layer in self.model.layers:
            x = layer(x, position_ids)
            
        x = self.model.norm(x)
        logits = self.lm_head(x)
        return logits

In [13]:
def fix_state_dict(state_dict):
    return {k.replace("model.", ""): v for k, v in state_dict.items()}

In [14]:
class Qwen2Config:
    def __init__(self, **kwargs):
        # Parameters from your config.json
        self.hidden_size = 2048
        self.num_attention_heads = 16
        self.num_key_value_heads = 2
        self.num_hidden_layers = 36
        self.intermediate_size = 11008
        self.rms_norm_eps = 1e-6
        self.vocab_size = 151936
        self.max_position_embeddings = 32768
        self.rope_theta = 1e6
        self.tie_word_embeddings = True
        self.sliding_window = 32768
        self.use_sliding_window = False
        self.attention_dropout = 0.0
        self.hidden_act = "silu"
        
        # Update with any custom parameters
        for k, v in kwargs.items():
            setattr(self, k, v)

In [None]:
config = Qwen2Config()
model = Qwen2Model(config).to(torch.float16).to('mps')
#model = Qwen2Model(config).to('mps')


# Load your state_dict here
model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=['lm_head.weight'], unexpected_keys=[])

## Inference

In [16]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B", trust_remote_code=True)

In [17]:
prompt = "Please tell me what the capital of France is:"

In [18]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to('mps')

In [19]:
with torch.no_grad():
    outputs = model(input_ids)

RuntimeError: MPS backend out of memory (MPS allocated: 18.12 GB, other allocations: 2.78 MB, max allowed: 18.13 GB). Tried to allocate 16.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
outputs.shape

torch.Size([1, 10, 151936])

## Answer generation

In [1]:
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=0.8):
    """
    Standard autoregressive text generation with basic temperature sampling.
    
    Args:
        model: Your language model
        tokenizer: The corresponding tokenizer
        prompt: Input text prompt
        max_new_tokens: Maximum number of tokens to generate
        temperature: Controls randomness in sampling (1.0 = standard, <1.0 = more focused)
    """
    # Tokenize the prompt
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    
    # Move to the model's device
    device = next(model.parameters()).device
    input_ids = input_ids.to(device)
    
    # Set the model to evaluation mode
    model.eval()
    
    # Generate tokens autoregressively
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Create position IDs
            position_ids = torch.arange(0, input_ids.shape[1], device=device).unsqueeze(0)
            
            # Forward pass
            outputs = model(input_ids, position_ids)
            
            # Get next token logits (the predictions for what comes next)
            next_token_logits = outputs[0, -1, :]
            
            # Apply temperature
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
            
            # Convert to probabilities
            probs = torch.softmax(next_token_logits, dim=-1)
            
            # Sample from the probability distribution
            next_token_id = torch.multinomial(probs, num_samples=1).unsqueeze(0)
            
            # Add the new token to the sequence
            input_ids = torch.cat([input_ids, next_token_id], dim=1)
            
            # Check if we've hit the end token
            if next_token_id.item() == tokenizer.eos_token_id:
                break
    
    # Decode the full sequence
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return generated_text

In [None]:
prompt = "Please tell me what the capital of France is:"
generated_text = generate_text(
    model, 
    tokenizer, 
    prompt, 
    max_new_tokens=50,
    temperature=0.7
)
print(generated_text)

Please tell me what the capital of France is.
The capital city of France, which you're asking about is Paris.  }
