In [1]:
import torch
from torch import nn

from src.data_loading import wdl_to_cp
from src.data_loading_1d import Dataset1D

from src.core import train, test

In [2]:
BATCH_SIZE = 64

In [3]:
from src.patches import TRAIN_DATASET_PATCH, TEST_DATASET_PATCH

train_dataset = Dataset1D(TRAIN_DATASET_PATCH, BATCH_SIZE)
test_dataset = Dataset1D(TEST_DATASET_PATCH, BATCH_SIZE)
len(train_dataset), len(test_dataset)

(13149, 3287)

In [4]:
def accuracy(out, truth):
    return wdl_to_cp(torch.abs(truth - out))

In [5]:
x0 = 768
x1 = 2 ** 12
x2 = 2 ** 7
x3 = 2 ** 4


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(x0, x1)
        self.layer2 = nn.Linear(x1, x2)
        self.layer3 = nn.Linear(x2, x3)
        self.layer4 = nn.Linear(x3, 1)
        self.classifier = nn.Sequential(self.layer1,
                                        nn.ReLU(),
                                        self.layer2,
                                        nn.ReLU(),
                                        self.layer3,
                                        nn.ReLU(),
                                        self.layer4)

    def forward(self, X):
        return self.classifier.forward(X)


model = Model()

In [6]:
test(test_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 3287
test_loss: 0.7354396714205288, test_accuracy: 289.0905156891361, time: 1.118s


In [7]:
test(train_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 13149
test_loss: 0.735663146340391, test_accuracy: 289.3062527217251, time: 3.18s


In [8]:
train(train_dataset,
      test_dataset,
      model,
      nn.MSELoss(),
      torch.optim.SGD(model.classifier.parameters(), lr=0.001),
      accuracy,
      300)

Dataset size: 13149
Epoch [1/300],  train: 0.73095,288.0    val: 0.72624, 287.0  test: 0.72596,286.7,  time: 15.691s
Epoch [2/300],  train: 0.71511,284.7    val: 0.70567, 283.0  test: 0.70590,282.7,  time: 15.408s
Epoch [3/300],  train: 0.68958,280.1    val: 0.68972, 279.7  test: 0.69039,279.7,  time: 15.171s
Epoch [4/300],  train: 0.65579,272.5    val: 0.67064, 276.2  test: 0.67235,276.4,  time: 14.49s
Epoch [5/300],  train: 0.60783,260.7    val: 0.64975, 271.2  test: 0.65304,272.0,  time: 14.324s
Epoch [6/300],  train: 0.55599,247.2    val: 0.62042, 264.4  test: 0.62557,265.6,  time: 14.204s
Epoch [7/300],  train: 0.50892,234.5    val: 0.59051, 258.0  test: 0.59833,259.8,  time: 14.739s
Epoch [8/300],  train: 0.47688,225.3    val: 0.57111, 253.3  test: 0.58177,255.8,  time: 14.423s
Epoch [9/300],  train: 0.45843,219.8    val: 0.55743, 250.0  test: 0.57092,253.1,  time: 14.071s
Epoch [10/300],  train: 0.44486,215.7    val: 0.54574, 247.0  test: 0.56222,250.9,  time: 13.183s
Epoch [11/


KeyboardInterrupt



In [None]:
test(test_dataset, model, nn.MSELoss(), accuracy)

In [None]:
test(train_dataset, model, nn.MSELoss(), accuracy)

In [None]:
torch.save(dict(model.state_dict()), 'model.pt')