In [34]:
import torch
from torch import nn

from data_loading import load_dataset_1d

from core import train, test

In [35]:
BATCH_SIZE = 64

In [36]:
dataset = load_dataset_1d('games_dataset.csv', BATCH_SIZE)
len(dataset)

9117

In [37]:
dataset[0][0].shape

torch.Size([64, 512])

In [38]:
DATASET_SIZE = 8000

In [39]:
train_dataset = dataset[:DATASET_SIZE]
test_dataset = dataset[DATASET_SIZE:]
len(train_dataset),len(test_dataset)

(8000, 1117)

In [40]:
validation_dataset = load_dataset_1d('puzzle_dataset.csv', BATCH_SIZE)
len(validation_dataset)

727

In [41]:
def accuracy(out, truth):
    return torch.abs(truth - out)

In [42]:
class Model(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Model, self).__init__()
        self.classifier = nn.Sequential(*args, **kwargs)

    def forward(self, X):
        return self.classifier.forward(X)

x1=2**10
x2=2**5
x3=2**3

model = Model(nn.Linear(512,x1),
              nn.ReLU(),
              nn.Linear(x1,x2),
              nn.ReLU(),
              nn.Linear(x2,x3),
              nn.ReLU(),
              nn.Linear(x3,1)
              )

In [43]:
train(train_dataset,
      BATCH_SIZE,
      model,
      nn.MSELoss(),
      torch.optim.SGD(model.classifier.parameters(), lr=0.02),
      accuracy,
      500)

Dataset size: 8000
Epoch [1/200], train_loss: 5.879448705345392, train_accuracy: 1.7779126019477844, time: 6.929s
Epoch [2/200], train_loss: 3.649482143178582, train_accuracy: 1.3867391659244894, time: 6.234s
Epoch [3/200], train_loss: 2.7160354040861128, train_accuracy: 1.1928198103308678, time: 5.868s
Epoch [4/200], train_loss: 2.210546361260116, train_accuracy: 1.076112511008978, time: 5.856s
Epoch [5/200], train_loss: 1.8848249880373478, train_accuracy: 0.9942407237738371, time: 5.882s
Epoch [6/200], train_loss: 1.6625629074797035, train_accuracy: 0.9343729684948922, time: 6.339s
Epoch [7/200], train_loss: 1.4993363173529506, train_accuracy: 0.8875179653987288, time: 6.635s
Epoch [8/200], train_loss: 1.3680359908789397, train_accuracy: 0.8492243732586503, time: 6.547s
Epoch [9/200], train_loss: 1.2677803937532008, train_accuracy: 0.8182686973512173, time: 8.104s
Epoch [10/200], train_loss: 1.180795012138784, train_accuracy: 0.7904159463234246, time: 7.715s
Epoch [11/200], train_los

In [44]:
test(test_dataset, BATCH_SIZE, model, nn.MSELoss(), accuracy)

Dataset size: 1117
test_loss: 2.458127800652035, test_accuracy: 1.0434147403404557, time: 0.289s


In [45]:
test(train_dataset, BATCH_SIZE, model, nn.MSELoss(), accuracy)

Dataset size: 8000
test_loss: 0.40339428310841324, test_accuracy: 0.4606853594407439, time: 2.015s


In [46]:
test(validation_dataset, BATCH_SIZE, model, nn.MSELoss(), accuracy)

Dataset size: 727
test_loss: 23.974502469027385, test_accuracy: 3.9960329030862045, time: 0.173s


In [48]:
torch.save(model.state_dict(),'model.pt')