# Train cryo-CARE Network

In this notebook we initialize a new model and train it.

In [None]:
from train_cryo_care import CryoCARE
from csbdeep.models import Config
from csbdeep.utils import plot_history
import subprocess
import numpy as np

In [None]:
# imports and settings for the GPU
import os

GPUs_to_use = "0" # <<< use up to 3 GPUs (it crashes at 4 for some reason) (0-7, format e.g. = '1,2,4')

number_of_GPUs = len(GPUs_to_use.split(",")) 
os.environ["CUDA_VISIBLE_DEVICES"] = GPUs_to_use

## Create Model

We use the standard `CSBDeep` config.

**[Note]** Set the 'axes' to 'XY' for 2D (slice based) training, or 'ZYX' for 3D (volume based) training. This depends on which version of [03] you used for training data generation.

In [None]:
# We set the 'train_steps_per_epoch' to 75. This way 
# 'train_steps_per_epoch' * 'train_batch_size' = 'num_train_volumes'
# 75 * 16 = 1200
conf = Config(axes='ZYX', train_loss='mse', train_epochs=200, train_steps_per_epoch=75, train_batch_size
= 16)
vars(conf)

In [None]:
# The `CryoCARE` model has a data-augmentation built in, which rotates the patches randomly by 90 degrees
# about the Y-Axis
model = CryoCARE(conf, 'denoiser_model', basedir='')

## Load Train/Validation Data

In [None]:
data = np.load('train_data/train_data.npz')

In [None]:
X = data['X']
Y = data['Y']
X_val = data['X_val']
Y_val = data['Y_val']

## Train Model

In [None]:
history = model.train(X, Y, (X_val, Y_val), numGPU = number_of_GPUs) #

In [None]:
print(history.history.keys())

In [None]:
# you may have to run this manually again if you ran the whole notebook in one go (otherwise plot doesnt show up)
from matplotlib import pyplot as plt
plt.figure(figsize=(16,5))
plot_history(history, 'loss', 'val_loss');

In [None]:
# if we used multiple GPUs we need to save the model differently (the true weights are now one layer down)
if number_of_GPUs > 1:   
    one_gpu_model = model.keras_model.layers[-2]    
    one_gpu_model.save_weights('denoiser_model/multi_gpu_model.h5')
    # remove the not-so-useful-files
    subprocess.run('rm denoiser_model/weights_best.h5', shell = True)
    subprocess.run('rm denoiser_model/weights_last.h5', shell = True)

In [None]:
# restart the kernel so the GPUs are freed - not a very elegant way
# this will result in a pop up dialog saying 'The kernel appears to have died. It will restart automatically.'
# you can ignore this and accept, the script should have completed successfully

exit()