### Target: 20 mins

In [1]:
import torch
import numpy as np
import torch.nn as nn

In [2]:
x, y  = torch.randint(0, 27, (10, 3)), torch.randint(0, 27, (10, ))
embd_dim = 3
embedding = torch.rand(27, embd_dim)
x_enc = embedding[x]

In [3]:
class LinearLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.rand(in_features, out_features))
        self.bias = torch.nn.Parameter(torch.rand(out_features))
    
    def forward(self, x):
        return x @ self.weight + self.bias # (10, 3) @ (3, 10) + (10, ) -> (10, 10)

class LayerNorm(nn.Module):
    def __init__(self, in_features: int, eps: float):
        super().__init__()
        self.gamma = torch.nn.Parameter(torch.ones(in_features))
        self.beta = torch.nn.Parameter(torch.zeros(in_features))
        self.eps = eps
    
    def forward(self, x):
        x_norm = (x - x.mean(dim = 0)) / ((x.var(dim = 0) + self.eps) ** 0.5)
        return x_norm * self.gamma + self.beta
    

In [4]:
class MLP(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, out_features: int, eps: float):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.eps = eps

        self.model = nn.Sequential(
            LinearLayer(self.in_features, self.hidden_features),
            LayerNorm(hidden_features, self.eps),
            nn.Tanh(),
            
            LinearLayer(self.hidden_features, self.hidden_features),
            LayerNorm(hidden_features, self.eps),
            nn.Tanh(),

            LinearLayer(self.hidden_features, self.hidden_features),
            LayerNorm(self.hidden_features, self.eps),
            nn.Tanh(),

            LinearLayer(self.hidden_features, self.out_features)
        )
    
    def forward(self, x):
        return self.model(x)

In [5]:
model = MLP(in_features = x_enc.shape[-1], out_features = 27, hidden_features = 10, eps = 10e-4)
optimizer = torch.optim.Adam(params = model.parameters(), lr = 0.01)
x = x_enc.view(-1, 3)
for i in range(100):

    logits = model(x)
    exp_logits = torch.exp(logits)
    probs = exp_logits / torch.sum(exp_logits, dim = 1, keepdim=True)
    loss = -torch.mean(torch.log(probs[torch.arange(0, len(y)), y]))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 10 == 0:
        print(f"Loss: {loss:.4f}")

Loss: 3.5873
Loss: 2.5153
Loss: 1.9258
Loss: 1.5509
Loss: 1.2390
Loss: 0.9643
Loss: 0.7618
Loss: 0.5924
Loss: 0.4883
Loss: 0.4250
