# TODO:

1. [x] load dataset into tensor, convert to float32  
    - [x] apply normalization
2. [x] Implement DataParallel training  
    - [x] increase minibatch size to 128 for 32 per device
3. [ ] try training and benchmark speed
4. [ ] fix simulation script to get the correct labels and retrain

gpu datasheet (we have sxm version): https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf

TODO: change dataparallel to distributed data parallel at some point, and move everything from the notebook into a training script

Links:
https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices'


In [None]:
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
from pathlib import Path
import h5py
import numpy as np
import dask.array as da
from torchvision.transforms import Normalize
from sklearn.model_selection import train_test_split
import sklearn
import pandas as pd

In [None]:
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))
    
# use devices 0-3

#### Load matlab data

In [39]:
# feats = h5py.File('samplesChirp.mat', 'r')
# labels = h5py.File('labelsChirp.mat', 'r')
feats, labels = scipy.io.loadmat('output/feats4T_.1R.mat'), scipy.io.loadmat('output/labels4T_.1R.mat')
feats = feats['features']
labels = labels['lp']

labels = labels.astype('float32')
labels = labels.T

feats = feats.astype('float32')
feats = feats.T
feats = feats.reshape((feats.shape[0], -1, feats.shape[-1]))
feats = feats[:, None, :, :]
feats.shape, feats.dtype



((42581, 1, 16, 48), dtype('float32'))

In [None]:
# feats['samples'], labels['labels']['position']

In [None]:
# feats_da = da.from_array(feats['samples']).astype('float32') # cast to float32
# feats_da = feats_da[:,None,:,:] # add channel dimension
# feats_da

In [None]:
# labels_da = da.from_array(labels['labels']['position']).astype('float32')
# labels_da

In [40]:
X = torch.Tensor(feats)
Y = torch.Tensor(labels)
X.shape, Y.shape, X.dtype, Y.dtype

(torch.Size([42581, 1, 16, 48]),
 torch.Size([42581, 3]),
 torch.float32,
 torch.float32)

In [41]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

In [42]:
X_train_mean = X_train.mean()
X_train_std = X_train.std()
X_train_mean, X_train_std

(tensor(0.0168), tensor(0.0659))

NORMALIZATION:
1. create custom Dataset class based on TensorDataset that will apply a normalization transform if provided
2. create train and test datasets, pass in X_train_mean and X_train_std

In [43]:
class CustomTensorDataset(Dataset):
    def __init__(self, tensors, transforms=None):
        # check to make sure number of samples match
        assert all(tensors[0].shape[0] == tens.shape[0] for tens in tensors)
        self.tensors = tensors
        self.transforms = transforms
        
    def __getitem__(self, index):
        x = self.tensors[0][index]
        
        if self.transforms is not None:
            x = self.transforms(x)
            
        y = self.tensors[1][index]
        return x, y
    
    def __len__(self):
        return self.tensors[0].shape[0]

In [44]:
train_dataset = CustomTensorDataset([X_train, Y_train], Normalize(X_train_mean, X_train_std))
test_dataset = CustomTensorDataset([X_test, Y_test], Normalize(X_train_mean, X_train_std))
train_dataset, test_dataset

(<__main__.CustomTensorDataset at 0x7fefd056a820>,
 <__main__.CustomTensorDataset at 0x7fefd056a910>)

### sample shape for spectrogram dataset: (minibatch_size, 1, 300, 1024)

In [45]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)
train_loader, test_loader

(<torch.utils.data.dataloader.DataLoader at 0x7fefd03fb9a0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fefd03fbee0>)

### Define models and functions

