<a href="https://colab.research.google.com/github/aslanbakirov/AlgorithmVisualizer/blob/gh-pages/Aslan_x_Cohere_PyTorch_Transformer_Challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import copy
import torch
import math
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter

# Reference paper: Attention is All You Need! https://arxiv.org/abs/1706.03762


**Helper functions, don't worry about these.**

In [None]:
class GPT2Config(object):
    def __init__(
            self,
            vocab_size_or_config_json_file=50257,
            n_positions=1024,
            n_ctx=1024,
            n_embd=768,
            n_layer=6,
            n_head=4,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
    ):
        self.vocab_size = vocab_size_or_config_json_file
        self.n_ctx = n_ctx
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = Parameter(w)
        self.bias = Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x

class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = gelu

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return h2

**Your implementation goes here:**

In [None]:

import torch.nn.functional as F
class MultiheadAttention(nn.Module):
    def __init__(self, embed_size, n_ctx, config):
        super(MultiheadAttention, self).__init__()
        assert embed_size % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = embed_size
        self.linear_q = Conv1D(embed_size, embed_size)
        self.linear_k = Conv1D(embed_size, embed_size)
        self.linear_v = Conv1D(embed_size, embed_size)
        self.c_proj = Conv1D(embed_size, embed_size)

    def forward(self, x):
        q_size, k_size, v_size= embed_size // self.n_head
        q= self.linear_q(x) #[batch, seq_length, embed_dim] 
        k= self.linear_k(x) # [batch, seq_length, embed_dim]
        v= self.linear_v(x) # [batch, seq_length, embed_dim]

        q_head=torch.reshape(q,batch_size, seq_length, self.n_head, q_size) #. [batch, seq_length, n_head, embed_dim/n_heads]
        k_head=torch.reshape(q,batch_size, seq_length, self.k_head, k_size) # [batch, seq_length, n_head, embed_dim/n_heads]
        v_head=torch.reshape(q,batch_size, seq_legth, self.v_head, v_size) # [batch, seq_length, n_head, embed_dim/n_heads]

        k_head= torch.transpose(k_head)  # [batch_size, n_heads, seq_length, embed_dim/n_heads]
        scores= torch.matmul(q_head.shape, k_head) ## [batch_size, n_heads,seq_length , seq_length]
        scores= scores//math.sqrt(q_size) ## [batch_size, n_heads, seq_length , seq_length]
        scores=F.softmax(scores)
        scores= torch.matmul(scores, v_head) ## [batch_size, n_heads,seq_length, embed_dim/n_heads]
        scores= torch.cat(scores).t().contigous() ## [batch_size, seq_length, embed_dim]

        return scores
  
class Block(nn.Module):
  def __init__(self, n_ctx, config):
      super(Block, self).__init__()
      n_embed = config.n_embd
      self.ln_1 = LayerNorm(n_embed, eps=config.layer_norm_epsilon)
      self.attn = MultiheadAttention(n_embed, n_ctx, config)
      self.ln_2 = LayerNorm(n_embed, eps=config.layer_norm_epsilon)
      self.mlp = MLP(4 * n_embed, config)

  def forward(self, x): # x: [batch, seq_length, embed_dim]
      x = n_embed(x)
      x = self.attn(x)
      x = self.ln_1(x)
      x= self.mlp(x)
      x= self.ln_2(x)
      x=F.softmax(x)
      print(x.shape)
      return x


**GPT2 Model**

In [None]:
class GPT2Model(nn.Module):
    def __init__(self, config):
        super(GPT2Model, self).__init__()
        self.n_layer = config.n_layer
        self.n_embd = config.n_embd
        self.n_vocab = config.vocab_size

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        block = Block(config.n_ctx, config)
        self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        embed_shape = self.wte.weight.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)

        self.loss_func = nn.CrossEntropyLoss(ignore_index=-1)


    def forward(self, input_ids, lm_labels):
 
        position_ids = torch.arange(0, input_ids.size(-1), dtype=torch.long,
                                    device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds
        for block in self.h:
            hidden_states = block(hidden_states)
        hidden_states = self.ln_f(hidden_states)
        output_shape = input_shape + (hidden_states.size(-1),)
        hidden_states = hidden_states.view(*output_shape)
        lm_logits = self.decoder(hidden_states)
        loss = self.loss_func(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
        return loss


**Time to test it!**

**Here is a naive training loop and the loss should go down**

In [None]:

config = GPT2Config()
model = GPT2Model(config)
dum_input = torch.range(1,config.n_ctx,dtype=torch.int64)
dum_input = torch.reshape(dum_input,(1,config.n_ctx))


#train loop
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
running_loss = 0.0
for i in range(50):
    # get the inputs; data is a list of [inputs, labels]
    inputs = torch.range(1, config.n_ctx,dtype=torch.int64)
    inputs = torch.reshape(dum_input,(1,config.n_ctx)) 
    labels = inputs + 1

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    loss = model(input_ids=inputs,lm_labels=labels)
    loss.backward()
    optimizer.step()

    print("step ",i ," loss: ", loss.item())



  after removing the cwd from sys.path.
  del sys.path[0]


torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
step  0  loss:  10.988716125488281
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
step  1  loss:  10.98796558380127
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
step  2  loss:  10.986539840698242
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
step  3  loss:  10.984503746032715
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
torch.Size([1, 1024, 768])
step  4  loss:  10.98192024230957
torch.Size([1, 10