In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from sklearn.datasets import make_moons
X, y = make_moons(n_samples=2000, noise=0.2, random_state=42)
X.shape

(2000, 2)

In [3]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

In [4]:
y[1]

tensor(0.)

In [5]:
class MoonModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = torch.nn.Sequential(
            torch.nn.Linear(2,50),
            torch.nn.Tanh(),
            torch.nn.Linear(50,1)
        )
    def forward(self,x):
        x = self.hidden(x)
        return x

In [6]:
moonmodel = MoonModel()
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(moonmodel.parameters(),lr=1e-3)

In [9]:
num_epochs = 20
for i in range(num_epochs):
    for x,label in zip(X,y):
        print(f"Epoch {i+1}\n-------------------------------")
        pred = moonmodel(x.unsqueeze(0))
        loss = loss_fn(pred,label.unsqueeze(0).unsqueeze(1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss = loss.item()
        print(f"loss: {loss:>7f}")


Epoch 1
-------------------------------
loss: 0.471577
Epoch 1
-------------------------------
loss: 0.766751
Epoch 1
-------------------------------
loss: 0.621039
Epoch 1
-------------------------------
loss: 0.527114
Epoch 1
-------------------------------
loss: 0.599974
Epoch 1
-------------------------------
loss: 0.698118
Epoch 1
-------------------------------
loss: 0.563138
Epoch 1
-------------------------------
loss: 0.435094
Epoch 1
-------------------------------
loss: 0.903311
Epoch 1
-------------------------------
loss: 0.581494
Epoch 1
-------------------------------
loss: 0.978963
Epoch 1
-------------------------------
loss: 0.618845
Epoch 1
-------------------------------
loss: 0.677681
Epoch 1
-------------------------------
loss: 0.665533
Epoch 1
-------------------------------
loss: 0.806687
Epoch 1
-------------------------------
loss: 0.853930
Epoch 1
-------------------------------
loss: 0.576608
Epoch 1
-------------------------------
loss: 0.971423
Epoch 1
--

In [17]:
theta = torch.concat([param.detach().flatten() for param in moonmodel.parameters()])

In [18]:
theta

tensor([-3.9926e-01,  3.8030e-01,  5.9903e-01,  7.4203e-01,  4.1148e-01,
        -4.3013e-01, -4.7999e-01,  4.4634e-01,  4.0796e-01, -1.5012e-01,
        -5.6980e-01,  8.7072e-01, -7.7815e-01,  2.9247e-01,  3.2276e+00,
        -4.4043e-01,  4.2449e-01, -3.8780e-01, -4.4399e-01,  4.2623e-01,
         3.9536e-01, -5.3465e-01, -5.0965e-01,  3.2366e-01, -2.8064e+00,
        -1.0028e+00, -3.8447e-01,  7.6201e-01, -5.2327e-01,  7.2381e-01,
        -3.5253e+00,  1.0473e-01,  2.9302e+00,  4.9580e-01,  4.4965e-01,
        -6.6472e-01,  4.3593e-01, -5.5129e-01,  5.9775e-01, -2.8826e-01,
         4.2528e-01, -3.3388e-01,  6.1780e-01, -2.8090e-01,  4.4198e-01,
        -3.2565e-01,  4.1477e-01, -3.8488e-01,  4.5526e-01, -2.1376e-01,
         4.9583e-01, -2.4513e-01,  2.9991e+00,  6.7611e-01,  7.2791e-01,
        -3.4707e-01,  4.1972e-01, -2.2841e-01,  4.2142e-01, -2.0416e-01,
         3.1445e-01, -5.8458e-01,  5.2978e-01, -8.0882e-01, -4.7971e-01,
         1.0857e-02,  1.4457e+00,  1.0310e+00,  5.0

In [19]:
dim = theta.numel()

v1 = torch.randn(dim)
v2 = torch.randn(dim)

v1 = v1/v1.norm()
v2 = v2/v2.norm()