In [46]:
class MyCNN(nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout2d(p=0.2),
            
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout2d(p=0.2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout2d(p=0.2),
        )
        # linear_in_dim = int(300/2/2/2*1024/2/2/2*64)
        linear_in_dim = 768
        self.linear1 = nn.Linear(linear_in_dim, 500)
        # self.dropout1 = nn.Dropout(p=0.2)
        self.linear2 = nn.Linear(500, 100)
        self.linear2_2 = nn.Linear(100, 20)
        # self.dropout2 = nn.Dropout(p=0.2)
        self.linear3 = nn.Linear(20, 3)

    
    def forward(self, x):
        out = self.seq(x)
        out = out.view(out.size(0), -1) # flatten to (batch size, int)
        out = F.relu(self.linear1(out))
        # out = self.dropout1(out)
        out = F.relu(self.linear2(out))
        out = F.relu(self.linear2_2(out))
        # out = self.dropout2(out)
        out = self.linear3(out)
        return out
        
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.seq = nn.Sequential(
            nn.Flatten(),
            nn.Linear(768, 20),
            nn.ReLU(),
            nn.Linear(20, 10),
            nn.ReLU(),
            nn.Linear(10, 3)
        )
    def forward(self, x):
        return self.seq(x)

def EucLoss(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape == b.shape
    assert b.shape[-1] == 3
    return torch.sum((a-b).square(), dim=-1).sqrt().mean()

def EucLossSquared(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.shape == b.shape
    assert b.shape[-1] == 3
    return torch.sum((a-b).square(), dim=-1).mean()


In [47]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = SimpleNN().to(device)

model = nn.DataParallel(MyCNN(), device_ids=[0,1,2,3]).cuda()
model

DataParallel(
  (module): MyCNN(
    (seq): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (linear1): Linear(in_features=768, out_features=500, bias=True)
    (linear2): Linear(in_features=500, out_features=1

In [48]:
crit = EucLoss
# crit = nn.L1Loss()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# optimizer = torch.optim.RMSprop(model.parameters())


#### Training/Evaluating NN

In [None]:
num_epochs = 500 
loss_tracker = np.zeros((num_epochs, 2))

num_train_batches = len(train_loader)
num_test_batches = len(test_loader)

for epoch in range(num_epochs):
    train_loss = 0
    test_loss = 0
    
    total_els = 0

    
    model = model.train()
    
    for batch_idx, (ft, lbl) in enumerate(train_loader):
        # ft, lbl = ft.to(device), lbl.to(device)
        optimizer.zero_grad()
        output = model(ft)
        lbl = lbl.cuda()
        loss = crit(output, lbl)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * ft.shape[0]
        total_els += ft.shape[0]
    assert total_els == len(train_dataset)
    train_loss /= len(train_dataset)
    loss_tracker[epoch, 0] = train_loss
        
        
    total_els = 0
    model = model.eval()
    
    with torch.no_grad():
        for batch_idx, (ft, lbl) in enumerate(test_loader):
            # ft, lbl = ft.to(device), lbl.to(device)
            output = model(ft)
            lbl = lbl.cuda()
            loss = crit(output, lbl)
            test_loss += loss.item() * ft.shape[0] # multiply by number of samples in mini-batch to get total loss for batch
            total_els += ft.shape[0]
    assert total_els == len(test_dataset)
    test_loss /= len(test_dataset) # get average loss per sample of whole dataset
    loss_tracker[epoch, 1] = test_loss
            
    print('Epoch {} | Training loss = {} | Test loss = {}'.format(epoch, train_loss, test_loss))
    
    
    
img_dir = Path('./loss_plots')
img_dir.mkdir(parents=True, exist_ok=True)

plt.figure()
plt.plot(loss_tracker)
plt.title('Training vs Testing Loss (Mean Loss Per Batch)')
plt.legend(['Training loss', 'Test loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.savefig(img_dir / '04-12_spectrogram-dset_MyCNN-extended_euc-loss.png')

# save model
# model_dir = Path('./models')
# model_dir.mkdir(parents=True, exist_ok=True)
# torch.save(model.state_dict(), model_dir / 'spectrogram_dset_dgx.pth')



Epoch 0 | Training loss = 2.1579369687317347 | Test loss = 1.8354578472798988
Epoch 1 | Training loss = 1.494756790661039 | Test loss = 1.3717727515707836
Epoch 2 | Training loss = 1.2566582799909707 | Test loss = 1.2103155349584984
Epoch 3 | Training loss = 1.1246605706696446 | Test loss = 1.1068601154505653
Epoch 4 | Training loss = 1.0143610300061399 | Test loss = 0.975734796905596
Epoch 5 | Training loss = 0.9683300960562273 | Test loss = 0.991350787261485
Epoch 6 | Training loss = 0.9277311212415145 | Test loss = 0.9637288499063235
Epoch 7 | Training loss = 0.8903996142606435 | Test loss = 0.9196464922753693
Epoch 8 | Training loss = 0.8708204357430318 | Test loss = 1.1441906854917958
Epoch 9 | Training loss = 0.8434846121249148 | Test loss = 0.8928457128479543
Epoch 10 | Training loss = 0.8179104809696095 | Test loss = 0.885201906995537
Epoch 11 | Training loss = 0.8067966478707478 | Test loss = 1.0056982117806632
Epoch 12 | Training loss = 0.7949048688345641 | Test loss = 1.1827

### Notes:

CNN seemed to help accuracy, as well as more linear layers. However, it is overfitting heavily. Batchnorm didn't really make a difference, dropout seems to make things worse. try running on stampede


#### Spectrogram dataset by itself doesn't perform well:

Epoch 0 | Training loss = 4.512690603733063 | Test loss = 4.193840344746907
Epoch 1 | Training loss = 3.207171857357025 | Test loss = 2.855196555455526
Epoch 2 | Training loss = 2.788858652114868 | Test loss = 2.7461801369984946
Epoch 3 | Training loss = 2.683992842833201 | Test loss = 2.7123541831970215
Epoch 4 | Training loss = 2.7258232831954956 | Test loss = 2.6222329139709473
Epoch 5 | Training loss = 2.623843808968862 | Test loss = 2.6298372745513916
Epoch 6 | Training loss = 2.6232070525487265 | Test loss = 2.868154525756836
Epoch 7 | Training loss = 2.6465203563372293 | Test loss = 2.767595052719116
Epoch 8 | Training loss = 2.66033927599589 | Test loss = 2.743466377258301
Epoch 9 | Training loss = 2.652188718318939 | Test loss = 2.6217918395996094
Epoch 10 | Training loss = 2.60043211778005 | Test loss = 2.60316801071167
Epoch 11 | Training loss = 2.5947894056638083 | Test loss = 2.6221278508504233
Epoch 12 | Training loss = 2.578959862391154 | Test loss = 2.5901806354522705
Epoch 13 | Training loss = 2.576880931854248 | Test loss = 2.594467878341675
Epoch 14 | Training loss = 2.616027057170868 | Test loss = 2.591064214706421
Epoch 15 | Training loss = 2.6497509876887 | Test loss = 2.662203232447306
Epoch 16 | Training loss = 2.6083373626073203 | Test loss = 2.6001358032226562
Epoch 17 | Training loss = 2.605268637339274 | Test loss = 2.5887749195098877
Epoch 18 | Training loss = 2.6101978619893393 | Test loss = 2.6554483572642007
Epoch 19 | Training loss = 2.6777352492014566 | Test loss = 2.6398793856302896
Epoch 20 | Training loss = 2.7068386475245156 | Test loss = 2.7072757879892984
Epoch 21 | Training loss = 2.627480169137319 | Test loss = 2.6925466855367026
Epoch 22 | Training loss = 2.6190430919329324 | Test loss = 2.584099292755127
Epoch 23 | Training loss = 2.566866397857666 | Test loss = 2.5893328189849854
Epoch 24 | Training loss = 2.671197772026062 | Test loss = 2.942407210667928
Epoch 25 | Training loss = 2.698890527089437 | Test loss = 2.6265974839528403
Epoch 26 | Training loss = 2.62491504351298 | Test loss = 2.707094192504883
Epoch 27 | Training loss = 2.680932104587555 | Test loss = 2.661949793497721
Epoch 28 | Training loss = 2.6023327708244324 | Test loss = 2.5905651251475015
Epoch 29 | Training loss = 2.5522359013557434 | Test loss = 2.57867161432902
Epoch 30 | Training loss = 2.5873255928357444 | Test loss = 2.660228888193766
Epoch 31 | Training loss = 2.6767356594403586 | Test loss = 2.644167423248291
Epoch 32 | Training loss = 2.5707041025161743 | Test loss = 2.603447516759237
Epoch 33 | Training loss = 2.572120189666748 | Test loss = 2.5676135222117105
Epoch 34 | Training loss = 2.5445960760116577 | Test loss = 2.5792222817738852
Epoch 35 | Training loss = 2.5703064997990928 | Test loss = 2.5740973154703775
Epoch 36 | Training loss = 2.5989076296488443 | Test loss = 2.6521519819895425
Epoch 37 | Training loss = 2.6060789426167807 | Test loss = 2.653982162475586
Epoch 38 | Training loss = 2.5480491320292153 | Test loss = 2.56284761428833
Epoch 39 | Training loss = 2.546206255753835 | Test loss = 2.582702080408732
Epoch 40 | Training loss = 2.53329328695933 | Test loss = 2.5605823198954263
Epoch 41 | Training loss = 2.5466638803482056 | Test loss = 2.6241183280944824
Epoch 42 | Training loss = 2.5401841600735984 | Test loss = 2.554289976755778
Epoch 43 | Training loss = 2.553957482179006 | Test loss = 2.613287925720215
Epoch 44 | Training loss = 2.545074760913849 | Test loss = 2.5625597635904946
Epoch 45 | Training loss = 2.5467688043912253 | Test loss = 2.565520763397217
Epoch 46 | Training loss = 2.5416312416394553 | Test loss = 2.5513343016306558
Epoch 47 | Training loss = 2.5294887820879617 | Test loss = 2.5814336140950522
Epoch 48 | Training loss = 2.530319571495056 | Test loss = 2.6669086615244546
Epoch 49 | Training loss = 2.5534868637720742 | Test loss = 2.567220369974772
Epoch 50 | Training loss = 2.545904219150543 | Test loss = 2.5498883724212646
Epoch 51 | Training loss = 2.519070307413737 | Test loss = 2.5748438040415444
Epoch 52 | Training loss = 2.5424413681030273 | Test loss = 2.548670689264933
Epoch 53 | Training loss = 2.5219553510348 | Test loss = 2.5572537581125894
Epoch 54 | Training loss = 2.5094886223475137 | Test loss = 2.560319662094116
Epoch 55 | Training loss = 2.511307974656423 | Test loss = 2.580026149749756
Epoch 56 | Training loss = 2.5115503470102944 | Test loss = 2.5456807613372803
Epoch 57 | Training loss = 2.5005688468615213 | Test loss = 2.547955592473348
Epoch 58 | Training loss = 2.5198761622111 | Test loss = 2.6861111323038735
Epoch 59 | Training loss = 2.5831815600395203 | Test loss = 2.637284517288208
Epoch 60 | Training loss = 2.572909891605377 | Test loss = 2.562631607055664
Epoch 61 | Training loss = 2.5002904136975608 | Test loss = 2.568574905395508
Epoch 62 | Training loss = 2.4539151986440024 | Test loss = 2.6927483081817627
Epoch 63 | Training loss = 2.4777592420578003 | Test loss = 2.6340691248575845
Epoch 64 | Training loss = 2.552095274130503 | Test loss = 2.943647782007853
Epoch 65 | Training loss = 2.5343843499819436 | Test loss = 2.6281658013661704
Epoch 66 | Training loss = 2.511139174302419 | Test loss = 2.5869100093841553
Epoch 67 | Training loss = 2.4871119459470115 | Test loss = 2.563323895136515
Epoch 68 | Training loss = 2.5255517959594727 | Test loss = 2.5753180980682373
Epoch 69 | Training loss = 2.462038576602936 | Test loss = 2.5954321225484214
Epoch 70 | Training loss = 2.494799772898356 | Test loss = 2.5883309046427407
Epoch 71 | Training loss = 2.426811178525289 | Test loss = 2.598881483078003
Epoch 72 | Training loss = 2.4209784666697183 | Test loss = 2.7488608360290527
Epoch 73 | Training loss = 2.46446826060613 | Test loss = 2.6066841284434
Epoch 74 | Training loss = 2.468072255452474 | Test loss = 2.587620576222738
Epoch 75 | Training loss = 2.3849696119626365 | Test loss = 2.6258848508199057
Epoch 76 | Training loss = 2.3631367683410645 | Test loss = 2.606775919596354
Epoch 77 | Training loss = 2.400774916013082 | Test loss = 2.6000608603159585
Epoch 78 | Training loss = 2.334807813167572 | Test loss = 2.7142237027486167
Epoch 79 | Training loss = 2.330035666624705 | Test loss = 2.6695587635040283
Epoch 80 | Training loss = 2.2572935819625854 | Test loss = 2.6696623961130777
Epoch 81 | Training loss = 2.2240594824155173 | Test loss = 2.754479726155599
Epoch 82 | Training loss = 2.295351425806681 | Test loss = 2.61575714747111
Epoch 83 | Training loss = 2.239086707433065 | Test loss = 2.7227691809336343
Epoch 84 | Training loss = 2.207997610171636 | Test loss = 2.7276058991750083
Epoch 85 | Training loss = 2.217043568690618 | Test loss = 2.6612442334493003
Epoch 86 | Training loss = 2.2218828002611795 | Test loss = 2.7417481740315757
Epoch 87 | Training loss = 2.2744336128234863 | Test loss = 2.69406795501709
Epoch 88 | Training loss = 2.1858057777086892 | Test loss = 2.731642246246338
Epoch 89 | Training loss = 2.1654090086619058 | Test loss = 2.6873814264933267
Epoch 90 | Training loss = 2.127222160498301 | Test loss = 2.7915919621785483
Epoch 91 | Training loss = 2.051015784343084 | Test loss = 2.6736954053243003
Epoch 92 | Training loss = 2.0790861745675406 | Test loss = 2.634180943171183
Epoch 93 | Training loss = 2.025949855645498 | Test loss = 2.751603285471598
Epoch 94 | Training loss = 1.8675897320111592 | Test loss = 2.831928094228109
Epoch 95 | Training loss = 1.9419245719909668 | Test loss = 2.8254872957865396
Epoch 96 | Training loss = 1.8731929957866669 | Test loss = 2.7153898080190024
Epoch 97 | Training loss = 1.8347639242808025 | Test loss = 2.763599236806234
Epoch 98 | Training loss = 2.005356421073278 | Test loss = 2.7253236770629883
Epoch 99 | Training loss = 1.8378906548023224 | Test loss = 2.7293864091237388
Epoch 100 | Training loss = 1.8166932662328084 | Test loss = 2.7688406308492026
Epoch 101 | Training loss = 1.7781188090642293 | Test loss = 2.717549959818522
Epoch 102 | Training loss = 1.7398491303126018 | Test loss = 2.9368110497792563
Epoch 103 | Training loss = 1.7139520943164825 | Test loss = 2.8264620304107666
Epoch 104 | Training loss = 1.7621939480304718 | Test loss = 2.7718935012817383
Epoch 105 | Training loss = 1.6212974886099498 | Test loss = 2.84261155128479
Epoch 106 | Training loss = 1.559253732363383 | Test loss = 2.854212681452433
Epoch 107 | Training loss = 1.4119056860605876 | Test loss = 2.96927809715271
Epoch 108 | Training loss = 1.4034680724143982 | Test loss = 2.9812479813893638
Epoch 109 | Training loss = 1.4534359474976857 | Test loss = 2.7374819119771323
Epoch 110 | Training loss = 1.4576045274734497 | Test loss = 2.7759761810302734
Epoch 111 | Training loss = 1.3911385635534923 | Test loss = 2.953104337056478
Epoch 112 | Training loss = 1.3195232152938843 | Test loss = 2.8024230003356934
Epoch 113 | Training loss = 1.2665107349554698 | Test loss = 2.9180691242218018
Epoch 114 | Training loss = 1.2094118297100067 | Test loss = 2.8747410774230957
Epoch 115 | Training loss = 1.2180902461210887 | Test loss = 2.9325674374898276
Epoch 116 | Training loss = 1.0766840328772862 | Test loss = 2.876356840133667
Epoch 117 | Training loss = 1.0015811175107956 | Test loss = 2.9061977863311768
Epoch 118 | Training loss = 1.0212817738453548 | Test loss = 2.981333017349243
Epoch 119 | Training loss = 1.015288641055425 | Test loss = 2.979546387990316
Epoch 120 | Training loss = 0.9890188823143641 | Test loss = 2.894641160964966
Epoch 121 | Training loss = 0.9872632374366125 | Test loss = 2.87300697962443
Epoch 122 | Training loss = 0.8896637111902237 | Test loss = 2.922434409459432
Epoch 123 | Training loss = 0.8756431738535563 | Test loss = 2.9475392500559487
Epoch 124 | Training loss = 0.8730523735284805 | Test loss = 2.929404338200887
Epoch 125 | Training loss = 0.8583066364129385 | Test loss = 2.959799289703369
Epoch 126 | Training loss = 0.9471468329429626 | Test loss = 3.007989486058553
Epoch 127 | Training loss = 0.8694671442111334 | Test loss = 2.8354373772939048
Epoch 128 | Training loss = 0.793950746456782 | Test loss = 3.010768493016561
Epoch 129 | Training loss = 0.7966095705827078 | Test loss = 2.877655824025472
Epoch 130 | Training loss = 0.8992835581302643 | Test loss = 2.9013328552246094
Epoch 131 | Training loss = 0.839479943116506 | Test loss = 2.894512971242269
Epoch 132 | Training loss = 0.7980612516403198 | Test loss = 2.909008026123047
Epoch 133 | Training loss = 0.7796344210704168 | Test loss = 2.997387647628784
Epoch 134 | Training loss = 0.8155859808127085 | Test loss = 2.9685846964518228
Epoch 135 | Training loss = 0.7698417057593664 | Test loss = 2.964077870051066
Epoch 136 | Training loss = 0.653485839565595 | Test loss = 2.8753408590952554
Epoch 137 | Training loss = 0.5882637848456701 | Test loss = 2.8730435371398926
Epoch 138 | Training loss = 0.5726363807916641 | Test loss = 2.8583150704701743
Epoch 139 | Training loss = 0.5824364374081293 | Test loss = 2.8840437730153403
Epoch 140 | Training loss = 0.6097027460734049 | Test loss = 2.9701621532440186
Epoch 141 | Training loss = 0.6112793187300364 | Test loss = 2.878804922103882
Epoch 142 | Training loss = 0.6694032202164332 | Test loss = 2.877445936203003
Epoch 143 | Training loss = 0.5986573696136475 | Test loss = 2.8629702726999917
Epoch 144 | Training loss = 0.5496392001708349 | Test loss = 2.8680386543273926
Epoch 145 | Training loss = 0.5609491864840189 | Test loss = 2.8797436555226645
Epoch 146 | Training loss = 0.5688396791617075 | Test loss = 2.8715103467305503
Epoch 147 | Training loss = 0.5568703909715017 | Test loss = 2.8297619024912515
Epoch 148 | Training loss = 0.5489857320984205 | Test loss = 2.932460149129232
Epoch 149 | Training loss = 0.5497378011544546 | Test loss = 2.8952627976735434
Epoch 150 | Training loss = 0.5418885300556818 | Test loss = 2.8697479565938315
Epoch 151 | Training loss = 0.5180931886037191 | Test loss = 2.817224899927775
Epoch 152 | Training loss = 0.4848233411709468 | Test loss = 2.8670266469319663
Epoch 153 | Training loss = 0.5095981508493423 | Test loss = 2.8884111245473227
Epoch 154 | Training loss = 0.4795256058375041 | Test loss = 2.9046644369761148
Epoch 155 | Training loss = 0.5045426338911057 | Test loss = 2.8604812622070312
Epoch 156 | Training loss = 0.49995659043391544 | Test loss = 2.8988219102223716
Epoch 157 | Training loss = 0.49334516127904254 | Test loss = 2.8423540592193604
Epoch 158 | Training loss = 0.4781015043457349 | Test loss = 2.8826920986175537
Epoch 159 | Training loss = 0.5287989204128584 | Test loss = 2.8374322255452475
Epoch 160 | Training loss = 0.5487982630729675 | Test loss = 2.93006165822347
Epoch 161 | Training loss = 0.4847252294421196 | Test loss = 2.8738359610239663
Epoch 162 | Training loss = 0.4757622430721919 | Test loss = 2.890534003575643
Epoch 163 | Training loss = 0.45206674685080844 | Test loss = 2.8310720125834146
Epoch 164 | Training loss = 0.41246474782625836 | Test loss = 2.8837059338887534
Epoch 165 | Training loss = 0.40487995743751526 | Test loss = 2.8537563482920327
Epoch 166 | Training loss = 0.43476181974013645 | Test loss = 2.941516081492106
Epoch 167 | Training loss = 0.44687868654727936 | Test loss = 2.8688937028249106
Epoch 168 | Training loss = 0.4426387498776118 | Test loss = 2.888704538345337
Epoch 169 | Training loss = 0.4189794734120369 | Test loss = 2.8439768155415854
Epoch 170 | Training loss = 0.3893149172266324 | Test loss = 2.864601214726766
Epoch 171 | Training loss = 0.40359241763750714 | Test loss = 2.9454919497172036
Epoch 172 | Training loss = 0.41057723263899487 | Test loss = 2.8415912787119546
Epoch 173 | Training loss = 0.36598721891641617 | Test loss = 2.8684465090433755
Epoch 174 | Training loss = 0.3647552008430163 | Test loss = 2.8665629227956138
Epoch 175 | Training loss = 0.3340824767947197 | Test loss = 2.858413060506185
Epoch 176 | Training loss = 0.3284987856944402 | Test loss = 2.866041421890259
Epoch 177 | Training loss = 0.36035139113664627 | Test loss = 2.8429741859436035
Epoch 178 | Training loss = 0.3578016261259715 | Test loss = 2.8614677588144937
Epoch 179 | Training loss = 0.33256562054157257 | Test loss = 2.8851611614227295
Epoch 180 | Training loss = 0.3387721429268519 | Test loss = 2.8600749174753823
Epoch 181 | Training loss = 0.3714219356576602 | Test loss = 2.88550074895223
Epoch 182 | Training loss = 0.3537074451645215 | Test loss = 2.857067267100016
Epoch 183 | Training loss = 0.34837787101666134 | Test loss = 2.838027000427246
Epoch 184 | Training loss = 0.3505762368440628 | Test loss = 2.842094580332438
Epoch 185 | Training loss = 0.38562504450480145 | Test loss = 2.835160414377848
Epoch 186 | Training loss = 0.33873580644528073 | Test loss = 2.8579559326171875
Epoch 187 | Training loss = 0.3242340013384819 | Test loss = 2.873668988545736
Epoch 188 | Training loss = 0.3271553839246432 | Test loss = 2.8571038246154785
Epoch 189 | Training loss = 0.31058179835478467 | Test loss = 2.8258886337280273
Epoch 190 | Training loss = 0.3144623264670372 | Test loss = 2.8323076566060386
Epoch 191 | Training loss = 0.31062446782986325 | Test loss = 2.8996430238087973
Epoch 192 | Training loss = 0.33441001425186795 | Test loss = 2.8261852264404297
Epoch 193 | Training loss = 0.3811745047569275 | Test loss = 2.925759792327881
Epoch 194 | Training loss = 0.3513134370247523 | Test loss = 2.856234312057495
Epoch 195 | Training loss = 0.37271806846062344 | Test loss = 2.865464528401693
Epoch 196 | Training loss = 0.3859856625398 | Test loss = 2.845203479131063
Epoch 197 | Training loss = 0.388036107023557 | Test loss = 2.8866395950317383
Epoch 198 | Training loss = 0.3276962439219157 | Test loss = 2.890552282333374
Epoch 199 | Training loss = 0.3188909739255905 | Test loss = 2.821582555770874
Epoch 200 | Training loss = 0.3096032291650772 | Test loss = 2.8967040379842124
Epoch 201 | Training loss = 0.3415669451157252 | Test loss = 2.789390484491984
Epoch 202 | Training loss = 0.3384142865737279 | Test loss = 2.926778554916382
Epoch 203 | Training loss = 0.31698766350746155 | Test loss = 2.837609132130941
Epoch 204 | Training loss = 0.29759829118847847 | Test loss = 2.8420074780782065
Epoch 205 | Training loss = 0.27972996855775517 | Test loss = 2.8413052558898926
Epoch 206 | Training loss = 0.2958216095964114 | Test loss = 2.905440409978231
Epoch 207 | Training loss = 0.2879531954725583 | Test loss = 2.8942856788635254
Epoch 208 | Training loss = 0.3514970491329829 | Test loss = 2.8464094003041587
Epoch 209 | Training loss = 0.3302631974220276 | Test loss = 2.8374098936716714
