# CARE Data Generation Notebook

This notebook generates training patches for the CARE model from raw data.

## Setup

Make sure you have the required packages installed:
```bash
pip install tensorflow csbdeep tifffile numpy matplotlib
```

For GPU support with TensorFlow, ensure you have the appropriate CUDA drivers installed.

In [None]:
import tensorflow as tf
print("GPUs:", tf.config.list_physical_devices('GPU'))

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from tifffile import imread
from csbdeep.utils import download_and_extract_zip_file, plot_some
from csbdeep.data import RawData, create_patches

### Create Training Patches

**Note:** Update the `basepath` below to match your local directory structure.

In [None]:
# TODO: Update this path to match your local directory structure
raw_data = RawData.from_folder(
    basepath    = 'path/to/your/data/fixed/CARE/raw_data/for_training_5ms',
    source_dirs = ['noisy'],
    target_dir  = 'clean',
    axes        = 'ZYX',
)

X, Y, XY_axes = create_patches(
    raw_data            = raw_data,
    patch_size          = (64,64,64),
    n_patches_per_image = 2000,
    save_file           = 'path/to/your/data/fixed/CARE/raw_data/for_training_5ms/patches/',
)

### Data Augmentation Functions

In [None]:
import numpy as np
import random
import tifffile
import os

# --- intensity augmentations ---
def random_gamma_contrast(img, gamma_range=(0.5, 2), contrast_range=(0.8, 1.2), p=0.8):
    print(img.min())
    print(img.max())
    #if random.random() > p:
    #    return img.astype(np.float32)
    img = img - img.min()
    img = img / img.max()
    print(img.min())
    print(img.max())
    gamma = random.uniform(*gamma_range)
    contrast = random.uniform(*contrast_range)

    img = np.power(img, gamma)
    #mean = np.mean(img)
    #img = (img - mean) * contrast + mean
    return np.clip(img, 0, 1).astype(np.float32)

# --- geometric augmentations ---
def random_flip_rotate(img, axes=(0,1,2)):
    """Randomly flip and rotate 3D patch."""
    for ax in axes:
        if random.random() < 0.5:
            img = np.flip(img, axis=ax)
    k = random.randint(0, 3)
    img = np.rot90(img, k, axes=(1,2))  # rotate in YX plane
    return img


# --- combined augmentation ---
def augment_patch_pair(x, y):
    """Apply augmentations to a patch pair (input, target)."""
    # geometric
    x = random_flip_rotate(x)
    y = random_flip_rotate(y)
    # intensity
    x = random_gamma_contrast(x)
    return x, y

### Apply Augmentations

**Note:** Update the `patch_file` path below.

In [None]:
# TODO: Update this path to your patches file
patch_file = 'path/to/your/data/fixed/CARE/raw_data/for_training/patches/patches.npz'

data = np.load(patch_file)
X = data['X']
Y = data['Y']
print(f"Loaded patches: {X.shape} input, {Y.shape} target")

# Apply augmentations
augmented_X, augmented_Y = [], []
n_aug = 2  # number of augmentations per original patch
print(len(X))
for i in range(len(X)):
    print(i)
    for _ in range(n_aug):
        xa, ya = augment_patch_pair(X[i], Y[i])

        #print(xa.min(), xa.max())
        #print(ya)
        augmented_X.append(xa)
        augmented_Y.append(ya)

augmented_X = np.stack(augmented_X)
augmented_Y = np.stack(augmented_Y)

# Combine with originals
X_all = np.concatenate([X, augmented_X], axis=0)
Y_all = np.concatenate([Y, augmented_Y], axis=0)
print("X_all shape:", X_all.shape, "dtype:", X_all.dtype)
print("Y_all shape:", Y_all.shape, "dtype:", Y_all.dtype)
print("Number of NaNs in X_all:", np.isnan(X_all).sum())
print("Number of NaNs in Y_all:", np.isnan(Y_all).sum())

print(f"Final patch set shape: {X_all.shape}")

# Save augmented patch set
# TODO: Update this path
save_path = 'path/to/your/data/fixed/CARE/raw_data/for_training/patches/augmented_patches_2.npz'

if X_all.shape[1] == 1:
    X_all = X_all[:, :, :, :, :]
    Y_all = Y_all[:, :, :, :, :]

# Specify axes string for load_training_data
# Batch dimension N is implicitly handled by load_training_data
axes = 'CZYX'

# Save manually
np.savez(str(save_path), X=X_all.astype(np.float32), Y=Y_all.astype(np.float32), axes=axes)
print(f"✅ Saved {X_all.shape[0]} patches to {save_path}")

### Visualize Patches

In [None]:
for i in range(20):
    plt.figure(figsize=(16,4))
    sl = slice(8*i, 8*(i+1)), 0
    plot_some(X[sl],Y[sl],title_list=[np.arange(sl[0].start,sl[0].stop)])
    plt.show()
None;