In [1]:
%load_ext autoreload
%autoreload 2
from helpers import *
from networks import *
from tqdm import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


### Define some tunable variables

In [2]:
batch_size = 5 # > 10 is bad
lr = 0.001
mom = 0.5
epochs = 50
networ_type = 'CNN1' #FNN CNN1 CNN2
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [3]:
device

device(type='mps')

### Prepare data for training

In [4]:
train_x, train_y, test_x, test_y = init_data("train.csv", 0.7)
train = MyDataset(train_x, train_y)
test = MyDataset(test_x, test_y)
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)

In [5]:
num_features = train.x.shape[1]
num_features

305

### Build the neural network

In [6]:
net = eval(networ_type)()
net.to(device)
net

CNN1(
  (bn0): BatchNorm1d(305, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop0): Dropout(p=0.25, inplace=False)
  (prep): Linear(in_features=305, out_features=1000, bias=True)
  (bn1): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1): Conv1d(50, 100, kernel_size=(5,), stride=(1,), padding=(2,), groups=50)
  (pool1): AdaptiveAvgPool1d(output_size=10)
  (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop2): Dropout(p=0.25, inplace=False)
  (conv2): Conv1d(100, 100, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn3): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (drop3): Dropout(p=0.25, inplace=False)
  (conv3): Conv1d(100, 100, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn4): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), groups=1

### Train

In [7]:
criterion = nn.MSELoss()
# criterion = nn.L1Loss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom)

In [8]:
los = []
tlos = []
for epoch in range(epochs):
    tot_loss = 0.0
    net.train()
    for i, (inputs, labels) in enumerate(tqdm(train_loader), 0):
        inputs, labels = inputs.to(device), labels.to(device) # Move to GPU / CPU
        optimizer.zero_grad() # zero the parameter gradients

        labels = labels.unsqueeze(-1) # Batch color value
        # forward + backward + optimize
        # inputs = inputs.unsqueeze(1)
        outputs = net(inputs)
        # print(outputs.shape, labels.shape)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # print statistics
        tot_loss += loss.item()
    # print(tot_loss)
    ttot_loss = 0.0
    net.eval()
    for i, (inputs, labels) in enumerate(test_loader, 0):
        inputs, labels = inputs.to(device), labels.to(device) # Move to GPU / CPU
        with torch.no_grad():
            labels = labels.unsqueeze(-1) # Batch color value
            # inputs = inputs.unsqueeze(1) #.to_sparse()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
        ttot_loss += loss.item()
    los.append(tot_loss)
    tlos.append(ttot_loss)
    print(epoch, tot_loss, ttot_loss)

100%|██████████| 205/205 [00:02<00:00, 76.67it/s] 


0 149.00148969516158 1314131.1241624206


100%|██████████| 205/205 [00:01<00:00, 122.77it/s]


1 120.93343037553132 2157674.2691277955


100%|██████████| 205/205 [00:01<00:00, 122.87it/s]


2 102.01324370177463 686827.79284551


100%|██████████| 205/205 [00:01<00:00, 122.64it/s]


3 92.96643630415201 8987558.079794832


100%|██████████| 205/205 [00:01<00:00, 124.19it/s]


4 105.4042049460113 4164088.1633860543


100%|██████████| 205/205 [00:01<00:00, 124.28it/s]


5 95.75618972443044 2560293.6720994716


100%|██████████| 205/205 [00:01<00:00, 122.82it/s]


6 98.72612855862826 5063362.473621556


100%|██████████| 205/205 [00:01<00:00, 123.58it/s]


7 102.34305085078813 469668.96289390326


100%|██████████| 205/205 [00:01<00:00, 122.98it/s]


8 104.50453153438866 1693405.4935083129


100%|██████████| 205/205 [00:01<00:00, 122.68it/s]


9 100.64125384553336 737471.9906241726


100%|██████████| 205/205 [00:01<00:00, 123.59it/s]


10 95.60716384090483 2451733.1202833913


100%|██████████| 205/205 [00:01<00:00, 123.55it/s]


11 100.23782822862267 2352465.433681084


100%|██████████| 205/205 [00:01<00:00, 122.48it/s]


12 89.01783407665789 3763399.2486348078


100%|██████████| 205/205 [00:01<00:00, 123.70it/s]


13 92.1628926891135 73812.47921886854


100%|██████████| 205/205 [00:01<00:00, 122.86it/s]


14 96.93117791041732 2692365.5154578015


100%|██████████| 205/205 [00:01<00:00, 122.36it/s]


15 95.79325799737126 1108123.650194548


100%|██████████| 205/205 [00:01<00:00, 123.16it/s]


16 91.25780818238854 165289.03132419847


100%|██████████| 205/205 [00:01<00:00, 123.25it/s]


17 98.41072463802993 284061.84324703924


100%|██████████| 205/205 [00:01<00:00, 123.27it/s]


18 88.25742550333962 79484.12255770527


100%|██████████| 205/205 [00:01<00:00, 122.77it/s]


19 93.34357580449432 196438.22815392353


100%|██████████| 205/205 [00:01<00:00, 122.29it/s]


20 99.96492861025035 395506.1931642946


100%|██████████| 205/205 [00:01<00:00, 122.42it/s]


