In [None]:
import torch
import sympy as sp
from IPython.display import display
import matplotlib.pyplot as plt
import matplotx
plt.style.use(matplotx.styles.pacoty)

In [None]:
# 12 feature input . 4 class labelling
class Dataset():
    def __init__(self, num_examples=100):
        """
        4-class Classification dataset
        W : R12 -> R4
        y = argmax(X@W + b)
        """

        torch.manual_seed(10)
        
        # Mostly sparse W with small scale and precision to represent them
        # this could be inferenced with fp4 if our model achieves this.
        self._W = torch.tensor([
            [-0.15, 0.0,   0.0,   0.0],
            [ 0.0, 0.15,  0.0,   0.0],
            [ 0.05, 0.0,  0.0,   0.0],
            [ 0.0, -0.1,  0.0,   0.0],
            [ 0.0, 0.0,   0.0,   0.22],
            [ 0.0, 0.0,   0.0,   0.0],
            [ 0.0, 0.0,   0.0,  0.21],
            [ 0.0, 0.0,   0.05,  0.0],
            [ 0.0, 0.0,   0.0,   0.0],
            [ 0.01, 0.0,  0.0,   0.0],
            [ 0.0, 0.0,   0.0,   0.0],
            [ 0.0, 0.0,   0.18,  0.0]
        ], dtype=torch.float32)

        self._bias = 0.1 * torch.tensor([0.2, -0.1, 0.3, -0.05], dtype=torch.float32)

        self.x = torch.randn(num_examples, 12, dtype=torch.float32)
        self.y = self.x @ self._W + self._bias
        self.y = torch.argmax(self.y, dim=-1)

    def show_answer(self, show=4):
        from IPython.display import display

        sp_W = sp.Matrix(self._W)
        sp_x = [sp.Symbol(f"x{i}") for i in range(self._W.shape[0])]
        sp_y = sp.Symbol("y")
        sp_func = sp.Function("argmax")
        sp_b = sp.Matrix(self._bias)

        sp_x = sp.Matrix(sp_x).T
        sp_b = sp.Matrix(sp_b).T

        sp_eq = sp.Eq(sp_y, sp_func(sp_x @ sp_W + sp_b), evaluate=False)

        print("True Pop Params")
        display(sp.Eq(sp.Symbol("W"), sp_W, evaluate=False))
        display(sp.Eq(sp.Symbol("b"), sp_b, evaluate=False))
        display(sp_eq)

# Example usage
data = Dataset()
data.show_answer()


In [None]:
plt.hist(data.y)

In [None]:
#simple NN 1 layer same shape
# register gelu

torch.gelu = torch.nn.GELU()
x = torch.linspace(-10,10,100)
y = torch.gelu(x)

plt.plot(x,y)

In [None]:
class SuperSimpleNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.W = torch.randn(12,4,dtype=torch.float32) * 0.001
        self.b = torch.randn(1,4,dtype=torch.float32) * 1e-6
        
        self.W = torch.nn.Parameter(self.W)
        self.b = torch.nn.Parameter(self.b)

    def forward(self,x):
        out = torch.gelu(x@self.W + self.b)
        return out

simple_model = SuperSimpleNN()

In [None]:
print("At init")
with torch.no_grad():
    for p in simple_model.named_parameters():
        display(sp.Eq(sp.Symbol(f"{p[0]}_init"), sp.Matrix(p[1]),evaluate=False))

In [None]:
x_examples = data.x.to("cuda")
y_examples = data.y.to("cuda")
torch.set_float32_matmul_precision("high")

@torch.compile
def train(epochs:int = 100_000):
    model = SuperSimpleNN()
    model.to("cuda")

    from tqdm import tqdm
    pbar = tqdm(range(epochs))
    optim = torch.optim.AdamW(model.parameters(),weight_decay=1e-3)

    for epoch in pbar:
        out = model(x_examples)
        loss = torch.nn.functional.cross_entropy(input = out, target = y_examples)
    
        if epoch % 10_000 == 0:
            pbar.set_postfix(loss=loss.item())
    
        optim.zero_grad()
        loss.backward()
        optim.step()
    return model


model = train()

In [None]:
with torch.no_grad():
    response = torch.argmax(model(x_examples),dim=-1)
    print("Score:k",torch.sum(response==y_examples)*1.0)
    
    sp_Wresult = sp.Matrix(model.W.to("cpu"))
    sp_Bresult = sp.Matrix(model.b.to("cpu"))
    
    display(sp.Eq(sp.Symbol("W_result"), sp_Wresult,evaluate=False))
    display(sp.Eq(sp.Symbol("B_result"), sp_Bresult,evaluate=False))
    
    print("==========")
    sp_Wgrad = sp.Matrix(model.W.grad.to("cpu"))
    sp_bgrad = sp.Matrix(model.b.grad.to("cpu"))

    display(sp.Eq(sp.Symbol("Wgrad_result"), sp_Wgrad,evaluate=False))
    display(sp.Eq(sp.Symbol("Bgrad_result"), sp_bgrad,evaluate=False))
    
    