In [3]:
import torch
import torch.nn as nn
import math

In [4]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(InputEmbeddings, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)



# Now we do postional encoding

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        Pe = torch.zeros(max_seq_length, d_model)
        Postions = torch.arange(0,max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        Pe[:,0::2] = torch.sin(Postions * div_term)
        Pe[:,1::2] = torch.cos(Postions * div_term)
        self.register_buffer('Pe', Pe.unsqueeze(0))
    def forward(self, x):
      return x + self.Pe[:, :x.size(1)]


In [6]:
# Initialize layers
embedding_layer = InputEmbeddings(1000, 6)
positional_encoding = PositionalEncoding(6, 1000)

# Example input
input_ids = torch.randint(0, 1000, (1, 1000))  # shape (batch_size, seq_length)
print(input_ids.shape)
# Apply embedding
embedded = embedding_layer(input_ids)  # shape (1, 5, 5)

# Apply positional encoding
output = positional_encoding(embedded)

print(output)

torch.Size([1, 1000])
tensor([[[ 0.3100,  3.6809,  1.5058,  1.7296,  1.6718, -0.3308],
         [ 1.0574,  6.8504, -2.6296,  1.4777,  4.8250, -0.3278],
         [ 4.3780, -0.7868,  3.2376,  4.3001,  2.9273,  1.0790],
         ...,
         [ 1.7624,  0.5264,  2.5049,  3.9940, -2.8293, -0.7895],
         [ 1.3450,  0.6577, -2.2282,  0.8381,  4.8667, -2.7859],
         [ 0.0357,  1.2243, -1.6068,  2.3189,  1.6121, -5.1969]]],
       grad_fn=<AddBackward0>)


# MultiHead Attention

In [8]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model,max_seq_length):
    super(MultiHeadAttention, self).__init__()
    self.d_model = d_model
    self.max_seq_length
    self.head_dim = d_model // num_heads
    self.query_linear = nn.Linear(d_model, d_model, bias=False)
    self.key_linear = nn.Linear(d_model, d_model, bias=False)
    self.value_linear = nn.Linear(d_model, d_model, bias=False)
    self.output_linear = nn.Linear(d_model, d_model)

  def split_heads(self, x, batch_size):
    seq_length = x.size(1)
    # we split the diemntions into multiple heads
    x = x.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
    # Reshuffle in correct order for the attention
    return x.permute(0, 2, 1, 3)

  def compute_attention(self, query, key, value, mask=None):
    scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attention_weights, value)

  def combine_heads(self, x, batch_size):
      seq_length = x.size(1)
      # Combine heads back to (batch_size, seq_length, d_model)
      x = x.permute(0, 2, 1, 3).contiguous()
      return x.view(batch_size, -1, self.d_model)

  def forward(self, query, key, value, mask=None):
      batch_size = query.size(0)

      # Build the forward pass
      query = self.split_heads(self.query_linear(query), batch_size)
      key = self.split_heads(self.key_linear(key), batch_size)
      value = self.split_heads(self.value_linear(value), batch_size)

      attention_weights = self.compute_attention(query, key, value, mask)
      output = self.combine_heads(attention_weights, batch_size)
      return self.output_linear(output)


# Decoder Model

In [1]:
class Decoder(nn.Module):

In [None]:
class Encoder(nn.Module):