In [1491]:
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(8827)
n = 10_000
numbers = torch.randint(0, 30_000, (n,))

In [1492]:
# Split dataset
split_percentage = int(0.9 * len(numbers))
train_dataset = numbers[:split_percentage]
val_dataset = numbers[split_percentage:]

print(train_dataset.shape)
print(val_dataset.shape)

torch.Size([9000])
torch.Size([1000])


In [1604]:
batch_size = 8
context_length = 64


def get_batch(split):
    data = train_dataset if split == 'train' else val_dataset
    random_offsets = torch.randint(len(data) - context_length, (batch_size,))
    xb = torch.stack([data[i:context_length + i] for i in random_offsets])
    yb = xb ** 2

    return xb.float(), yb.float()


xb, yb = get_batch('train')
print(f"xb: {xb.shape}")
print(f"yb: {yb.shape}")

xb: torch.Size([8, 64])
yb: torch.Size([8, 64])


In [1726]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()

        self.sequential = nn.Sequential(
            nn.Linear(context_length, context_length),
            # nn.Linear(context_length, 128),
            # nn.ReLU(),
            # nn.Linear(128, 512),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            # nn.Linear(512, 128),
            # nn.ReLU(),
            # nn.Linear(512, context_length),
        )
        self.ln_norm = nn.LayerNorm(context_length)

    def forward(self, x, y):
        x_mean_val = x.mean(dim=1, keepdim=True)
        x_std_dev = x.std(dim=1, keepdim=True)
        x = (x - x_mean_val) / x_std_dev

        y_mean_val = y.mean(dim=1, keepdim=True)
        y_std_dev = y.std(dim=1, keepdim=True)
        y = (y - y_mean_val) / y_std_dev
        logits = self.sequential(x)
        # logits = self.ln_norm(logits)
        # logits = self.sequential(logits)

        loss = F.mse_loss(logits, y)

        return logits, loss


model = MLP()
logits, loss = model(xb, yb)
# print(logits, loss)
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [1746]:


# training loop
for _ in range(1000):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(loss)

tensor(0.0576, grad_fn=<MseLossBackward0>)


In [1751]:
model.eval()
xb, yb = get_batch('val')
yb_mean_val = yb.mean(dim=1, keepdim=True)
yb_std_dev = yb.std(dim=1, keepdim=True)
logits, loss = model(xb, yb)
print(f"prediction loss: {loss}")
print(f"prediction: {torch.round(logits[:1][0] * yb_std_dev[:1][0] + yb_mean_val[:1][0])}")
# print(f"xb: {xb[:1][0]}")
print(f"yb: {yb[:1][0]}")
model.train()

prediction loss: 0.061485182493925095
prediction: tensor([ 1.2993e+08,  6.7717e+08,  7.1782e+08,  3.7606e+08,  5.5574e+08,
         2.9279e+08,  6.4761e+08, -1.1749e+08,  1.0407e+08, -1.2491e+08,
         7.0917e+08,  2.6072e+08,  1.4688e+08,  1.4141e+08,  2.6037e+08,
         5.0927e+08, -1.9124e+07,  2.4122e+08, -3.6975e+06,  4.7170e+08,
         4.9689e+08,  5.3483e+08,  1.3696e+08,  3.1839e+08, -5.4629e+07,
         4.9811e+08,  4.7726e+08,  2.5151e+08,  4.2580e+08, -1.1610e+08,
         6.9708e+08, -9.0739e+07, -3.7830e+07,  2.7469e+08,  4.6142e+08,
         3.2809e+08,  3.0613e+08, -2.5958e+07,  2.6648e+07, -7.7005e+07,
         6.3536e+08,  8.9144e+07,  4.6075e+08,  6.3454e+08,  3.3495e+08,
         3.2281e+08,  6.5684e+08,  2.9651e+07,  1.7446e+08,  3.9048e+08,
         4.9318e+06,  4.5096e+08,  5.5807e+08,  4.7422e+08,  4.8498e+08,
         4.2955e+08,  2.0177e+08, -2.2773e+07,  4.5633e+07, -9.5346e+07,
        -2.6363e+07,  2.8752e+08,  1.8145e+08,  7.1107e+07],
       grad_f

MLP(
  (sequential): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
  )
  (ln_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)