In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import optim
import pickle
from matplotlib import pyplot as plt
from libs_unet.models import unet_001
from libs_unet.training.libs_train import train_loop, test_loop
from pathlib import Path
#from torch.utils.tensorboard import SummaryWriter
import datetime

top_dir = Path.cwd()
datapath = top_dir / 'data'


In [3]:
#Leverage PyTorch native Dataset and DataLoader 
#Define Train/Test sets from 20 element data samples
with open(datapath / 'training/10k_nomods.pickle', 'rb') as f:
    fracs = pickle.load(f)
    wave = pickle.load(f)
    x_data = pickle.load(f)
    y_data = pickle.load(f)

#create dataset
#input needs a placeholder "channel" dimension since single channel
#learned labels already has max_z + 2 channels from spec_array
#data has to match weights which default to float() so cast data as same
scale_factor = 1
x_data = torch.tensor(x_data[:,None,:].astype('float32'))
y_data = torch.tensor(y_data.astype('float32'))
spec_ds = TensorDataset(scale_factor * x_data, scale_factor * y_data)
#batch sizes
train_bs = 50
test_bs = 100
#create random split for training and validation
train_len = int(0.8 * len(x_data))
test_len = len(x_data) - train_len
train_ds, test_ds = random_split(spec_ds,[train_len, test_len])
train_dl = DataLoader(train_ds, batch_size=train_bs) #took out , shuffle=True for repeatability
test_dl = DataLoader(test_ds, batch_size=test_bs)
#

In [4]:
#set parameters
el_count = 20 #first n elements used to construct model
wl_points = 760 #number of wavelength point measurements in data
learning_rate = 1

#Initialize and run
model = unet_001.LIBSUNet(el_count, wl_points)
loss_fn = nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

#code from training module
model.train()
#see initial weights and bias norms by node as-initialized
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param)}") #all non-zero



down_conv_1.0.weight,2.896204710006714
down_conv_1.0.bias,0.976987361907959
down_conv_1.2.weight,2.6898374557495117
down_conv_1.2.bias,0.19388988614082336
down_conv_2.0.weight,3.8123490810394287
down_conv_2.0.bias,0.3473900854587555
down_conv_2.2.weight,3.840686559677124
down_conv_2.2.bias,0.24461036920547485
down_conv_3.0.weight,5.400777339935303
down_conv_3.0.bias,0.35707327723503113
down_conv_3.2.weight,5.411995887756348
down_conv_3.2.bias,0.2372264415025711
down_conv_4.0.weight,7.655261993408203
down_conv_4.0.bias,0.32639551162719727
down_conv_4.2.weight,7.652919769287109
down_conv_4.2.bias,0.24704717099666595
down_conv_5.0.weight,10.845921516418457
down_conv_5.0.bias,0.3335774540901184
down_conv_5.2.weight,10.837780952453613
down_conv_5.2.bias,0.2319386750459671
up_trans_1.weight,10.840964317321777
up_trans_1.bias,0.416425883769989
up_conv_1.0.weight,7.6594390869140625
up_conv_1.0.bias,0.1695244461297989
up_conv_1.2.weight,7.660733699798584
up_conv_1.2.bias,0.24643990397453308
up_

In [5]:
#now predict / train on first batch from data loader and check gradients
for batch, (X, y) in enumerate(train_dl):
    if batch == 0: #just process first batch
        pred = model(X)
        loss = loss_fn(pred, y)
    else:
        break #exit once past first batch

# clear gradients before back prop
optimizer.zero_grad()
#check if gradients still None or explicit zeros
for name, param in model.named_parameters():
    print(f"{name},{param.grad}") #All None

down_conv_1.0.weight,None
down_conv_1.0.bias,None
down_conv_1.2.weight,None
down_conv_1.2.bias,None
down_conv_2.0.weight,None
down_conv_2.0.bias,None
down_conv_2.2.weight,None
down_conv_2.2.bias,None
down_conv_3.0.weight,None
down_conv_3.0.bias,None
down_conv_3.2.weight,None
down_conv_3.2.bias,None
down_conv_4.0.weight,None
down_conv_4.0.bias,None
down_conv_4.2.weight,None
down_conv_4.2.bias,None
down_conv_5.0.weight,None
down_conv_5.0.bias,None
down_conv_5.2.weight,None
down_conv_5.2.bias,None
up_trans_1.weight,None
up_trans_1.bias,None
up_conv_1.0.weight,None
up_conv_1.0.bias,None
up_conv_1.2.weight,None
up_conv_1.2.bias,None
up_trans_2.weight,None
up_trans_2.bias,None
up_conv_2.0.weight,None
up_conv_2.0.bias,None
up_conv_2.2.weight,None
up_conv_2.2.bias,None
up_trans_3.weight,None
up_trans_3.bias,None
up_conv_3.0.weight,None
up_conv_3.0.bias,None
up_conv_3.2.weight,None
up_conv_3.2.bias,None
up_trans_4.weight,None
up_trans_4.bias,None
up_conv_4.0.weight,None
up_conv_4.0.bias,None
up

