In [1]:
# https://johanwind.github.io/2022/07/06/dln_classifier.html

In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
torch.manual_seed(0)
import numpy as np

In [3]:
# Set constants
d = 20
k = 10
n = 100

In [4]:
# Generate data
X = torch.rand(n*2, d)
y = ((torch.atan2(X[:, 0], X[:, 1])/np.pi+1)/2*k).long()
# train test split
X_train, y_train = X[:n, :], y[:n]
X_test, y_test = X[n:, :], y[n:]


In [5]:
class SimpleLinear(nn.Module):
    def __init__(self, d, k, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.linear = nn.Linear(d, k, bias=False)
    def forward(self, x):
        return self.linear(x)

# test layer
model = SimpleLinear(d, k)
pred_y = model(X_train)
print(pred_y.shape, y_train.shape)
print(model.linear.weight.shape)

torch.Size([100, 10]) torch.Size([100])
torch.Size([10, 20])


In [6]:
device = torch.device("cuda")
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
loss_value = 10
X_train, y_train = X_train.to(device), y_train.to(device)

iteration = 1
while True:
    if loss_value < 0.1:
        break
    model.train()
    pred_y = model(X_train)
    loss = F.cross_entropy(pred_y, y_train)
    optimizer.zero_grad()
    loss.backward()
    loss_value = loss.item()
    optimizer.step()
    print(iteration, loss_value)
    iteration += 1

1 2.336678981781006
2 2.206367015838623
3 2.089176654815674
4 1.9845492839813232
5 1.8916938304901123
6 1.809648036956787
7 1.7373530864715576
8 1.673719882965088
9 1.6176878213882446
10 1.5682605504989624
11 1.5245338678359985
12 1.4857048988342285
13 1.4510730504989624
14 1.4200352430343628
15 1.3920773267745972
16 1.3667631149291992
17 1.343725323677063
18 1.32265305519104
19 1.3032853603363037
20 1.2854030132293701
21 1.2688194513320923
22 1.2533775568008423
23 1.2389445304870605
24 1.225406289100647
25 1.2126659154891968
26 1.2006398439407349
27 1.1892560720443726
28 1.1784521341323853
29 1.1681740283966064
30 1.1583740711212158
31 1.1490108966827393
32 1.1400474309921265
33 1.1314513683319092
34 1.1231938600540161
35 1.1152487993240356
36 1.1075935363769531
37 1.1002070903778076
38 1.0930705070495605
39 1.086167573928833
40 1.0794821977615356
41 1.0730006694793701
42 1.0667107105255127
43 1.0606002807617188
44 1.0546585321426392
45 1.0488765239715576
46 1.0432448387145996
47 1.03

In [7]:
# Calculate accuracy
model.eval()
X_test, y_test = X_test.to(device), y_test.to(device)
acc = 0

with torch.no_grad():
    y_pred = model(X_test)
    y_pred = torch.argmax(y_pred, dim=1)
    correct = torch.sum(y_pred == y_test).item()
    print(correct/X_test.shape[0])

0.78
