In [1]:
%load_ext autoreload
%autoreload 2

import os, glob
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import scipy.io
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

import matplotlib as mpl
mpl.rc('image', cmap='inferno')

import models.model_2d as mod
import forward_model as fm
import utils as ut

In [None]:
!gpustat

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

os.environ["CUDA_VISIBLE_DEVICES"]="0" 

# Training code for 2D spatially-varying deconvolutions

## Make dataset and dataloader for training data

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
batch_size = 2

In [None]:
target_dir = '/home/kyrollos/LearnedMiniscope3D/Data/Target/'  # path to objects (ground truth)
input_dir = '/home/kyrollos/LearnedMiniscope3D/Data/Train/'    # path to simulated measurements (inputs to deconv.)

target_path = sorted(glob.glob(target_dir + '*'))
input_path = sorted(glob.glob(input_dir + '*'))

image_count=len(os.listdir(target_dir))
print(image_count) 

In [None]:
# Create a first dataset of file paths and labels
dataset = tf.data.Dataset.from_tensor_slices((input_path, target_path))
dataset = dataset.shuffle(image_count, reshuffle_each_iteration=False)


# Split into train/validation
val_size = int(image_count * 0.25)
train_ds = dataset.skip(val_size)
val_ds = dataset.take(val_size)

print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())

train_ds = train_ds.map(ut.parse_function, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(ut.parse_function, num_parallel_calls=AUTOTUNE)

train_ds = ut.configure_for_performance(train_ds,batch_size)
val_ds = ut.configure_for_performance(val_ds,batch_size)

print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())

In [None]:
#visualzie data to make sure all is good
input_batch, target_batch = next(iter(val_ds))
f, ax = plt.subplots(1, 2, figsize=(15,15))

ax[0].imshow(input_batch[0,:,:,0], vmax = 1)
ax[0].set_title('Input Data')

ax[1].imshow(target_batch[0,:,:,0], vmax = 1)
ax[1].set_title('Target Data')

print(input_batch[0,:,:,0].shape)

# load in Psfs and initialize network to train

Here we initialize with 9 PSFs taken from different parts in the field of view

In [None]:
# choose network type to train
model_type='multiwiener' # choices are 'multiwiener', 'wiener', 'unet'
filter_init_path = '../data/multiWienerPSFStack_40z_aligned.mat' # initialize with 9 PSFs
filter_key = 'multiWienerPSFStack_40z'  # key to load in

In [None]:
if model_type=='unet':
    model =mod.UNet(486, 648, 
                             encoding_cs=[24, 64, 128, 256, 512, 1024],
                             center_cs=1024,
                             decoding_cs=[512, 256, 128, 64, 24, 24],
                             skip_connections=[True, True, True, True, True, False])
elif model_type=='wiener':

    registered_psfs_path = filter_init_path
    psfs = scipy.io.loadmat(registered_psfs_path)
    psfs=psfs[filter_key]
    psfs=psfs[:,:,0,0]
    psfs=psfs/np.max(psfs)
    
    Ks=1

    model = mod.UNet_wiener(486, 648, psfs, Ks, 
                             encoding_cs=[24, 64, 128, 256, 512, 1024],
                             center_cs=1024,
                             decoding_cs=[512, 256, 128, 64, 24, 24],
                             skip_connections=[True, True, True, True, True, False])
    
    print(psfs.shape, 1)
    
elif model_type=='multiwiener':
    registered_psfs_path = filter_init_path
    psfs = scipy.io.loadmat(registered_psfs_path)
    psfs=psfs[filter_key]
    
    psfs=psfs[:,:,:,0]
    psfs=psfs/np.max(psfs)
    
    Ks =np.ones((1,1,9))
    
    model =mod.UNet_multiwiener_resize(486, 648, psfs, Ks, 
                         encoding_cs=[24, 64, 128, 256, 512, 1024],
                         center_cs=1024,
                         decoding_cs=[512, 256, 128, 64, 24, 24],
                         skip_connections=[True, True, True, True, True, False])
    
    print('initialized filter shape:', psfs.shape, 'initialized K shape:', Ks.shape)

In [None]:
model.build((None, 486, 648, 1))

model.summary()

# Train

In [None]:
## Training with TF.Dataset
initial_learning_rate = 1e-4
optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate, beta_1=0.9, beta_2=0.999, amsgrad=False) #1e-3 diverges

# Keep results for plotting
train_loss_results = []
train_accuracy_results = []
validtate_loss_results=[]
num_epochs = 1000
loss_func=ut.SSIMLoss_l1
learning_rate_counter=0
for epoch in range(num_epochs):
    validation_loss_avg=tf.keras.metrics.Mean()
    epoch_loss_avg = tf.keras.metrics.Mean()
    epoch_accuracy = tf.keras.metrics.MeanSquaredError()

    # Training loop
    iter_num=0
    for x, y in train_ds:
        # Optimize the model
        loss_value, grads = ut.grad(model,loss_func, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # Track progress
        epoch_loss_avg.update_state(loss_value)  # Add current batch loss

        epoch_accuracy.update_state(y, model(x)) 
        # Print every 1
        if iter_num % 1 == 0:
            print("Epoch {:03d}: Step: {:03d}, Loss: {:.3f}, MSE: {:.3}".format(epoch, iter_num,epoch_loss_avg.result(),
                                                                        epoch_accuracy.result()),end='\r')
        iter_num=iter_num+1
        
    

  # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())


    for x_val, y_val in val_ds:
        val_loss_value= loss_func(model, x_val, y_val)
        validation_loss_avg.update_state(val_loss_value)
        
        
    validtate_loss_results.append(validation_loss_avg.result())    
    if epoch % 1 == 0:
        print("Epoch {:03d}: MSE: {:.3}, Training Loss: {:.3f}, Validation Loss: {:.3f}".format(epoch,
                                                                    epoch_accuracy.result(), epoch_loss_avg.result(), 
                                                                                                validation_loss_avg.result()))

In [None]:
# model.load_weights('./saved_models/multiwiener')

In [None]:
# test on validation data
input_batch, target_batch = next(iter(val_ds))
imnum=1
f, ax = plt.subplots(1, 2, figsize=(15,15))
ax[0].imshow((target_batch[imnum,:,:,0]))
ax[0].set_title('Target Data')

test=model(input_batch[imnum,:,:,0].numpy().reshape((1,486, 648,1)))
ax[1].set_title('recon')
ax[1].imshow(test[0,:,:])


Once training is working, save your model using: 

    model.save_weights('./saved_models/model_name')

You can save after training is complete, or periodically throughout epochs.