In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import optim
import pickle
from matplotlib import pyplot as plt
from libs_unet.models import unet_001
from libs_unet.training.libs_train import train_loop, test_loop
from pathlib import Path
#from torch.utils.tensorboard import SummaryWriter
import datetime
now_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') #  now 
datapath = "./data/"


In [None]:
#Leverage PyTorch native Dataset and DataLoader 
#Define Train/Test sets from 20 element data samples
with open(datapath + 'training/10k_nomods.pickle', 'rb') as f:
    fracs = pickle.load(f)
    wave = pickle.load(f)
    x_data = pickle.load(f)
    y_data = pickle.load(f)

#create dataset
#input needs a placeholder "channel" dimension since single channel
#learned labels already has max_z + 2 channels from spec_array
#data has to match weights which default to float() so cast data as same
scale_factor = 1
x_data = torch.tensor(x_data[:,None,:].astype('float32'))
y_data = torch.tensor(y_data.astype('float32'))
spec_ds = TensorDataset(scale_factor * x_data, scale_factor * y_data)
#batch sizes
train_bs = 50
test_bs = 100
#create random split for training and validation
train_len = int(0.8 * len(x_data))
test_len = len(x_data) - train_len
train_ds, test_ds = random_split(spec_ds,[train_len, test_len])
train_dl = DataLoader(train_ds, batch_size=train_bs) #took out , shuffle=True for repeatability
test_dl = DataLoader(test_ds, batch_size=test_bs)
#

In [13]:
#set parameters
el_count = 20 #first n elements used to construct model
wl_points = 760 #number of wavelength point measurements in data
learning_rate = 1

#Initialize and run
model = unet_001.LIBSUNet(el_count, wl_points)
loss_fn = nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

#code from training module
model.train()
#see initial weights and bias norms by node as-initialized
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param)}") #all non-zero



down_conv_1.0.weight,2.5268189907073975
down_conv_1.0.bias,0.8911404609680176
down_conv_1.2.weight,2.709373712539673
down_conv_1.2.bias,0.23561222851276398
down_conv_2.0.weight,3.8001809120178223
down_conv_2.0.bias,0.35201114416122437
down_conv_2.2.weight,3.8223915100097656
down_conv_2.2.bias,0.24843235313892365
down_conv_3.0.weight,5.395349025726318
down_conv_3.0.bias,0.3506188690662384
down_conv_3.2.weight,5.4237799644470215
down_conv_3.2.bias,0.2295835018157959
down_conv_4.0.weight,7.640052795410156
down_conv_4.0.bias,0.3270268738269806
down_conv_4.2.weight,7.669570446014404
down_conv_4.2.bias,0.23282566666603088
down_conv_5.0.weight,10.832441329956055
down_conv_5.0.bias,0.32407405972480774
down_conv_5.2.weight,10.830806732177734
down_conv_5.2.bias,0.23400621116161346
up_trans_1.weight,10.846290588378906
up_trans_1.bias,0.4168807566165924
up_conv_1.0.weight,7.656768321990967
up_conv_1.0.bias,0.17372988164424896
up_conv_1.2.weight,7.668402194976807
up_conv_1.2.bias,0.2346291691064834

In [14]:
#now predict / train on first batch from data loader and check gradients
for batch, (X, y) in enumerate(train_dl):
    if batch == 0: #just process first batch
        pred = model(X)
        loss = loss_fn(pred, y)
    else:
        break #exit once past first batch

# clear gradients before back prop
optimizer.zero_grad()
#check if gradients still None or explicit zeros
for name, param in model.named_parameters():
    print(f"{name},{param.grad}") #All None

down_conv_1.0.weight,None
down_conv_1.0.bias,None
down_conv_1.2.weight,None
down_conv_1.2.bias,None
down_conv_2.0.weight,None
down_conv_2.0.bias,None
down_conv_2.2.weight,None
down_conv_2.2.bias,None
down_conv_3.0.weight,None
down_conv_3.0.bias,None
down_conv_3.2.weight,None
down_conv_3.2.bias,None
down_conv_4.0.weight,None
down_conv_4.0.bias,None
down_conv_4.2.weight,None
down_conv_4.2.bias,None
down_conv_5.0.weight,None
down_conv_5.0.bias,None
down_conv_5.2.weight,None
down_conv_5.2.bias,None
up_trans_1.weight,None
up_trans_1.bias,None
up_conv_1.0.weight,None
up_conv_1.0.bias,None
up_conv_1.2.weight,None
up_conv_1.2.bias,None
up_trans_2.weight,None
up_trans_2.bias,None
up_conv_2.0.weight,None
up_conv_2.0.bias,None
up_conv_2.2.weight,None
up_conv_2.2.bias,None
up_trans_3.weight,None
up_trans_3.bias,None
up_conv_3.0.weight,None
up_conv_3.0.bias,None
up_conv_3.2.weight,None
up_conv_3.2.bias,None
up_trans_4.weight,None
up_trans_4.bias,None
up_conv_4.0.weight,None
up_conv_4.0.bias,None
up

