In [29]:
import torch
import torch.nn as nn

In [30]:
GPT_CONFIG_124M = {
    'vocab_size':50257,
    'context_length':1024,
    'emb_dim' : 768,
    'n_head':12,
    'n_layers':12,
    'drop_rate':0.1,
    'qkv_bias':False
}

In [31]:
class Layer_Norm(nn.Module):
  def __init__(self,emb_dim):
    super().__init__()
    self.scale = nn.Parameter(torch.ones(emb_dim))
    self.shift = nn.Parameter(torch.ones(emb_dim))
    self.eps = 1e-5

  def forward(self,x):
    mean = x.mean(dim=-1,keepdim=True)
    var = x.var(dim=-1,keepdim=True, unbiased=True)
    return self.scale*(x-mean)/torch.sqrt(var+self.eps) + self.shift


class GELU(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self,x):
    return 0.5*x*(1+torch.tanh(torch.tensor((2/torch.pi))**0.5)*(x+0.044715*x**3))


class FeedForward(nn.Module):
  def __init__(self,cfg):
        super().__init__()
        self.layers =  nn.Sequential(
        nn.Linear(cfg['emb_dim'],4*cfg['emb_dim']),    #expansion
        GELU(),   #non linear activation
        nn.Linear(4*cfg['emb_dim'],cfg['emb_dim'])  #compression
    )
  def forward(self,x):
    return self.layers(x)

In [32]:
class MultiheadAttention(nn.Module):
  def __init__(self,din,dout,context_length,dropout,num_heads,qkv_bias=False):
    super().__init__()
    self.w_queries = nn.Linear(din,dout,qkv_bias)
    self.w_keys = nn.Linear(din,dout,qkv_bias)
    self.w_values = nn.Linear(din,dout,qkv_bias)
    self.dropout=nn.Dropout(dropout)
    self.context_length=context_length
    self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length)))
    self.num_heads=num_heads
    self.out_proj = nn.Linear(dout,dout)

  def forward(self,x):
    b,contextlength, emd_size = x.shape
    num_heads=self.num_heads
    head_dim = emd_size//num_heads

    queries = self.w_queries(x)
    keys = self.w_keys(x)
    values = self.w_values(x)

    queries = queries.view(b,contextlength,num_heads,head_dim)
    keys = queries.view(b,contextlength,num_heads,head_dim)
    values = queries.view(b,contextlength,num_heads,head_dim)

    queries = queries.transpose(1,2)
    keys = keys.transpose(1,2)
    values = values.transpose(1,2)

    attention_scores = queries @ keys.transpose(2,3)
    attention_scores.masked_fill(self.mask.bool()[:contextlength, :contextlength],-torch.inf)
    attention_weights = torch.softmax(attention_scores/num_heads**0.5,dim=-1)
    attention_weights = self.dropout(attention_weights)
    context_vectors = (attention_weights @ values).transpose(1,2)
    context_vectors = context_vectors.contiguous().view(b,contextlength,emd_size)
    context_vectors = self.out_proj(context_vectors)
    return context_vectors


In [33]:
class Transformer(nn.Module):

  def __init__(self,cfg):
      super().__init__()
      self.norm1 = Layer_Norm(cfg['emb_dim'])
      self.norm2 = Layer_Norm(cfg['emb_dim'])
      self.att = MultiheadAttention(din=cfg['emb_dim'],
                                    dout=cfg['emb_dim'],
                                    context_length=cfg['context_length'],
                                    dropout=cfg['drop_rate'],
                                    num_heads=cfg['n_head'])
      self.ff = FeedForward(cfg)
      self.drop_shortcut = nn.Dropout(cfg['drop_rate'])

  def forward(self,x):

      shortcut = x
      x = self.norm1(x)
      x = self.att(x)
      x = self.drop_shortcut(x)
      x = x+shortcut
      shortcut = x
      x = self.norm2(x)
      x = self.ff(x)
      x = self.drop_shortcut(x)
      x = x+shortcut

      return x