In [6]:
#run back-prop and see that gradients are calculated for each node
loss.backward()
for name, param in model.named_parameters():
    print(f"{name},{param.grad}") #we see non-zero and zero gradient tensors through the graph

down_conv_1.0.weight,tensor([[[ 2.1243e-08,  2.2586e-08,  2.4036e-08,  2.4395e-08,  2.3493e-08,
           2.2731e-08]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[-3.2534e-10, -2.8609e-10, -1.7307e-10, -8.5201e-11, -4.3322e-11,
          -2.5433e-11]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 9.8554e-09,  9.0545e-09,  7.8447e-09,  6.4903e-09,  6.5946e-09,
           8.0523e-09]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 4.6146e-08,  4.6612e-08,  4.7744e-08,  4.8263e-08,  4.8879e-08,
           5.0359e-08]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[ 8.9251e-11,  2.6236e-10,  4.7638e-10,  3.3319e-10, -1.8610e

In [16]:
#check the L2 norm for a compact representation of zero/non-zero gradients at each node
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param.grad)}") #all non-zero

down_conv_1.0.weight,1.33759712639403e-07
down_conv_1.0.bias,8.643703768029809e-05
down_conv_1.2.weight,0.0003370021586306393
down_conv_1.2.bias,0.00022665204596705735
down_conv_2.0.weight,2.8243284759810194e-06
down_conv_2.0.bias,7.016693416517228e-06
down_conv_2.2.weight,1.2935259292135015e-05
down_conv_2.2.bias,2.27140753850108e-05
down_conv_3.0.weight,2.2594245763229992e-07
down_conv_3.0.bias,4.0226154851552565e-07
down_conv_3.2.weight,6.427567313949112e-07
down_conv_3.2.bias,9.388311923430592e-07
down_conv_4.0.weight,1.2371143220946124e-08
down_conv_4.0.bias,2.4230024209259682e-08
down_conv_4.2.weight,3.656649738559281e-08
down_conv_4.2.bias,5.807896030773918e-08
down_conv_5.0.weight,5.1466924055887375e-09
down_conv_5.0.bias,1.0795077720615609e-08
down_conv_5.2.weight,1.7583513312047216e-08
down_conv_5.2.bias,2.7722281359388035e-08
up_trans_1.weight,1.1300055113849794e-08
up_trans_1.bias,7.633719434352315e-08
up_conv_1.0.weight,2.2477199479453702e-07
up_conv_1.0.bias,1.96062131863

In [7]:
#step the optimizer and see whether L2 norm of weights/biases change
optimizer.step() #leverages tensor gradients from backward() to update weights
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param)}") #all non-zero, note norm of first node changes so weights update

down_conv_1.0.weight,2.896204710006714
down_conv_1.0.bias,0.977003276348114
down_conv_1.2.weight,2.6898434162139893
down_conv_1.2.bias,0.19381438195705414
down_conv_2.0.weight,3.8123490810394287
down_conv_2.0.bias,0.34739255905151367
down_conv_2.2.weight,3.840686321258545
down_conv_2.2.bias,0.24460531771183014
down_conv_3.0.weight,5.400777339935303
down_conv_3.0.bias,0.3570733368396759
down_conv_3.2.weight,5.411995887756348
down_conv_3.2.bias,0.23722641170024872
down_conv_4.0.weight,7.655261993408203
down_conv_4.0.bias,0.32639551162719727
down_conv_4.2.weight,7.652919769287109
down_conv_4.2.bias,0.24704717099666595
down_conv_5.0.weight,10.845921516418457
down_conv_5.0.bias,0.3335774540901184
down_conv_5.2.weight,10.837780952453613
down_conv_5.2.bias,0.2319386750459671
up_trans_1.weight,10.840964317321777
up_trans_1.bias,0.416425883769989
up_conv_1.0.weight,7.6594390869140625
up_conv_1.0.bias,0.16952446103096008
up_conv_1.2.weight,7.660733699798584
up_conv_1.2.bias,0.24643996357917786
u

In [8]:
#Now process the second batch and see if we are getting updates to everything
for batch, (X, y) in enumerate(train_dl):
    if batch == 0: #skip the batch we already processed
        continue
    if batch == 1: #just process first batch
        pred = model(X)
        loss = loss_fn(pred, y)
    else:
        break #exit once past second batch

In [9]:
# Clear gradients
optimizer.zero_grad()
#check if gradients still None or explicit zeros
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param.grad)}") #All zeros