In [15]:
#run back-prop and see that gradients are calculated for each node
loss.backward()
for name, param in model.named_parameters():
    print(f"{name},{param.grad}") #we see non-zero and zero gradient tensors through the graph

down_conv_1.0.weight,tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 2.3287e-08,  2.3363e-08,  2.3555e-08,  2.3912e-08,  2.4599e-08,
           2.5155e-08]],

        [[-2.0585e-08, -2.0310e-08, -2.0489e-08, -2.0738e-08, -2.0794e-08,
          -2.0692e-08]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 3.3252e-08,  3.2803e-08,  3.2988e-08,  3.3532e-08,  3.4404e-08,
           3.5324e-08]],

        [[-1.5091e-08, -1.4063e-08, -1.3234e-08, -1.2907e-08, -1.2942e

In [16]:
#check the L2 norm for a compact representation of zero/non-zero gradients at each node
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param.grad)}") #all non-zero

down_conv_1.0.weight,1.33759712639403e-07
down_conv_1.0.bias,8.643703768029809e-05
down_conv_1.2.weight,0.0003370021586306393
down_conv_1.2.bias,0.00022665204596705735
down_conv_2.0.weight,2.8243284759810194e-06
down_conv_2.0.bias,7.016693416517228e-06
down_conv_2.2.weight,1.2935259292135015e-05
down_conv_2.2.bias,2.27140753850108e-05
down_conv_3.0.weight,2.2594245763229992e-07
down_conv_3.0.bias,4.0226154851552565e-07
down_conv_3.2.weight,6.427567313949112e-07
down_conv_3.2.bias,9.388311923430592e-07
down_conv_4.0.weight,1.2371143220946124e-08
down_conv_4.0.bias,2.4230024209259682e-08
down_conv_4.2.weight,3.656649738559281e-08
down_conv_4.2.bias,5.807896030773918e-08
down_conv_5.0.weight,5.1466924055887375e-09
down_conv_5.0.bias,1.0795077720615609e-08
down_conv_5.2.weight,1.7583513312047216e-08
down_conv_5.2.bias,2.7722281359388035e-08
up_trans_1.weight,1.1300055113849794e-08
up_trans_1.bias,7.633719434352315e-08
up_conv_1.0.weight,2.2477199479453702e-07
up_conv_1.0.bias,1.96062131863

In [17]:
#step the optimizer and see whether L2 norm of weights/biases change
optimizer.step() #leverages tensor gradients from backward() to update weights
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param)}") #all non-zero, note norm of first node changes so weights update

down_conv_1.0.weight,2.5268189907073975
down_conv_1.0.bias,0.8911359906196594
down_conv_1.2.weight,2.709372043609619
down_conv_1.2.bias,0.2355806529521942
down_conv_2.0.weight,3.8001811504364014
down_conv_2.0.bias,0.35200873017311096
down_conv_2.2.weight,3.8223915100097656
down_conv_2.2.bias,0.24842582643032074
down_conv_3.0.weight,5.395349025726318
down_conv_3.0.bias,0.3506188690662384
down_conv_3.2.weight,5.4237799644470215
down_conv_3.2.bias,0.22958344221115112
down_conv_4.0.weight,7.640052795410156
down_conv_4.0.bias,0.3270268738269806
down_conv_4.2.weight,7.669570446014404
down_conv_4.2.bias,0.23282566666603088
down_conv_5.0.weight,10.832441329956055
down_conv_5.0.bias,0.32407405972480774
down_conv_5.2.weight,10.830806732177734
down_conv_5.2.bias,0.23400621116161346
up_trans_1.weight,10.846290588378906
up_trans_1.bias,0.41688072681427
up_conv_1.0.weight,7.656768321990967
up_conv_1.0.bias,0.17372986674308777
up_conv_1.2.weight,7.668402194976807
up_conv_1.2.bias,0.23462910950183868


In [18]:
#Now process the second batch and see if we are getting updates to everything
for batch, (X, y) in enumerate(train_dl):
    if batch == 0: #skip the batch we already processed
        continue
    if batch == 1: #just process first batch
        pred = model(X)
        loss = loss_fn(pred, y)
    else:
        break #exit once past second batch

In [19]:
# Clear gradients
optimizer.zero_grad()
#check if gradients still None or explicit zeros
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param.grad)}") #All zeros

