In [1]:
import torch
from torch import nn

from src.data_loading import wdl_to_cp
from src.data_loading_3d import Dataset3D

from src.core import train, test
from src.patches import TRAIN_DATASET_PATCH, TEST_DATASET_PATCH

In [2]:
BATCH_SIZE = 64

In [3]:
train_dataset = Dataset3D(TRAIN_DATASET_PATCH, BATCH_SIZE)
test_dataset = Dataset3D(TEST_DATASET_PATCH, BATCH_SIZE)
len(train_dataset),len(test_dataset)

(13149, 3287)

In [4]:
def accuracy(out, truth):
    return wdl_to_cp(torch.abs(truth - out))

In [7]:
class Model(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Model, self).__init__()
        self.classifier = nn.Sequential(*args, **kwargs)

    def forward(self, X):
        return self.classifier.forward(X)


model = Model(nn.Conv2d(12, 32, kernel_size=4, stride=2, padding=4),
              nn.ReLU(),
              nn.Conv2d(32, 128, kernel_size=6, stride=2, padding=4),
              nn.ReLU(),
              nn.Conv2d(128, 512, kernel_size=6, stride=2, padding=2),
              nn.ReLU(),
              nn.MaxPool2d(kernel_size=2, stride=1),
              nn.Conv2d(512, 64, kernel_size=1, stride=2),
              nn.ReLU(),
              nn.Flatten(),
              nn.Linear(64, 16),
              nn.ReLU(),
              nn.Linear(16, 1),
              )

In [8]:
test(train_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 13149
test_loss: 0.7366759070938318, test_accuracy: 289.75068232398297, time: 5.477s


In [9]:
test(test_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 3287
test_loss: 0.7361266386610489, test_accuracy: 289.39922680693775, time: 1.495s


In [10]:
train(train_dataset,
      test_dataset,
      model,
      nn.MSELoss(),
      torch.optim.SGD(model.classifier.parameters(), lr=0.02),
      accuracy,
      60)

Dataset size: 13149
Epoch [1/60],  train: 0.68134,275.2    val: 0.73586, 286.7  test: 0.73612,286.9,  time: 28.643s
Epoch [2/60],  train: 0.46506,223.3    val: 0.69375, 278.5  test: 0.69479,278.9,  time: 27.628s
Epoch [3/60],  train: 0.42095,208.1    val: 0.66198, 272.0  test: 0.66496,272.8,  time: 25.408s
Epoch [4/60],  train: 0.40236,201.3    val: 0.63894, 266.4  test: 0.64436,267.7,  time: 25.201s
Epoch [5/60],  train: 0.39185,197.5    val: 0.62501, 263.6  test: 0.63303,265.5,  time: 26.629s
Epoch [6/60],  train: 0.37892,192.9    val: 0.62963, 263.4  test: 0.64078,266.0,  time: 29.283s
Epoch [7/60],  train: 0.36740,188.8    val: 0.62271, 261.0  test: 0.63885,264.6,  time: 26.754s
Epoch [8/60],  train: 0.35519,184.7    val: 0.61546, 258.0  test: 0.63713,262.9,  time: 27.686s
Epoch [9/60],  train: 0.34199,180.2    val: 0.59585, 252.6  test: 0.62485,259.2,  time: 26.643s
Epoch [10/60],  train: 0.32743,175.3    val: 0.57846, 247.3  test: 0.61570,256.0,  time: 25.85s
Epoch [11/60],  trai

In [11]:
test(test_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 3287
test_loss: 0.6624998499420881, test_accuracy: 247.70778778791066, time: 1.354s


In [12]:
test(train_dataset, model, nn.MSELoss(), accuracy)

Dataset size: 13149
test_loss: 0.4572316542652681, test_accuracy: 192.2754278035751, time: 5.268s