down_conv_1.0.weight,0.0
down_conv_1.0.bias,0.0
down_conv_1.2.weight,0.0
down_conv_1.2.bias,0.0
down_conv_2.0.weight,0.0
down_conv_2.0.bias,0.0
down_conv_2.2.weight,0.0
down_conv_2.2.bias,0.0
down_conv_3.0.weight,0.0
down_conv_3.0.bias,0.0
down_conv_3.2.weight,0.0
down_conv_3.2.bias,0.0
down_conv_4.0.weight,0.0
down_conv_4.0.bias,0.0
down_conv_4.2.weight,0.0
down_conv_4.2.bias,0.0
down_conv_5.0.weight,0.0
down_conv_5.0.bias,0.0
down_conv_5.2.weight,0.0
down_conv_5.2.bias,0.0
up_trans_1.weight,0.0
up_trans_1.bias,0.0
up_conv_1.0.weight,0.0
up_conv_1.0.bias,0.0
up_conv_1.2.weight,0.0
up_conv_1.2.bias,0.0
up_trans_2.weight,0.0
up_trans_2.bias,0.0
up_conv_2.0.weight,0.0
up_conv_2.0.bias,0.0
up_conv_2.2.weight,0.0
up_conv_2.2.bias,0.0
up_trans_3.weight,0.0
up_trans_3.bias,0.0
up_conv_3.0.weight,0.0
up_conv_3.0.bias,0.0
up_conv_3.2.weight,0.0
up_conv_3.2.bias,0.0
up_trans_4.weight,0.0
up_trans_4.bias,0.0
up_conv_4.0.weight,0.0
up_conv_4.0.bias,0.0
up_conv_4.2.weight,0.0
up_conv_4.2.bias,0.0


In [10]:
loss.backward()
#check the L2 norm and if different than before
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param.grad)}") #all non-zero

down_conv_1.0.weight,4.787734724231996e-07
down_conv_1.0.bias,0.00023190787760540843
down_conv_1.2.weight,0.00046656312770210207
down_conv_1.2.bias,0.0005074075306765735
down_conv_2.0.weight,4.160306616540765e-06
down_conv_2.0.bias,1.0564353942754678e-05
down_conv_2.2.weight,1.5857747712288983e-05
down_conv_2.2.bias,2.5834964617388323e-05
down_conv_3.0.weight,3.1246153753272665e-07
down_conv_3.0.bias,5.059141017227375e-07
down_conv_3.2.weight,9.031633680933737e-07
down_conv_3.2.bias,1.0966899708364508e-06
down_conv_4.0.weight,6.350086234618857e-09
down_conv_4.0.bias,1.4009480686638653e-08
down_conv_4.2.weight,2.096642148785577e-08
down_conv_4.2.bias,3.3050440606530174e-08
down_conv_5.0.weight,3.5402138998108512e-09
down_conv_5.0.bias,7.302556781496605e-09
down_conv_5.2.weight,1.2297765472624178e-08
down_conv_5.2.bias,1.9543325180393367e-08
up_trans_1.weight,9.030779679619627e-09
up_trans_1.bias,6.445468869742399e-08
up_conv_1.0.weight,1.674168572662893e-07
up_conv_1.0.bias,1.4792085778

In [11]:
#step the optimizer and see whether L2 norm of weights/biases change
optimizer.step() #leverages tensor gradients from backward() to update weights
for name, param in model.named_parameters():
    print(f"{name},{torch.linalg.vector_norm(param)}") #all non-zero, note norm of first node changes so weights update

down_conv_1.0.weight,2.896204710006714
down_conv_1.0.bias,0.9770189523696899
down_conv_1.2.weight,2.6898491382598877
down_conv_1.2.bias,0.19375565648078918
down_conv_2.0.weight,3.8123488426208496
down_conv_2.0.bias,0.347395122051239
down_conv_2.2.weight,3.840686559677124
down_conv_2.2.bias,0.2446010708808899
down_conv_3.0.weight,5.400777339935303
down_conv_3.0.bias,0.3570733964443207
down_conv_3.2.weight,5.411995887756348
down_conv_3.2.bias,0.23722641170024872
down_conv_4.0.weight,7.655261993408203
down_conv_4.0.bias,0.3263954818248749
down_conv_4.2.weight,7.652919769287109
down_conv_4.2.bias,0.24704715609550476
down_conv_5.0.weight,10.845921516418457
down_conv_5.0.bias,0.3335774540901184
down_conv_5.2.weight,10.837780952453613
down_conv_5.2.bias,0.2319386750459671
up_trans_1.weight,10.840964317321777
up_trans_1.bias,0.416425883769989
up_conv_1.0.weight,7.6594390869140625
up_conv_1.0.bias,0.16952447593212128
up_conv_1.2.weight,7.660733699798584
up_conv_1.2.bias,0.24644000828266144
up_t