down_conv_1.0.weight,0.0
down_conv_1.0.bias,0.0
down_conv_1.2.weight,0.0
down_conv_1.2.bias,0.0
down_conv_2.0.weight,0.0
down_conv_2.0.bias,0.0
down_conv_2.2.weight,0.0
down_conv_2.2.bias,0.0
down_conv_3.0.weight,0.0
down_conv_3.0.bias,0.0
down_conv_3.2.weight,0.0
down_conv_3.2.bias,0.0
down_conv_4.0.weight,0.0
down_conv_4.0.bias,0.0
down_conv_4.2.weight,0.0
down_conv_4.2.bias,0.0
down_conv_5.0.weight,0.0
down_conv_5.0.bias,0.0
down_conv_5.2.weight,0.0
down_conv_5.2.bias,0.0
up_trans_1.weight,0.0
up_trans_1.bias,0.0
up_conv_1.0.weight,0.0
up_conv_1.0.bias,0.0
up_conv_1.2.weight,0.0
up_conv_1.2.bias,0.0
up_trans_2.weight,0.0
up_trans_2.bias,0.0
up_conv_2.0.weight,0.0
up_conv_2.0.bias,0.0
up_conv_2.2.weight,0.0
up_conv_2.2.bias,0.0
up_trans_3.weight,0.0
up_trans_3.bias,0.0
up_conv_3.0.weight,0.0
up_conv_3.0.bias,0.0
up_conv_3.2.weight,0.0
up_conv_3.2.bias,0.0
up_trans_4.weight,0.0
up_trans_4.bias,0.0
up_conv_4.0.weight,0.0
up_conv_4.0.bias,0.0
up_conv_4.2.weight,0.0
up_conv_4.2.bias,0.0


In [20]:
loss.backward()
#check the L2 norm and if different than before
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param.grad)}") #all non-zero

down_conv_1.0.weight,1.3640787699387147e-07
down_conv_1.0.bias,7.806273060850799e-05
down_conv_1.2.weight,0.0003080574388150126
down_conv_1.2.bias,0.00020718446467071772
down_conv_2.0.weight,2.577882696641609e-06
down_conv_2.0.bias,6.413536993932212e-06
down_conv_2.2.weight,1.1861923667311203e-05
down_conv_2.2.bias,2.082867649733089e-05
down_conv_3.0.weight,2.1045364917426923e-07
down_conv_3.0.bias,3.747288701561047e-07
down_conv_3.2.weight,5.922836976424151e-07
down_conv_3.2.bias,8.651402367831906e-07
down_conv_4.0.weight,1.1400634214453476e-08
down_conv_4.0.bias,2.2329031068579752e-08
down_conv_4.2.weight,3.362794132044655e-08
down_conv_4.2.bias,5.3412080092130054e-08
down_conv_5.0.weight,4.723437641018791e-09
down_conv_5.0.bias,9.907393128116837e-09
down_conv_5.2.weight,1.614991695930712e-08
down_conv_5.2.bias,2.546218169641179e-08
up_trans_1.weight,1.0400833083679117e-08
up_trans_1.bias,7.008755886772633e-08
up_conv_1.0.weight,2.0638456987853715e-07
up_conv_1.0.bias,1.8001657053901

In [21]:
#step the optimizer and see whether L2 norm of weights/biases change
optimizer.step() #leverages tensor gradients from backward() to update weights
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param)}") #all non-zero, note norm of first node changes so weights update

down_conv_1.0.weight,2.5268189907073975
down_conv_1.0.bias,0.8911322951316833
down_conv_1.2.weight,2.7093708515167236
down_conv_1.2.bias,0.23555269837379456
down_conv_2.0.weight,3.8001809120178223
down_conv_2.0.bias,0.3520065248012543
down_conv_2.2.weight,3.8223915100097656
down_conv_2.2.bias,0.24841997027397156
down_conv_3.0.weight,5.395349025726318
down_conv_3.0.bias,0.3506189286708832
down_conv_3.2.weight,5.4237799644470215
down_conv_3.2.bias,0.22958336770534515
down_conv_4.0.weight,7.640052795410156
down_conv_4.0.bias,0.327026903629303
down_conv_4.2.weight,7.669570446014404
down_conv_4.2.bias,0.23282566666603088
down_conv_5.0.weight,10.832441329956055
down_conv_5.0.bias,0.32407405972480774
down_conv_5.2.weight,10.830806732177734
down_conv_5.2.bias,0.23400622606277466
up_trans_1.weight,10.846290588378906
up_trans_1.bias,0.4168807566165924
up_conv_1.0.weight,7.656768321990967
up_conv_1.0.bias,0.17372985184192657
up_conv_1.2.weight,7.668402194976807
up_conv_1.2.bias,0.2346290498971939