In [15]:
import torch
from torch import nn

from src.data_loading import load_dataset_1d

from src.core import train, test

In [16]:
BATCH_SIZE = 64

In [17]:
dataset = load_dataset_1d('games_dataset.csv', BATCH_SIZE)
len(dataset)

9117

In [18]:
dataset[0][0].shape

torch.Size([64, 512])

In [19]:
DATASET_SIZE = 8000

In [20]:
train_dataset = dataset[:DATASET_SIZE]
test_dataset = dataset[DATASET_SIZE:]
len(train_dataset), len(test_dataset)

(8000, 1117)

In [21]:
validation_dataset = load_dataset_1d('puzzle_dataset.csv', BATCH_SIZE)
len(validation_dataset)

7318

In [22]:
def accuracy(out, truth):
    return torch.abs(truth - out)

In [23]:
x0 = 512
x1 = 2 ** 10
x2 = 2 ** 5
x3 = 2 ** 3


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(x0, x1)
        self.layer2 = nn.Linear(x1, x2)
        self.layer3 = nn.Linear(x2, x3)
        self.layer4 = nn.Linear(x3, 1)
        self.classifier = nn.Sequential(self.layer1,
                                        nn.ReLU(),
                                        self.layer2,
                                        nn.ReLU(),
                                        self.layer3,
                                        nn.ReLU(),
                                        self.layer4)

    def forward(self, X):
        return self.classifier.forward(X)


model = Model()

In [24]:
train(train_dataset,
      BATCH_SIZE,
      model,
      nn.MSELoss(),
      torch.optim.SGD(model.classifier.parameters(), lr=0.05),
      accuracy,
      100)

Dataset size: 8000
Epoch [1/100], train_loss: 6.418233809739351, train_accuracy: 1.8785780140310526, time: 6.42s
Epoch [2/100], train_loss: 3.891790317773819, train_accuracy: 1.4449316529035567, time: 6.624s
Epoch [3/100], train_loss: 2.914921775408089, train_accuracy: 1.2474415203258395, time: 6.497s
Epoch [4/100], train_loss: 2.37084598056972, train_accuracy: 1.1235895558223128, time: 6.675s
Epoch [5/100], train_loss: 2.0305103674158453, train_accuracy: 1.0380273125320674, time: 6.445s
Epoch [6/100], train_loss: 1.7980657447054982, train_accuracy: 0.9782051668986678, time: 6.294s
Epoch [7/100], train_loss: 1.6116149662658572, train_accuracy: 0.9262188586443663, time: 6.541s
Epoch [8/100], train_loss: 1.4639298454783858, train_accuracy: 0.8833762713745237, time: 6.657s
Epoch [9/100], train_loss: 1.3538724211677908, train_accuracy: 0.8505016731396318, time: 6.619s
Epoch [10/100], train_loss: 1.2659078379422426, train_accuracy: 0.8229015044383705, time: 6.658s
Epoch [11/100], train_loss

In [25]:
test(test_dataset, BATCH_SIZE, model, nn.MSELoss(), accuracy)

Dataset size: 1117
test_loss: 2.2645393526756883, test_accuracy: 1.0157604236205775, time: 0.337s


In [26]:
test(train_dataset, BATCH_SIZE, model, nn.MSELoss(), accuracy)

Dataset size: 8000
test_loss: 0.5259334992729128, test_accuracy: 0.5336011244058609, time: 2.234s


In [27]:
test(validation_dataset, BATCH_SIZE, model, nn.MSELoss(), accuracy)

Dataset size: 7318
test_loss: 22.470364613101665, test_accuracy: 3.88891692020164, time: 2.032s


In [28]:
torch.save(dict(model.state_dict()), 'model.pt')