In [144]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import pandas as pd

In [145]:
device = torch.device("cuda")
device

device(type='cuda')

In [146]:
# setting seeds
np.random.seed(252)
torch.manual_seed(252)

<torch._C.Generator at 0x1b5559c12f0>

## Model

In [147]:
class SimplePredictor(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs):
        super().__init__()
        self.linear1 = nn.Linear(num_inputs, num_hidden)
        self.bn1 = nn.BatchNorm1d(num_hidden)
        self.fn = nn.LeakyReLU()
        self.d1 = nn.Dropout(0.4)
        self.linear2 = nn.Linear(num_hidden, num_outputs)

    def forward(self, x):
        x = self.linear1(x)
        x = self.bn1(x)
        x = self.fn(x)
        x = self.d1(x)
        x = self.linear2(x)
        return x

In [148]:
def get_accuracy(model, loader):
    correct = 0
    total = 0
    model.eval()
    for x, labels in loader:
        x, labels = x.to(device), labels.to(device)
        output = model(x)
        pred = torch.round(output)
        correct += torch.logical_and(pred.gt(labels.view_as(pred) - 0.1 * labels.view_as(pred)), pred.lt(labels.view_as(pred) + 0.1 * labels.view_as(pred))).sum().item()
        total += x.shape[0]
    return correct / total

## Data

In [149]:
data06 = pd.read_csv("../data/processed/target06.csv")

In [150]:
train_indices = np.random.rand(len(data06))>0.25

numerical_data = torch.from_numpy(data06.values[train_indices,:-1]).float()
targets = torch.from_numpy(data06.values[train_indices,-1]).float()
test_numerical_data = torch.from_numpy(data06.values[~train_indices,:-1]).float()
test_targets = torch.from_numpy(data06.values[~train_indices,-1]).float()

train_dataset = data.TensorDataset(numerical_data, targets)
test_dataset = data.TensorDataset(test_numerical_data, test_targets)

In [151]:
data_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
test_data_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=False, drop_last=False)

## Prepare model

In [152]:
model = SimplePredictor(num_inputs=8, num_hidden=128, num_outputs=1)
model.to(device)
print(model)

SimplePredictor(
  (linear1): Linear(in_features=8, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fn): LeakyReLU(negative_slope=0.01)
  (d1): Dropout(p=0.4, inplace=False)
  (linear2): Linear(in_features=128, out_features=1, bias=True)
)


In [153]:
loss_fun = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00005)
epochs = 200

## Training

In [154]:
model.train()

# Training loop
for epoch in range(epochs):
    for data_inputs, data_labels in data_loader:
        data_inputs = data_inputs.to(device)
        data_labels = data_labels.to(device)
        optimizer.zero_grad()
        preds = model(data_inputs).squeeze(dim=1)
        loss = loss_fun(preds, data_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch}, loss: {loss.item():.3}")

Epoch: 0, loss: 5.87e+02
Epoch: 1, loss: 7.28e+02
Epoch: 2, loss: 7.34e+02
Epoch: 3, loss: 6.84e+02
Epoch: 4, loss: 6.94e+02
Epoch: 5, loss: 6.43e+02
Epoch: 6, loss: 6.99e+02
Epoch: 7, loss: 6.73e+02
Epoch: 8, loss: 5.84e+02
Epoch: 9, loss: 6.83e+02
Epoch: 10, loss: 5.39e+02
Epoch: 11, loss: 5.08e+02
Epoch: 12, loss: 5.53e+02
Epoch: 13, loss: 4.9e+02
Epoch: 14, loss: 6.37e+02
Epoch: 15, loss: 4.85e+02
Epoch: 16, loss: 4.3e+02
Epoch: 17, loss: 6.83e+02
Epoch: 18, loss: 5.27e+02
Epoch: 19, loss: 5.72e+02
Epoch: 20, loss: 6.14e+02
Epoch: 21, loss: 6.89e+02
Epoch: 22, loss: 5.85e+02
Epoch: 23, loss: 4.29e+02
Epoch: 24, loss: 5.1e+02
Epoch: 25, loss: 5.28e+02
Epoch: 26, loss: 5.09e+02
Epoch: 27, loss: 5.57e+02
Epoch: 28, loss: 4.18e+02
Epoch: 29, loss: 4.39e+02
Epoch: 30, loss: 5.15e+02
Epoch: 31, loss: 4.71e+02
Epoch: 32, loss: 4.33e+02
Epoch: 33, loss: 3.59e+02
Epoch: 34, loss: 4.12e+02
Epoch: 35, loss: 3.97e+02
Epoch: 36, loss: 3.62e+02
Epoch: 37, loss: 4.63e+02
Epoch: 38, loss: 3.85e+02

In [155]:
get_accuracy(model, test_data_loader)

0.2008025929927458

In [156]:
data_labels

tensor([ 20.9200,  22.6400,  11.9000,  27.8100,  25.5600,  22.8000,  18.8400,
         39.7800,  10.4100,  24.0900,  46.9900,   7.1000,  37.1800,   8.7000,
          9.7700,  18.6300,   7.0700,  25.3200,  15.0900,  10.6400,  31.9500,
         27.9700,   8.0300,  45.4600,  15.8300,  53.5600,   8.8700,  15.4200,
         14.8400,   9.9000,  19.3600,  10.2800,  21.1100,  49.6700,  11.1900,
         12.2700,   8.0100,  41.5700,  21.5400,   4.7500,  28.4000,   5.0600,
         13.1200,   5.5200,  16.7900,  54.8300,  23.3800,   4.0500,  11.1900,
         25.8400,  58.3900,  10.2500,  22.5100,  29.4600,  15.9800,   6.9600,
         17.1900,   9.2800,  20.0400,  17.6500,  21.8200,  24.0700,  13.7700,
         11.5100,  30.2600,  25.2100,  11.4300,   9.4700,  12.5200,  17.9700,
         11.2400,  59.2900,  21.3200,  14.9300,  17.9600,  14.1600,  19.8900,
         17.3700,  25.7400,  18.8300,  25.0100,   8.9900,  16.7800,  15.8400,
         29.3700,  15.2700,  33.8300,  22.3900,  12.2700,  17.88

In [157]:
preds

tensor([ 18.9709,  49.6116,   8.1770,  14.2525,  20.7363,  14.5305,  27.4487,
         24.2884,  12.4430,  14.1497,  21.1399,  14.4776,  39.7353,   8.2929,
         18.3694,  14.2637,  12.8021,  33.2520,  21.1720,  12.1356,  33.2720,
         18.5826,  15.4705,  32.2479,  16.1407,  48.9864,  22.3402,  11.1683,
         17.1737,  13.1919,  14.0843,  23.6869,  17.2599,  24.0692,  23.4664,
         15.0107,   7.5973,  34.4271,  16.2389,  10.7882,  16.7696,   6.5953,
         39.2492,  20.7286,  10.2109,  36.8062,   8.3447,   9.7610,  13.4944,
         21.3061,   8.5721,  14.2382,  21.2795,  37.4153,   8.8329,  10.6365,
         18.7296,  15.6151,   7.9315,  15.4716,  14.6352,   9.4295,  13.4242,
          8.5209,  22.3609,  26.1189,  14.0130,  17.2038,  13.3810,  11.8236,
         17.5076,  59.2773,  43.5502,  22.5803,  19.3706,  13.3424,  17.6790,
         17.3370,  18.1210,   9.5628,  28.7828,  14.1420,  17.7311,  23.9240,
         40.9087,  23.1944,  11.4875,  20.1528,  20.8879,  16.68