# CARE Model Training Notebook

This notebook trains the CARE (Content-Aware Image Restoration) neural network model.

## Setup

Make sure you have the required packages installed:
```bash
pip install tensorflow csbdeep numpy matplotlib
```

For GPU support with TensorFlow, ensure you have the appropriate CUDA drivers installed.

In [None]:
import tensorflow as tf
print("GPUs:", tf.config.list_physical_devices('GPU'))

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE

### Load Training Data

**Note:** Update the path below to match your local directory structure.

In [None]:
# TODO: Update this path to your patches file
(X,Y), (X_val,Y_val), axes = load_training_data(
    'path/to/your/data/fixed/CARE/raw_data/for_training_5ms/patches/patches.npz', 
    validation_split=0.1, 
    verbose=True
)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

### Visualize Validation Patches

In [None]:
plt.figure(figsize=(12,5))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

### Configure Model

In [None]:
config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=400, train_batch_size=20, probabilistic=True)
print(config)
vars(config)

### Train Model

**Note:** Update the `basedir` below to specify where you want to save the trained model.

For TensorBoard monitoring, you can run in a separate terminal:
```bash
tensorboard --logdir=models
```

In [None]:
# TODO: Update basedir to your desired model save location
model = CARE(config, 'my_model', basedir='models')
history = model.train(X,Y, validation_data=(X_val,Y_val))

### Plot Training History

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

### Evaluate Model on Validation Set

In [None]:
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:5])
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'
             'top row: input (source),  '
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');