In [1]:
import utils.hgg_utils as hu
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm.notebook import tqdm 
from model import unet
from utils.dice import dice_loss as dice
from utils.dice import dice_coef as dice_coef
from sklearn.utils import shuffle
from IPython import display
from tensorflow.keras.mixed_precision import experimental as mixed_precision
import time
import pickle
import gc

In [2]:
#policy = mixed_precision.Policy("float32") 
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_policy(policy)

In [3]:
num_to_load = 5
n_slices = 155


# The val that varies between experiments

In [4]:
ds = 4

### Prepare to load in some input data and masks and remove outliers

In [5]:
patients = hu.get_each_normalized_hgg_folder()
patients = hu.remove_outliers(patients)

masks = hu.get_each_hgg_folder()
masks = hu.remove_outliers(masks)

patients, masks = shuffle(patients, masks, random_state=1)

### Train, Test indices

In [6]:
train_data_ratio = 0.8

In [7]:
train_start = 0
train_stop = int(np.round(train_data_ratio * len(patients)))
print(train_start)
train_stop

0


194

In [8]:
test_start = train_stop
test_stop = len(patients)
print(test_start)
test_stop

194


243

### Save paths for train 

In [9]:
train_data = patients[:test_start]
train_masks = masks[:test_start]

fname_train_data = "ds_"+str(ds)+"_train_data.pkl"
fname_train_masks = "ds_"+str(ds)+"_train_masks.pkl"

with open(fname_train_data, 'wb') as file_pi:
    pickle.dump(train_data, file_pi)
    
with open(fname_train_masks, 'wb') as file_pi:
    pickle.dump(train_masks, file_pi)

### Save the paths for testing 

In [10]:
test_data = patients[test_start:]
test_masks = masks[test_start:]

fname_test_data = "ds_"+str(ds)+"_test_data.pkl"
fname_test_masks = "ds_"+str(ds)+"_test_masks.pkl"

with open(fname_test_data, 'wb') as file_pi:
    pickle.dump(test_data, file_pi)
    
with open(fname_test_masks, 'wb') as file_pi:
    pickle.dump(test_masks, file_pi)

### Preallocate arrays to hold data & masks

In [11]:
some_data = np.ones([ num_to_load*155, 240, 240, 4])
some_masks = np.ones([ num_to_load*155, 240, 240, 1])

In [12]:
def load_n_brains(data, start, stop, paths, end):

    data_idx = 0
    num_slices = 155
    brains_seen = 0
    
    #for multimodal_tensor in tqdm(range(start, stop)):
    for multimodal_tensor in range(start, stop):

        if multimodal_tensor != end:
            four_channel_scan = hu.reshape_tensor_with_slices_first(
                                    hu.get_a_multimodal_tensor( 
                                                paths[multimodal_tensor] 
                                    )[data_idx]
            )
            #print(paths[multimodal_tensor])


            for slic in range(num_slices):
                data[slic+(num_slices*brains_seen),:,:,:] = four_channel_scan[slic,:,:,:]

            brains_seen += 1
        else:
            break
        
    #print(multimodal_tensor)
    return data

In [13]:
def load_n_masks(data, start, stop, paths, end):

    data_idx = 0
    num_slices = 155
    brains_seen = 0

    for mask_idx in range(start, stop):
        if mask_idx != end:

            mask =  hu.reshape_tensor_with_slices_first(
                                    hu.convert_mask_to_binary_mask(
                                         hu.get_a_mask_tensor( paths[mask_idx] )

                                   )
            )
            #print(paths[mask_idx])


            for slic in range(num_slices):
                data[slic+(num_slices*brains_seen),:,:,:] = mask[slic,:,:,:]

            brains_seen += 1
        else:
            break
            
    return data

### Train

In [14]:
chunks = 39

beg = 0
end = num_to_load*155
truncated = 155*(train_stop - ( (chunks-1) * num_to_load  ) )

#truncated
#155*(train_stop - ( (chunks-1) * num_to_load  ) )

my_opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

In [15]:

start_time = time.time()

for run in tqdm(range(5)):
    
    print("********************")
    print("Run:", run)
    
    model = unet( input_size=(240,240,4), ds=ds )
    # Save architecture
    model_json_name = "unet_ds_{}.json".format(ds)
    with open(model_json_name, "w") as json_file:
        json_file.write(model.to_json())
        
    model.compile(optimizer=my_opt, loss=dice, metrics=[dice_coef])
    
    run_history = []

    for epoch in tqdm(range(20)):
        
        epoch_history = []
        
        for i in range(chunks):
            #print("Loading chunk of data...")
            some_data = load_n_brains(some_data,  (num_to_load*i), (num_to_load*i)+num_to_load, patients, train_stop).astype(np.float32)
            some_masks = load_n_masks(some_masks, (num_to_load*i), (num_to_load*i)+num_to_load, masks, train_stop).astype(np.float32)

            some_data, some_masks = shuffle(some_data, some_masks, random_state=1)

            if num_to_load*i+num_to_load <= train_stop:
                history = model.fit(some_data[beg:end,...], some_masks[beg:end,...], validation_split=0.2, epochs=1, batch_size=16, verbose=0)

            else:
                history = model.fit(some_data[beg:truncated,...], some_masks[beg:truncated,...], validation_split=0.2, epochs=1, batch_size=16, verbose=0)


            epoch_history.append(history.history)
            gc.collect()
        print("Epoch", epoch, "completed")
        print("Elapsed time:", (time.time() - start_time)/60.0, "minutes" )
        run_history.append(epoch_history)
    
    print()
    print("Saving run", run, "loss etc.")
    history_name = "ds_"+str(ds)+"_run_" + str(run) +"_histories.pkl"

    with open(history_name, 'wb') as file_pi:
        pickle.dump(run_history, file_pi)
    
    model_weights_name = "ds_"+str(ds)+"_run_" + str(run) +"_model_weights.h5"
    
    print("Saving run", run, "model weights as", model_weights_name)
    model.save_weights(model_weights_name)
    
    del model
    gc.collect()
    
print("Total time:", (time.time() - start_time)/60.0, "minutes"  )

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

********************
Run: 0
Epoch 0 completed
Elapsed time: 6.281456716855367 minutes
Epoch 1 completed
Elapsed time: 12.42411817709605 minutes
Epoch 2 completed
Elapsed time: 18.56976407766342 minutes
Epoch 3 completed
Elapsed time: 24.738404874006907 minutes
Epoch 4 completed
Elapsed time: 30.881374418735504 minutes
Epoch 5 completed
Elapsed time: 37.043267011642456 minutes
Epoch 6 completed
Elapsed time: 43.203418425718944 minutes
Epoch 7 completed
Elapsed time: 49.35093146165212 minutes
Epoch 8 completed
Elapsed time: 55.574229045708975 minutes
Epoch 9 completed
Elapsed time: 61.73570133050283 minutes
Epoch 10 completed
Elapsed time: 67.88629736502965 minutes
Epoch 11 completed
Elapsed time: 74.0407768646876 minutes
Epoch 12 completed
Elapsed time: 80.1841358780861 minutes
Epoch 13 completed
Elapsed time: 86.34899410009385 minutes
Epoch 14 completed
Elapsed time: 92.49589199225107 minutes
Epoch 15 completed
Elapsed time: 98.64944149653117 minutes
Epoch 16 completed
Elapsed time: 10

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

In [16]:
#print( run_history[0].history )