In [1]:
import smart_loading as ldr

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class baseModel(nn.Module):
    def __init__(self, inShape: tuple):
        super().__init__()
        h1 = 150
        h2 = 25
        h3 = 2
        self.lin1 = nn.Linear(inShape, h1)
        self.lin2 = nn.Linear(h1, h2)
        self.lin3 = nn.Linear(h2, h3)
        self.dp = nn.Dropout(.20)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=0)
        self.loss = nn.CrossEntropyLoss()
        self.optim = optim.SGD(self.parameters(), lr=.01)
        self.optim.zero_grad()
        return
    
    def forward(self, x):
        y = self.lin1(x)
        y = self.dp(y)
        y = self.tanh(y)
        y = self.lin2(y)
        y = self.dp(y)
        y = self.tanh(y)
        y = self.lin3(y)
        y = self.softmax(y)
        y = y[None, :]
        return y

    def step(self, y, yhat):
        loss = self.loss(yhat, y)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return loss.detach().cpu().numpy()

In [3]:
# Use the data processed by very sparse matrix projections?
useSparse = False
epochs = 100
dtype = torch.float32

# NOTE: The data generation scripts can be found in utils/gen_data.ipynb
if useSparse:
    dataShape = 4350
    sm = ldr.SmartLoader(3, load_type='restData_meanTime_spr.pt', data_shape=dataShape, dtype=dtype)
else:
    dataShape = 128*128*49
    sm = ldr.SmartLoader(5, load_type='restData_meanTime_reg.pt', data_shape=dataShape, dtype=dtype)


In [4]:
model = baseModel(dataShape).cuda()
sm.run(model, epochs)

Starting fold 0
Finished loading training data.
((0, 1)): Running loss: 0.7015975206159055
((0, 2)): Running loss: 0.7126651559956372
((0, 3)): Running loss: 0.6974553624168038
((0, 4)): Running loss: 0.7021387284621596
((0, 5)): Running loss: 0.6918039773590863
((0, 6)): Running loss: 0.6934313192032278
((0, 7)): Running loss: 0.6884957859292626
((0, 8)): Running loss: 0.6852442733943462
((0, 9)): Running loss: 0.6728845592588186
((0, 10)): Running loss: 0.6945344218984246
((0, 11)): Running loss: 0.6882271063514054
((0, 12)): Running loss: 0.6991545609198511
((0, 13)): Running loss: 0.6919162925332785
((0, 14)): Running loss: 0.6970834960229695
((0, 15)): Running loss: 0.7060211875941604
((0, 16)): Running loss: 0.6978462710976601
((0, 17)): Running loss: 0.7001422643661499
((0, 18)): Running loss: 0.6919081793166697
((0, 19)): Running loss: 0.6856463910080492
((0, 20)): Running loss: 0.6855084931012243
((0, 21)): Running loss: 0.6846971604973078
((0, 22)): Running loss: 0.6933079916