In [None]:
!pip install torch



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# Positional Encoding Layer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

# Transformer Model
class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, d_model=128, nhead=8, num_layers=3, dim_feedforward=512, max_len=100):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(input_dim, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)

        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_layers,
                                          num_decoder_layers=num_layers, dim_feedforward=dim_feedforward, batch_first=True)

        self.fc_out = nn.Linear(d_model, output_dim)

    def forward(self, src, tgt):
        src = self.embedding(src)
        tgt = self.embedding(tgt)
        src = self.positional_encoding(src)
        tgt = self.positional_encoding(tgt)

        output = self.transformer(src, tgt)
        output = self.fc_out(output)
        return output

# Example usage
input_dim = 1000  # Vocabulary size
output_dim = 1000 # Output vocabulary size

# Create a model instance
model = Transformer(input_dim, output_dim)

# Dummy inputs
src = torch.randint(0, 1000, (1, 10))  # (batch_size, sequence_length)
tgt = torch.randint(0, 1000, (1, 10))

# Forward pass
output = model(src, tgt)
print("Output shape:", output.shape)  # Expected shape: (batch_size, seq_length, output_dim)


Output shape: torch.Size([1, 10, 1000])
