In [1]:
import torch
from torch import nn

from src.data_loading import wdl_to_cp
from src.data_loading_halfkp import HalfKpDataset, M

from src.core import train, test

In [2]:
BATCH_SIZE = 64

In [3]:
from src.patches import TRAIN_DATASET_PATCH, TEST_DATASET_PATCH

train_dataset = HalfKpDataset(TRAIN_DATASET_PATCH, BATCH_SIZE)
test_dataset = HalfKpDataset(TEST_DATASET_PATCH, BATCH_SIZE)
len(train_dataset), len(test_dataset)

(13149, 3287)

In [4]:
def accuracy(out, truth):
    return wdl_to_cp(torch.abs(truth - out))

In [5]:
x0 = 2 * M
x1 = 2 ** 8
x2 = 2 ** 5


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(x0, x1)
        self.layer2 = nn.Linear(x1, x2)
        self.layer3 = nn.Linear(x2, 1)
        self.classifier = nn.Sequential(self.layer1,
                                        nn.ReLU(),
                                        self.layer2,
                                        nn.ReLU(),
                                        self.layer3)

    def forward(self, X):
        return self.classifier.forward(X)


model = Model()

In [9]:
test(test_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 3287


KeyboardInterrupt: 

In [7]:
test(train_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 13149
test_loss: 0.7432454234951394, test_accuracy: 291.39394797875553, time: 733.759s


In [8]:
train(train_dataset,
      test_dataset,
      model,
      nn.MSELoss(),
      torch.optim.SGD(model.classifier.parameters(), lr=0.001),
      accuracy,
      300)

Dataset size: 13149


KeyboardInterrupt: 

In [None]:
test(test_dataset, model, nn.MSELoss(), accuracy)

In [None]:
test(train_dataset, model, nn.MSELoss(), accuracy)

In [None]:
torch.save(dict(model.state_dict()), 'halfkp_state.pt')

In [None]:
torch.save([[model.layer1.weight.tolist(), model.layer1.bias.tolist()],
            [model.layer2.weight.tolist(), model.layer2.bias.tolist()],
            [model.layer3.weight.tolist(), model.layer3.bias.tolist()]],
           "halfkp_wb.pt")