In [4]:
# Simple Transformer using PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
#Toy data generation
vocab_size= 100
seq_length=10
batch_size=32

X=torch.randint(0, vocab_size, (batch_size, seq_length)).to(device)
Y= X.clone().to(device)

In [6]:
type(X)

torch.Tensor

In [7]:
X.shape

torch.Size([32, 10])

In [8]:
X[0]

tensor([24, 12, 79, 22, 66, 55, 90, 40, 36, 68])

In [23]:
# Model Definition

class TransformerModel(nn.Module):
  def __init__(self,vocab_size,embed_dim,num_heads, num_layers, ff_dim, max_len=100):
    super().__init__()
    self.embedding= nn.Embedding(vocab_size, embed_dim)
    self.pos_encoding= nn.Parameter(torch.zeros(1,max_len, embed_dim))

    # Encoder + Decoder
    self.transformer= nn.Transformer(
        d_model=embed_dim,
        nhead=num_heads,
        num_encoder_layers=num_layers,
        num_decoder_layers=num_layers,
        dim_feedforward=ff_dim,
        dropout=0.1,
        batch_first=True

    )
    self.fc_out= nn.Linear(embed_dim, vocab_size)

  def forward(self, src, tgt):
    #src and tgt [bacth, seq_len]
    src_emb= self.embedding(src) + self.pos_encoding[:, :src.size(1), :]
    tgt_emb= self.embedding(tgt) + self.pos_encoding[:, : tgt.size(1), :]

    out= self.transformer(src_emb, tgt_emb)
    return self.fc_out(out)


In [24]:
# training loop
model= TransformerModel(
    vocab_size= vocab_size,
    embed_dim=64,
    num_heads=4,
    num_layers=2,
    ff_dim=2048
).to(device)


In [25]:
criterion= nn.CrossEntropyLoss()
optimizer= optim.Adam(model.parameters(), lr=0.001)


In [26]:
for epoch in range(10):
  optimizer.zero_grad()

  output= model(X,Y[:,:-1])
  loss= criterion(output.reshape(-1, vocab_size),Y[:,1:].reshape(-1))
  loss.backward()
  optimizer.step()

  print(f"Epoch {epoch+1}. loss: {loss.item():.4f}")

Epoch 1. loss: 4.8237
Epoch 2. loss: 4.4637
Epoch 3. loss: 4.2098
Epoch 4. loss: 3.9642
Epoch 5. loss: 3.7837
Epoch 6. loss: 3.6050
Epoch 7. loss: 3.4409
Epoch 8. loss: 3.3192
Epoch 9. loss: 3.1858
Epoch 10. loss: 3.0709


In [29]:
# Inference Demo
test_seq= torch.randint(0, vocab_size, (1,seq_length)).to(device)
print("Input IDs:", test_seq)

with torch.no_grad():
    pred= model(test_seq, test_seq[:,:-1])
    predicted_tokens= pred.argmax(dim=-1)
    print("predicted IDs:", predicted_tokens)

Input IDs: tensor([[77, 36, 12, 85, 40, 65, 19, 85, 23, 51]])
predicted IDs: tensor([[ 1, 90, 68, 74, 36, 55, 83, 74, 36]])
