In [1]:
import torch
import numpy as np

In [40]:
data = np.loadtxt('data/iris.txt')
np.random.shuffle(data)

x = torch.tensor(data[:, :4], dtype=torch.float32)
y = torch.tensor(data[:, -1], dtype=torch.float32)
y = y.reshape(len(y), 1)

split_index = int(len(data) * 0.8) 

x_train = x[:split_index]
x_test = x[split_index:]

y_train = y[:split_index]
y_test = y[split_index:]

In [3]:
print(f"X_train shape: {x_train.shape}")
print(f"X_test shape: {x_test.shape}")
print(f"Y_train shape: {y_train.shape}")
print(f"Y_test shape: {y_test.shape}")

X_train shape: torch.Size([120, 4])
X_test shape: torch.Size([30, 4])
Y_train shape: torch.Size([120, 1])
Y_test shape: torch.Size([30, 1])


In [70]:
weights1 = torch.rand((4, 20), requires_grad=True)
biases1 = torch.rand(20, requires_grad=True)

weights2 = torch.rand((20, 1), requires_grad=True)
biases2 = torch.rand(1, requires_grad=True)

In [62]:
learning_rate = 0.0001
batch_size = 10

In [71]:
for epoch in range(500):
    total_loss = 0
    for batch_i in range(0, len(x_train), batch_size):
        x_batch = x_train[batch_i:batch_i+batch_size]
        y_batch = y_train[batch_i:batch_i+batch_size]

        h = torch.mm(x_batch, weights1) + biases1
        y_hat = torch.mm(h, weights2) + biases2
        
        loss = sum((y_hat - y_batch) ** 2)
        
        loss.backward()
        total_loss += loss
        total_loss.detach()
        
        weights2 = weights2 - learning_rate * weights2.grad
        biases2 = biases2 - learning_rate * biases2.grad
        weights1 = weights1 - learning_rate * weights1.grad
        biases1 = biases1 - learning_rate * biases1.grad
        
        weights1 = weights1.detach().requires_grad_()
        biases1 = biases1.detach().requires_grad_()
        weights2 = weights2.detach().requires_grad_()
        biases2 = biases2.detach().requires_grad_()
    
    print(f"Epoch = {epoch} | Loss = {total_loss}")

Epoch = 0 | Loss = tensor([62462.4258], grad_fn=<AddBackward0>)
Epoch = 1 | Loss = tensor([99.2819], grad_fn=<AddBackward0>)
Epoch = 2 | Loss = tensor([46.7582], grad_fn=<AddBackward0>)
Epoch = 3 | Loss = tensor([41.5431], grad_fn=<AddBackward0>)
Epoch = 4 | Loss = tensor([37.5024], grad_fn=<AddBackward0>)
Epoch = 5 | Loss = tensor([33.8672], grad_fn=<AddBackward0>)
Epoch = 6 | Loss = tensor([30.6240], grad_fn=<AddBackward0>)
Epoch = 7 | Loss = tensor([27.7424], grad_fn=<AddBackward0>)
Epoch = 8 | Loss = tensor([25.1889], grad_fn=<AddBackward0>)
Epoch = 9 | Loss = tensor([22.9322], grad_fn=<AddBackward0>)
Epoch = 10 | Loss = tensor([20.9427], grad_fn=<AddBackward0>)
Epoch = 11 | Loss = tensor([19.1932], grad_fn=<AddBackward0>)
Epoch = 12 | Loss = tensor([17.6586], grad_fn=<AddBackward0>)
Epoch = 13 | Loss = tensor([16.3157], grad_fn=<AddBackward0>)
Epoch = 14 | Loss = tensor([15.1434], grad_fn=<AddBackward0>)
Epoch = 15 | Loss = tensor([14.1222], grad_fn=<AddBackward0>)
Epoch = 16 | Lo

In [73]:
h = torch.mm(x_test, weights1) + biases1
y_hat = torch.tanh(torch.mm(h, weights2) + biases2)

print(y_hat)
print(y_test)

weights1 = weights1.detach().requires_grad_()
biases1 = biases1.detach().requires_grad_()
weights2 = weights2.detach().requires_grad_()
biases2 = biases2.detach().requires_grad_()

tensor([[ 0.2220],
        [ 0.5404],
        [ 0.6919],
        [-0.9074],
        [ 0.7238],
        [-0.7809],
        [ 0.5711],
        [ 0.5954],
        [-0.8157],
        [ 0.0651],
        [ 0.7948],
        [-0.8606],
        [ 0.3387],
        [-0.7826],
        [ 0.5846],
        [ 0.1818],
        [ 0.0635],
        [ 0.2486],
        [ 0.5709],
        [-0.8240],
        [ 0.7310],
        [-0.8111],
        [-0.8183],
        [ 0.2103],
        [ 0.6524],
        [-0.8823],
        [ 0.0937],
        [ 0.5304],
        [ 0.6406],
        [-0.7694]], grad_fn=<TanhBackward>)
tensor([[ 0.],
        [ 1.],
        [ 1.],
        [-1.],
        [ 1.],
        [-1.],
        [ 1.],
        [ 1.],
        [-1.],
        [ 0.],
        [ 1.],
        [-1.],
        [ 0.],
        [-1.],
        [ 1.],
        [ 0.],
        [ 0.],
        [ 0.],
        [ 1.],
        [-1.],
        [ 1.],
        [-1.],
        [-1.],
        [ 0.],
        [ 1.],
        [-1.],
        [ 0.],
