In [1]:
from importlib.metadata import version
from safetensors.torch import load_file
import torch,tiktoken,torch.nn as nn
pkgs=['numpy','torch','safetensors']
for p in pkgs:
    print(f'{p} version: {version(p)}')

numpy version: 2.4.1
torch version: 2.9.1
safetensors version: 0.7.0


In [2]:
BASE_CONFIG={'vocab_size':50257,
             'context_length':1024,
             'drop_rate':0,
             'qkv_bias':True}
model_configs={'gpt2-small (124M)':{'emb_dim':768,
                                    'n_layers':12,
                                    'n_heads':12}}
CHOOSE_MODEL='gpt2-small (124M)'
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
URL_DIR={'gpt2-small (124M)':'gpt2'}
output_file=f'model-{URL_DIR[CHOOSE_MODEL]}.safetensors'
state_dict=load_file(output_file)
def assign(left,right):
    if left.shape!=right.shape:
        raise ValueError(f'Shape mismatch.\nLeft: {left.shape} | Right: {right.shape}')
    return torch.nn.Parameter(right.detach())
def load_weights_into_gpt(gpt,params):
    gpt.pos_emb.weight=assign(gpt.pos_emb.weight,params['wpe.weight'])
    gpt.tok_emb.weight=assign(gpt.tok_emb.weight,params["wte.weight"])
    for b in range(len(gpt.trf_blocks)):
        q_w,k_w,v_w=torch.chunk(params[f'h.{b}.attn.c_attn.weight'],3,axis=-1)
        gpt.trf_blocks[b].att.W_query.weight=assign(gpt.trf_blocks[b].att.W_query.weight,
                                                    q_w.T)
        gpt.trf_blocks[b].att.W_key.weight=assign(gpt.trf_blocks[b].att.W_key.weight,
                                                  k_w.T)
        gpt.trf_blocks[b].att.W_value.weight=assign(gpt.trf_blocks[b].att.W_value.weight,
                                                    v_w.T)
        q_b,k_b,v_b=torch.chunk(params[f'h.{b}.attn.c_attn.bias'],3,axis=-1)
        gpt.trf_blocks[b].att.W_query.bias=assign(gpt.trf_blocks[b].att.W_query.bias,q_b)
        gpt.trf_blocks[b].att.W_key.bias=assign(gpt.trf_blocks[b].att.W_key.bias,k_b)
        gpt.trf_blocks[b].att.W_value.bias=assign(gpt.trf_blocks[b].att.W_value.bias,v_b)
        gpt.trf_blocks[b].att.out_proj.weight=assign(gpt.trf_blocks[b].att.out_proj.weight,
                                                     params[f'h.{b}.attn.c_proj.weight'].T)
        gpt.trf_blocks[b].att.out_proj.bias=assign(gpt.trf_blocks[b].att.out_proj.bias,params[f'h.{b}.attn.c_proj.bias'])
        gpt.trf_blocks[b].ff.layers[0].weight=assign(gpt.trf_blocks[b].ff.layers[0].weight,
                                                     params[f'h.{b}.mlp.c_fc.weight'].T)
        gpt.trf_blocks[b].ff.layers[0].bias=assign(gpt.trf_blocks[b].ff.layers[0].bias,params[f'h.{b}.mlp.c_fc.bias'])
        gpt.trf_blocks[b].ff.layers[2].weight=assign(gpt.trf_blocks[b].ff.layers[2].weight,
                                                     params[f'h.{b}.mlp.c_proj.weight'].T)
        gpt.trf_blocks[b].ff.layers[2].bias=assign(gpt.trf_blocks[b].ff.layers[2].bias,params[f'h.{b}.mlp.c_proj.bias'])
        gpt.trf_blocks[b].norm1.scale=assign(gpt.trf_blocks[b].norm1.scale,params[f'h.{b}.ln_1.weight'])
        gpt.trf_blocks[b].norm1.shift=assign(gpt.trf_blocks[b].norm1.shift,params[f'h.{b}.ln_1.bias'])
        gpt.trf_blocks[b].norm2.scale=assign(gpt.trf_blocks[b].norm2.scale,params[f'h.{b}.ln_2.weight'])
        gpt.trf_blocks[b].norm2.shift=assign(gpt.trf_blocks[b].norm2.shift,params[f'h.{b}.ln_2.bias'])
    gpt.final_norm.scale=assign(gpt.final_norm.scale,params['ln_f.weight'])
    gpt.final_norm.shift=assign(gpt.final_norm.shift,params['ln_f.bias'])
    gpt.out_head.weight=assign(gpt.out_head.weight,params['wte.weight'])
class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
        super().__init__()
        assert d_out%num_heads==0,'d_out must be divisible by n_heads.'
        self.d_out=d_out
        self.num_heads=num_heads
        self.head_dim=d_out//num_heads
        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.out_proj=nn.Linear(d_out,d_out)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        keys=keys.view(b,num_tokens,self.num_heads,self.head_dim)
        values=values.view(b,num_tokens,self.num_heads,self.head_dim)
        queries=queries.view(b,num_tokens,self.num_heads,self.head_dim)
        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)
        attn_scores=queries@keys.transpose(2,3)
        mask_bool=self.mask.bool()[:num_tokens,
                                   :num_tokens]
        attn_scores.masked_fill_(mask_bool,-torch.inf)
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=(attn_weights@values).transpose(1,2)
        context_vec=context_vec.reshape(b,num_tokens,self.d_out)
        context_vec=self.out_proj(context_vec)
        return context_vec
