<a href="https://colab.research.google.com/github/Tanish-Sarkar/Elite-Transformers/blob/main/Module1%20-%20Transformer%20Fundamentals/Simple_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Implementing a Simple Transformer**

In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class SimpleTransformer(nn.Module):
  def __init__(self, vocab_size=1000, d_model=512, n_heads=8):
    super().__init__()
    self.embed = nn.Embedding(vocab_size, d_model)
    self.attn = nn.MultiheadAttention(d_model, n_heads)
    self.ff = nn.Sequential(
        nn.Linear(d_model, 4*d_model),
        nn.ReLU(),
        nn.Linear(4*d_model, d_model)
    )
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self, x):
    x = self.embed(x)
    # dummy positional
    pos = torch.arange(x.size(1), device=x.device).unsqueeze(0).repeat(x.size(0), 1)
    x = x + pos.unsqueeze(-1) * 0.01   # fake positional

    attn_out, _ = self.attn(x,x,x)
    x = self.norm1(x + attn_out)
    ff_out = self.ff(x)
    x = self.norm2(x + ff_out)
    return x

**Test :)**

In [3]:
model = SimpleTransformer()
dummy_input = torch.randint(0, 1000, (4,20))  # batch 4, seq 20
output = model(dummy_input)
print(f"Input size: {dummy_input.size()}")
print(f"Output size: {output.size()}")

Input size: torch.Size([4, 20])
Output size: torch.Size([4, 20, 512])
