<a href="https://colab.research.google.com/github/archyyu/GPT-from-MLP-to-RNN-to-Transformer/blob/main/GPT_by_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
class Head(nn.Module):
  def __init__(self, input_size, sequence_length, head_size):
    super(Head, self).__init__()
    self.C = input_size
    self.L = sequence_length
    self.head_size = head_size
    self.q = nn.Linear(self.C, head_size, bias=False)
    self.k = nn.Linear(self.C, head_size, bias=False)
    self.v = nn.Linear(self.C, head_size, bias=False)

  def forward(self, x):
    q = self.q(x)
    k = self.k(x)
    v = self.v(x)

    wei = q @ k.transpose(-2, -1)
    tril = torch.tril(torch.ones(self.L, self.L))
    wei = wei.masked_fill(tril == 0, float('-inf'))

    out = wei @ v
    return out


In [3]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, input_size, sequence_length, head_size):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads

    self.heads = nn.ModuleList([
        Head(input_size, sequence_length, head_size) for _ in range(num_heads)
    ])

    self.final_linear = nn.Linear(num_heads * head_size, input_size)

  def forward(self, x):

    head_outputs = [head(x) for head in self.heads]
    concatenated_output = torch.cat(head_outputs, dim=-1)
    final_output = self.final_linear(concatenated_output)

    return final_output

In [7]:
class BlockAttention(nn.Module):
  def __init__(self, num_heads, input_size, sequence_length, head_size):
    super(BlockAttention, self).__init__()
    self.multiheads = MultiHeadAttention(num_heads, input_size, sequence_length, head_size)
    self.norm = nn.LayerNorm(input_size)

  def forward(self, x):
    inter_result = self.multiheads(x)
    final_output = self.norm(x + inter_result)
    return final_output

In [9]:
import torch

batch_size = 4
sequence_length = 8
input_size = 32
num_heads = 4

head_size = 12

# Example input tensor x = torch.randn(B,T,C)
x = torch.randn(batch_size, sequence_length, input_size)

# Initialize the multi Head
# head = MultiHeadAttention(num_heads, input_size, sequence_length, head_size)

#
head = BlockAttention(num_heads, input_size, sequence_length, head_size)

# Forward pass
print(x.shape)
output = head(x)

print(output.shape)

torch.Size([4, 8, 32])
torch.Size([4, 8, 32])