class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return .5*x*(1+torch.tanh(torch.sqrt(torch.tensor(2/torch.pi))*(x+.044715*torch.pow(x,3))))
class FeedForward(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.layers=nn.Sequential(nn.Linear(cfg['emb_dim'],4*cfg['emb_dim']),GELU(),nn.Linear(4*cfg['emb_dim'],cfg['emb_dim']))
    def forward(self,x):
        return self.layers(x)
class TransformerBlock(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.att=MultiHeadAttention(d_in=cfg['emb_dim'],
                                    d_out=cfg['emb_dim'],
                                    context_length=cfg['context_length'],
                                    num_heads=cfg['n_heads'],
                                    dropout=cfg['drop_rate'],
                                    qkv_bias=cfg['qkv_bias'])
        self.ff=FeedForward(cfg)
        self.norm1=LayerNorm(cfg['emb_dim'])
        self.norm2=LayerNorm(cfg['emb_dim'])
        self.drop_resid=nn.Dropout(cfg['drop_rate'])
    def forward(self,x):
        shortcut=x
        x=self.norm1(x)
        x=self.att(x)
        x=self.drop_resid(x)
        x=x+shortcut
        shortcut=x
        x=self.norm2(x)
        x=self.ff(x)
        x=self.drop_resid(x)
        x=x+shortcut
        return x
class LayerNorm(nn.Module):
    def __init__(self,emb_dim):
        super().__init__()
        self.eps=1e-5
        self.scale=nn.Parameter(torch.ones(emb_dim))
        self.shift=nn.Parameter(torch.zeros(emb_dim))
    def forward(self,x):
        mean=x.mean(dim=-1,
                    keepdim=True)
        var=x.var(dim=-1,
                  keepdim=True,
                  unbiased=False)
        norm_x=(x-mean)/torch.sqrt(var+self.eps)
        return self.scale*norm_x+self.shift
class GPTModel(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.tok_emb=nn.Embedding(cfg['vocab_size'],
                                  cfg['emb_dim'])
        self.pos_emb=nn.Embedding(cfg['context_length'],
                                  cfg['emb_dim'])
        self.drop_emb=nn.Dropout(cfg['drop_rate'])
        self.trf_blocks=nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.final_norm=LayerNorm(cfg['emb_dim'])
        self.out_head=nn.Linear(cfg['emb_dim'],cfg['vocab_size'],bias=False)
    def forward(self,in_idx):
        batch_size,seq_len=in_idx.shape
        tok_embeds=self.tok_emb(in_idx)
        pos_embeds=self.pos_emb(torch.arange(seq_len,device=in_idx.device))
        x=tok_embeds+pos_embeds
        x=self.drop_emb(x)
        x=self.trf_blocks(x)
        x=self.final_norm(x)
        logits=self.out_head(x)
        return logits
gpt=GPTModel(BASE_CONFIG)
device=torch.device('mps')
load_weights_into_gpt(gpt,state_dict)
gpt.to(device)
tokenizer=tiktoken.get_encoding('gpt2')
def generate(model,idx,max_new_tokens,context_size,temperature=0,top_k=None,eos_id=None):
    for _ in range(max_new_tokens):
        idx_cond=idx[:,-context_size:]
        with torch.no_grad():
            logits=model(idx_cond)
        logits=logits[:,-1,:]
        if top_k is not None:
            top_logits,_=torch.topk(logits,top_k)
            min_val=top_logits[:,-1]
            logits=torch.where(logits<min_val,torch.tensor(float('-inf')).to(logits.device),logits)
        if temperature>0:
            logits=logits/temperature
            logits=logits-logits.max(dim=-1,
                                     keepdim=True).values
            probs=torch.softmax(logits,dim=-1)
            idx_next=torch.multinomial(probs,num_samples=1)
        else:
            idx_next=torch.argmax(logits,dim=-1,keepdim=True)
        if idx_next==eos_id:
            break
        idx = torch.cat((idx,idx_next),dim=1)
    return idx
def text_to_token_ids(text,tokenizer):
    encoded=tokenizer.encode(text,allowed_special={'<|endoftext|>'})
    encoded_tensor=torch.tensor(encoded).unsqueeze(0)
    return encoded_tensor
def token_ids_to_text(token_ids,tokenizer):
    flat=token_ids.squeeze(0)
    return tokenizer.decode(flat.tolist())
token_ids=generate(model=gpt.to(device),
                   idx=text_to_token_ids('Every effort moves',tokenizer).to(device),
                   max_new_tokens=30,
                   context_size=BASE_CONFIG['context_length'],
                   top_k=1,
                   temperature=1)
print(f'Output text: {token_ids_to_text(token_ids,tokenizer)}')

Output text: Every effort moves forward, but it's not enough.

"I'm not going to sit here and say, 'I'm not going to do this,'