21 89.44123843125999 9316261.311684005


100%|██████████| 205/205 [00:01<00:00, 123.02it/s]


22 92.29922773689032 1700841.947327124


100%|██████████| 205/205 [00:01<00:00, 122.39it/s]


23 93.62737448234111 163581.13250433095


100%|██████████| 205/205 [00:01<00:00, 122.49it/s]


24 85.39303106814623 156060.74892094173


100%|██████████| 205/205 [00:01<00:00, 122.51it/s]


25 85.71652023307979 115609.5518753673


100%|██████████| 205/205 [00:01<00:00, 122.44it/s]


26 89.60534842731431 461600.0971460715


100%|██████████| 205/205 [00:01<00:00, 122.32it/s]


27 83.47436611354351 184610.92947934242


100%|██████████| 205/205 [00:01<00:00, 123.87it/s]


28 93.21495229937136 534438.6292705657


100%|██████████| 205/205 [00:01<00:00, 123.09it/s]


29 93.57365889567882 177676.10436145216


100%|██████████| 205/205 [00:01<00:00, 123.34it/s]


30 83.75260004587471 1106296.5445414325


100%|██████████| 205/205 [00:01<00:00, 123.91it/s]


31 90.45272192265838 716399.200844937


100%|██████████| 205/205 [00:01<00:00, 123.44it/s]


32 93.42153437063098 247202.62368221348


100%|██████████| 205/205 [00:01<00:00, 124.04it/s]


33 85.39053469523787 2011875.6203706246


100%|██████████| 205/205 [00:01<00:00, 124.10it/s]


34 82.49681625887752 2090478.1672078818


100%|██████████| 205/205 [00:01<00:00, 123.13it/s]


35 82.7358672885457 1588657.5804014765


100%|██████████| 205/205 [00:01<00:00, 123.63it/s]


36 88.55063721351326 1529705.7558947913


100%|██████████| 205/205 [00:01<00:00, 123.69it/s]


37 91.17138454277301 78350.86894639954


100%|██████████| 205/205 [00:01<00:00, 122.89it/s]


38 90.0650410298258 810795.583864402


100%|██████████| 205/205 [00:01<00:00, 123.78it/s]


39 79.46571178734303 307266.31821385585


100%|██████████| 205/205 [00:01<00:00, 123.56it/s]


40 89.81581134721637 575554.4434964135


100%|██████████| 205/205 [00:01<00:00, 122.98it/s]


41 81.72183162532747 1097232.776591327


100%|██████████| 205/205 [00:01<00:00, 123.52it/s]


42 81.60252719372511 169918.85240530595


100%|██████████| 205/205 [00:01<00:00, 123.75it/s]


43 91.79725949955173 200248.3773618153


100%|██████████| 205/205 [00:01<00:00, 123.05it/s]


44 83.83840093202889 2632714.0893380307


100%|██████████| 205/205 [00:21<00:00,  9.61it/s] 


45 82.5721289254725 689007.7887831293


100%|██████████| 205/205 [00:01<00:00, 118.19it/s]


46 83.35199515148997 416704.81401626114


100%|██████████| 205/205 [00:01<00:00, 123.51it/s]


47 86.40182356350124 82913.11425922066


100%|██████████| 205/205 [00:01<00:00, 122.85it/s]


48 86.03145030263113 477491.4327265993


100%|██████████| 205/205 [00:01<00:00, 123.72it/s]


49 77.0422130394727 1078744.2192449383


In [1]:
plt.plot(los, label="train")
plt.plot(tlos, label="test")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

NameError: name 'plt' is not defined

In [10]:
net.eval()
all_loss_train = []
with torch.no_grad():
    tot_loss = 0
    for i, (inputs, labels) in enumerate(tqdm(train_loader), 0):
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = inputs.unsqueeze(1) #.to_sparse()
        outputs = net(inputs).ravel()
        all_loss_train.extend((abs(outputs - labels) / abs(labels)).tolist())
        

  0%|          | 0/205 [00:00<?, ?it/s]


RuntimeError: running_mean should contain 1 elements not 305

In [None]:
print(test.x.shape, test.x.device)
print(test.y.shape, test.y.device)

In [None]:
all_loss_test = []
with torch.no_grad():
    tot_loss = 0
    for i, (inputs, labels) in enumerate(tqdm(test_loader), 0):
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = inputs.unsqueeze(1)
        outputs = net(inputs).ravel()
        all_loss_test.extend((abs(outputs - labels) / abs(labels)).tolist())
    

In [None]:
all_loss_train = np.array(all_loss_train)
all_loss_test = np.array(all_loss_test)


In [None]:
len(np.where(all_loss_train < 1)[0]) / len(all_loss_train), len(np.where(all_loss_test < 1)[0]) / len(all_loss_test)

In [None]:
good_pred_train = all_loss_train[np.where(all_loss_train < 1)]
good_pred_test = all_loss_test[np.where(all_loss_test < 1)]
good_pred_train.mean(), good_pred_train.std(), good_pred_test.mean(), good_pred_test.std()

In [None]:
bins = range(101)
plt.hist([good_pred_train*100, good_pred_test*100], bins=bins, alpha=0.5, label=["train", "test"])
plt.legend()
plt.show()