# Example: Train the entire pipeline with a new dataset
This jupyter notebook shows how you can train your own GAN to generate synthetic images that fit your dataset. In step two, we then train a segmentation that is trained on your dataset and therefore maximizes performance.

## Prerequisites:

For this example you will need:
- A dataset of 2D 3x3 mm² macular OCTA images. We recommend at least 200 good quality images, a resolution of >=304x304 pixel, and to only / mainly train on healthy samples.
- An NVIDIA GPU compatible with CUDA version >= 8 and 30GB of VRAM
- A clean [conda](https://docs.conda.io/en/main/miniconda.html) environment with python 3 and [pytorch](https://pytorch.org/get-started/locally/) (tested with python 3.10, pytorch==2.0.1, and torchvision==0.15.2)

Install the required dependencies:
 > ⚠️ **_NOTE:_** Package `open3d` is currently (Nov 22, 2023) not available for python 3.11 yet

In [None]:
!pip istall -r requirements.txt

In [None]:
# Import libraries for this notebook
from matplotlib import pyplot as plt
import yaml
import os
from glob import glob
from natsort import natsorted
from PIL import Image

## 1. GAN training
We first train a new GAN model to generate realistic synthetic images that fit your dataset. For this, you first need to configure a `config.yml` file.

### 1.1 Configure GAN config file

In [None]:
with open("./configs/config_gan_ves_seg.yml", "r") as stream:
    config: dict[str,dict] = yaml.safe_load(stream)

In [None]:
# TODO Enter the path to your dataset:
YOUR_DATASET_PATH = ...

# You may want to choose your own folder for this
config["Output"]["save_dir"] = os.path.abspath("./results/custom-gan-ves-seg")

# Your real OCTA images are used to train the GAN
config["Train"]["data"]["real_B"]["files"] = YOUR_DATASET_PATH
# We use our existing dataset of synthetic vessel maps 
config["Train"]["data"]["real_A"]["files"] = os.path.abspath("./datasets/vessel_graphs/*.csv")
# We use our existing dataset of synthetic vessel maps (Make sure that these are the same vessel maps!)
config["Train"]["data"]["real_A_seg"]["files"] = os.path.abspath("./datasets/vessel_graphs/*.csv")
# We use our existing dataset of synthetic background vessel maps.
config["Train"]["data"]["background"]["files"] = os.path.abspath("./datasets/background_images/*.png")


# We want to use the GAN part of this model during inference
config["General"]["inference"] = "generator"
# We use our existing dataset of synthetic vessel maps 
config["Train"]["data"]["real_A"]["files"] = os.path.abspath("./datasets/vessel_graphs/*.csv")
# We use our existing dataset of synthetic background vessel maps.
config["Train"]["data"]["background"]["files"] = os.path.abspath("./datasets/background_images/*.png")

# In case you want to segment your dataset with the implicitly trained segmentor, run the following:
# config["General"]["inference"] = "segmentor"
# config["Test"]["real_B"]["files"] = YOUR_DATASET_PATH

In [None]:
assert isinstance(YOUR_DATASET_PATH, str), "Please provide a valid path to your dataset"
dataset_paths = natsorted(glob(YOUR_DATASET_PATH))
assert len(dataset_paths) > 0, "No images found! Please check your path again."

# Plot an example of your dataset
Image.open(dataset_paths[0])

In [None]:
# Save your custom yaml file
with open('./configs/my_custom_gan_config.yml', 'w') as f:
    yaml.dump(config, f)

### 1.2. Train the generator

In [None]:
# Path to your custom config file
CONFIG_FILE_PATH = os.path.abspath("./configs/my_custom_gan_config.yml")
# Number of cpu cores for dataloading. If not set, use half of available cores.
NUM_WORKERS = None 

# Train a new Generator network
!python train.py --config_file $CONFIG_FILE_PATH --num_workers $NUM_WORKERS

### 1.3 Validate you generator (Optional)

In [None]:
# TODO Enter the path of the config.yml file that was created during training.
CONFIG_FILE_PATH: str = ...
# TODO Enter the epoch you want to load a checkpoint from. In our paper, we use epoch 50 but this depends on your dataset.
EPOCH: int = ...

# For a simple test we will just create 3 images
NUM_SAMPLES = 3
# Number of cpu cores for dataloading. If not set, use half of available cores.
NUM_WORKERS = None

# Test your trained generator:
!python test.py --config_file $CONFIG_FILE_PATH --epoch $EPOCH --num_workers $NUM_WORKERS --num_samples $NUM_SAMPLES

In [None]:
test_image_paths = glob(CONFIG_FILE_PATH.replace("config.yml", "Test/*.png"))
test_images = [Image.open(p) for p in test_image_paths]
_, axes=plt.subplots(nrows=1, ncols=3, figsize=(9,3))
for a,i in zip(test_images, axes):
    a.imshow(i)
plt.show()

## 2. Vessel segmentation training

Now that we have a trained GAN, we can use it to augment our synthetic training images. We can now begin to train the segmentation network.

### 2.1 Configure vessel segmentation config file

> ⚠️ **_NOTE:_** In the following we assume that you use the `config_ves_seg-S_GAN.yml` without changes. If you add further training data augmentations make sure that the index (normally 6) points to the `ImageToImageTranslationd` augmentation.

In [None]:
with open("./configs/config_ves_seg-S_GAN.yml", "r") as stream:
    config: dict[str,dict] = yaml.safe_load(stream)

In [None]:
# TODO Enter the path to your generator checkpoint that you want to use
GAN_CHECKPOINT_PATH = ...
config["Train"]["data_augmentation"][6] = GAN_CHECKPOINT_PATH

# You may want to choose your own folder for this
config["Output"]["save_dir"] = os.path.abspath("./results/custom-ves-seg-S_GAN")

# We use our existing dataset of synthetic vessel maps 
config["Train"]["data"]["image"]["files"] = os.path.abspath("./datasets/vessel_graphs/*.csv")
# We use our existing dataset of synthetic vessel maps (Make sure that these are the same vessel maps!)
config["Train"]["data"]["label"]["files"] = os.path.abspath("./datasets/vessel_graphs/*.csv")
# We use our existing dataset of synthetic background vessel maps.
config["Train"]["data"]["background"]["files"] = os.path.abspath("./datasets/background_images/*.png")


# Use can use your dataset for validation (altough this will not mean much without labels)
config["Validation"]["data"]["image"]["files"] = YOUR_DATASET_PATH
config["Validation"]["data"]["image"]["split"] = None

# If you have labels for your dataset use them, otherwise you can just your dataset path. This will then use a threshold of 0.5 to create a dummy label map.
config["Validation"]["data"]["image"]["files"] = YOUR_DATASET_PATH
config["Validation"]["data"]["image"]["split"] = None

# Use can use your dataset for inference
config["Test"]["data"]["image"]["files"] = YOUR_DATASET_PATH
config["Test"]["data"]["image"]["split"] = None

In [None]:
assert isinstance(config["Train"]["data"]["image"]["files"], str), "Please provide a valid path to your dataset training"
dataset_paths = natsorted(glob(config["Train"]["data"]["image"]["files"]))
assert len(dataset_paths) > 0, "No images found! Please check your train path again."

assert isinstance(config["Validation"]["data"]["image"]["files"], str), "Please provide a valid path to your validation dataset"
dataset_paths = natsorted(glob(config["Validation"]["data"]["image"]["files"]))
assert len(dataset_paths) > 0, "No images found! Please check your validation path again."

assert isinstance(config["Test"]["data"]["image"]["files"], str), "Please provide a valid path to your test dataset"
dataset_paths = natsorted(glob(config["Test"]["data"]["image"]["files"]))
assert len(dataset_paths) > 0, "No images found! Please check your test path again."

assert os.path.isfile(GAN_CHECKPOINT_PATH), "The given patht to the generator checkpoint is not valid!" 

In [None]:
# Save your custom yaml file
with open('./configs/my_custom_ves_seg_config.yml', 'w') as f:
    yaml.dump(config, f)

### 2.2. Train the segmentation network

In [None]:
# Path to your custom vessel segmentation config file
CONFIG_FILE_PATH = os.path.abspath("./configs/my_custom_ves_seg_config.yml")
# Number of cpu cores for dataloading. If not set, use half of available cores.
NUM_WORKERS = None 

# Train a new Generator network
!python train.py --config_file $CONFIG_FILE_PATH --num_workers $NUM_WORKERS

### 2.3. Test the segmentation model

In [None]:
# TODO Enter the path of the config.yml file that was created during training.
CONFIG_FILE_PATH: str = ...

# Enter the epoch you want to load a checkpoint from. You can simply used 'latest' for now.
EPOCH: str = "latest"
# For a simple test we will just create 10 images. If you do not set this, all images will be segmented.
NUM_SAMPLES = 10
# Number of cpu cores for dataloading. If not set, use half of available cores.
NUM_WORKERS = None

# Test your trained generator:
!python test.py --config_file $CONFIG_FILE_PATH --epoch $EPOCH --num_workers $NUM_WORKERS --num_samples $NUM_SAMPLES

In [None]:
test_image_paths = glob(CONFIG_FILE_PATH.replace("config.yml", "Test/*.png"))
test_images = [Image.open(p) for p in test_image_paths]
_, axes=plt.subplots(nrows=1, ncols=3, figsize=(9,3))
for a,i in zip(test_images, axes):
    a.imshow(i)
plt.show()

# What's next?

Congratulations, you made it trough the example! 🎉

You can now start to optimize your pipeline. Possible things you might want to try next:
- Select optimal GAN and segmentor checkpoints
- Add further data augmentations
- Change the GAN model
- Experiment with other hyperparamers of our pipeline