In [113]:

class GPTModel(nn.Module):
  def __init__(self,cfg):
      super().__init__()
      self.tok_embedding = nn.Embedding(cfg['vocab_size'],cfg['emb_dim'])
      self.pos_embedding = nn.Embedding(cfg['context_length'],cfg['emb_dim'])
      self.drop_emb = nn.Dropout(cfg['drop_rate'])
      self.trf_blocks = nn.Sequential(
          *[Transformer(cfg) for _ in range(cfg['n_layers'])]
      )
      self.final_norm = Layer_Norm(cfg['emb_dim'])
      self.out_head = nn.Linear(cfg['emb_dim'],cfg['vocab_size'],bias=False)


  def forward(self,x):
      batch_size, seq_len = x.shape
      tok_embeds = self.tok_embedding(x)
      pos_embeds = self.pos_embedding(torch.arange(seq_len,device=x.device))
      x = tok_embeds + pos_embeds
      x = self.drop_emb(x)
      x = self.trf_blocks(x)
      x = self.final_norm(x)
      logits = self.out_head(x)
      return logits

In [114]:
torch.manual_seed(123)
x = torch.tensor([[6109,3626,6100,345],[6109,1110,6622,257]])
model = GPTModel(GPT_CONFIG_124M)
op = model.forward(x)
print(op.shape)

torch.Size([2, 4, 50257])


In [115]:
print(op)

tensor([[[-0.6212, -1.0549, -0.7801,  ..., -0.1406, -1.4150,  0.3253],
         [ 0.5819, -1.4818,  0.0543,  ..., -0.7155, -1.3151,  0.1606],
         [-0.4098, -0.8797, -0.1308,  ..., -0.6430, -1.2205,  0.2275],
         [-0.4459, -0.5298, -0.5528,  ...,  0.7475, -0.2636, -0.2226]],

        [[-0.6213, -0.7357, -0.4994,  ..., -0.4088, -1.1084,  0.4430],
         [ 0.7592, -1.3652,  0.7986,  ...,  0.3807, -1.1525,  0.4100],
         [-0.1499, -1.1212, -0.5046,  ..., -0.5478, -0.8977,  0.4323],
         [-0.5898, -1.3921,  0.1194,  ...,  0.6098, -1.3501,  0.8240]]],
       grad_fn=<UnsafeViewBackward0>)


In [116]:
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,}")

163,009,536


In [117]:
model.out_head.weight.shape

torch.Size([50257, 768])

In [118]:
model.tok_embedding.weight.shape

torch.Size([50257, 768])

Reuse the tok_emb parameters for the out_head parameters

In [119]:
final_total_params = total_params - sum(p.numel() for p in model.out_head.parameters())
print(f"{final_total_params:,}")

124,412,160


In [120]:
total_size_bytes = total_params *4
total_size_mb = total_size_bytes/(1024*1024)
print(f"{total_size_mb:,} MB")

621.83203125 MB


In [121]:
def generate_text_simple(model,idx,max_new_tokens,context_length):
  for _ in range(max_new_tokens):
    idx_cond = idx[:,-context_length:]
    with torch.no_grad():
      logits = model(idx_cond)
    logits = logits[:,-1,:]
    probas = torch.softmax(logits,dim=-1)
    idx_next = torch.argmax(probas,dim=-1,keepdim=True)
    idx= torch.cat((idx,idx_next),dim=1)

  return idx


In [122]:
import tiktoken

In [123]:
tokenizer = tiktoken.get_encoding('gpt2')

In [138]:
start_context ="Hello, I am"
encoded = tokenizer.encode(start_context)
print(encoded)
encoded = torch.tensor(encoded).unsqueeze(0)
# model.eval()
op=generate_text_simple(model,encoded,6,GPT_CONFIG_124M['context_length'])

[15496, 11, 314, 716]


In [139]:
tokenizer.decode(op.squeeze(0).tolist())

'Hello, I am39 proliferation proliferation proliferation nation assuming'