In [1]:
import torch
from torch import nn

from src.data_loading import wdl_to_cp
from src.data_loading_halfkp import HalfKpDataset, M

from src.core import train, test

In [2]:
BATCH_SIZE = 64

In [3]:
from src.patches import TRAIN_DATASET_PATCH, TEST_DATASET_PATCH

train_dataset = HalfKpDataset(TRAIN_DATASET_PATCH, BATCH_SIZE)
test_dataset = HalfKpDataset(TEST_DATASET_PATCH, BATCH_SIZE)
len(train_dataset), len(test_dataset)

(13149, 3287)

In [4]:
def accuracy(out, truth):
    return wdl_to_cp(torch.abs(truth - out))

In [5]:
x0 = 2 * M
x1 = 2 ** 8
x2 = 2 ** 5


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(x0, x1)
        self.layer2 = nn.Linear(x1, x2)
        self.layer3 = nn.Linear(x2, 1)
        self.classifier = nn.Sequential(self.layer1,
                                        nn.ReLU(),
                                        self.layer2,
                                        nn.ReLU(),
                                        self.layer3)

    def forward(self, X):
        return self.classifier.forward(X)


model = Model()
optimizer = torch.optim.SGD(model.classifier.parameters(), lr=0.001)

In [6]:
LOAD_FLAG = False

if LOAD_FLAG:
    checkpoint = torch.load('halfkp_checkpoint.pt')
else:
    checkpoint = {'epoch': 0,
                  'model': model.state_dict(),
                  'optimizer': optimizer.state_dict(),
                  'history': []}

In [10]:
train(train_dataset,
      test_dataset,
      model,
      nn.MSELoss(),
      optimizer,
      accuracy,
      300,
      checkpoint)

Epoch [2/300],  train: 0.72496,286.7    san_check: 0.71481, 284.7  test: 0.71474,284.4,  time: 2.172257s


KeyboardInterrupt: 

In [None]:
test(test_dataset, model, nn.MSELoss(), accuracy)

In [None]:
test(train_dataset, model, nn.MSELoss(), accuracy)

In [8]:
torch.save(checkpoint, 'halfkp_checkpoint.pt')
checkpoint

{'epoch': 0,
 'model': OrderedDict([('layer1.weight',
               tensor([[ 0.0019, -0.0034,  0.0014,  ...,  0.0022, -0.0020, -0.0010],
                       [-0.0025,  0.0013, -0.0027,  ..., -0.0033, -0.0031, -0.0008],
                       [ 0.0002, -0.0022,  0.0027,  ..., -0.0019, -0.0004,  0.0023],
                       ...,
                       [-0.0018, -0.0015, -0.0020,  ...,  0.0017,  0.0032,  0.0016],
                       [ 0.0031,  0.0027,  0.0012,  ...,  0.0010,  0.0029, -0.0032],
                       [ 0.0004,  0.0031,  0.0018,  ..., -0.0028,  0.0019, -0.0031]])),
              ('layer1.bias',
               tensor([ 2.0156e-03, -1.5195e-03, -2.5822e-03, -2.8514e-03, -1.1796e-03,
                        9.1030e-04,  1.6223e-03, -4.8537e-04,  1.1740e-03, -3.7497e-04,
                       -2.4401e-03,  1.6131e-04, -2.2212e-03, -1.7369e-03, -2.7221e-03,
                        1.0225e-03, -3.4781e-03,  6.6799e-04,  1.2151e-03,  1.3230e-03,
                       

In [5]:
torch.save([[model.layer1.weight.tolist(), model.layer1.bias.tolist()],
            [model.layer2.weight.tolist(), model.layer2.bias.tolist()],
            [model.layer3.weight.tolist(), model.layer3.bias.tolist()]],
           "halfkp_wb.pt")