In [13]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader



In [14]:
# Dummy dataset: generates random sequences of token IDs
class DummySequenceDataset(Dataset):
    def __init__(self, vocab_size=100, seq_len=10, num_samples=5):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return torch.randint(0, self.vocab_size, (self.seq_len,))



In [15]:
# Learnable positional encoding module
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.position_embedding = nn.Parameter(torch.randn(1, max_len, d_model))

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        seq_len = x.size(1)
        return x + self.position_embedding[:, :seq_len, :]



In [12]:
# Minimal demo showing learnable positional encoding in action
def run_demo():
    batch_size = 2
    seq_len = 10
    vocab_size = 100
    d_model = 16

    # Dummy token sequences
    dataset = DummySequenceDataset(vocab_size=vocab_size, seq_len=seq_len)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Embedding + Positional Encoding
    embedding = nn.Embedding(vocab_size, d_model)
    pos_encoding = LearnablePositionalEncoding(max_len=seq_len, d_model=d_model)

    for batch in loader:
        print("Token IDs:\n", batch)
        embedded = embedding(batch)  # [batch, seq_len, d_model]
        print("\nToken Embeddings Shape:", embedded.shape)

        output = pos_encoding(embedded)
        print("\nOutput After Adding Learnable Positional Encoding:\n", output[0, :, :4])  # show 4 dims of first sample
        break

run_demo()


Token IDs:
 tensor([[23, 45,  2, 33, 22,  8, 32, 41, 45, 36],
        [50, 71, 93, 65, 33, 68, 41, 16, 30, 61]])

Token Embeddings Shape: torch.Size([2, 10, 16])

Output After Adding Learnable Positional Encoding:
 tensor([[ 0.8761, -0.5855, -0.4641, -0.8456],
        [ 0.5093, -0.3224, -0.2590, -2.6530],
        [ 1.0682,  1.1623, -3.2055, -1.7464],
        [ 1.1452, -1.7065,  0.2862, -0.2955],
        [ 1.2297, -0.1780, -1.4761,  0.2236],
        [-2.1803, -1.9979, -2.9829, -0.3825],
        [ 0.7024,  1.8152,  0.8138, -0.0145],
        [ 0.4441,  1.5403, -1.4557,  2.4941],
        [ 2.1826, -2.1226, -0.2999,  1.3709],
        [-1.8351,  0.4538,  2.1026,  0.8340]], grad_fn=<SliceBackward0>)
