In [107]:
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader, random_split

In [108]:
### Generate Synthetic Dataset ###
X = np.random.randn(1000, 2).astype(np.float32)
Y = (np.sum(X, axis=1) > 0).astype(np.float32) # if sum(x_features) > 0 then 1, else 0
X = torch.tensor(X)
Y = torch.tensor(Y)

dataset = TensorDataset(X, Y) # convert dataset to tensor for pytorch

train_size = int(0.8 * len(dataset)) # 80-20 split
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Define a Simple NN ###
model = nn.Sequential(
    nn.Linear(2, 32),
    nn.ReLU(),
    nn.Linear(32, 2),
    nn.Softmax(dim=1) # get softmax probabilities per class
)

In [109]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### Training Loop ###
epochs = 10
model.train()
for epoch in range(epochs):
    for x_batch, y_batch in train_loader:
        train_batch_preds = model(x_batch)
        loss = loss_fn(train_batch_preds, y_batch.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


### Evaluation on Test Set ###
model.eval()
total_accuracy = 0
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        test_batch_probs = model(x_batch)
        test_batch_preds = torch.argmax(test_batch_probs, axis=1)
        correct = (test_batch_preds == y_batch).sum()
        batch_accuracy = correct / len(y_batch)
        total_accuracy += batch_accuracy

    total_accuracy /= len(test_loader)
    print(f"total accuracy: {total_accuracy * 100}%")

Epoch 1, Loss: 0.6036
Epoch 2, Loss: 0.5671
Epoch 3, Loss: 0.5486
Epoch 4, Loss: 0.5025
Epoch 5, Loss: 0.4909
Epoch 6, Loss: 0.4559
Epoch 7, Loss: 0.4331
Epoch 8, Loss: 0.4353
Epoch 9, Loss: 0.4284
Epoch 10, Loss: 0.4376
total accuracy: 98.66071319580078%
