## Welcome to this tutorial on how to train your own nanotextural demixing model using the `NanTex` library! 
In this tutorial, we will focus on the second part of the process: constructing and training the network using the `Hypnos` module of the `NanTex` library.

### Requirements
- Synthetic overlay data generated with the `Tekhne` module of the `NanTex` package (see `tekhne_tutorial.ipynb`)
- Tested installation of PyTorch, ideally with GPU support (see `installation_guide.md`)
- Tested installation of TensorBoard (see `installation_guide.md`)
- `NanTex` package installed (see `installation_guide.md`)
- Basic knowledge of Python and Jupyter Notebooks

If you haven't completed the first part of the process, please refer to the `tekhne_tutorial.ipynb` notebook to generate synthetic overlay data. This data will be used to train the demixing model in this tutorial.

## Part 0: Dependencies

In [None]:
## Dependencies
import os
import json
import torch
from pprint import pprint

# Model (Hypnos)
from nantex.deep_learning.dl_building_blocks import train
from nantex.deep_learning.dl_model_assembly import (
    assembled_model,
    hyperparameters,
)

# Dataloaders (Harmonia)
from nantex.batching import Harmonia

## Part I: Configure and Instantiate Harmonia

In [None]:
## Generate Config

# define and create config directory
config_dir = "../configs/"
os.makedirs(config_dir, exist_ok=True)

# generate boilerplate config file
if not os.path.exists(os.path.join(config_dir, "harmonia_config.json")):
    Harmonia.generate_boilerplate_config_file(config_dir)

# show the generated config file
with open(os.path.join(config_dir, "harmonia_config.json"), "r") as f:
    print(f.read())

**Heads-up** \\\ Some of our modules contain convenience functions that interface with the Windows file system. If you are using a UNIX based system (Linux, MacOS), you will need to provide the directory paths manually in the configuration files. Please refer to the docstrings of the respective functions for more information.

## UNIX

In [None]:
## define train and validation directories
raw_source: str = "../path/to/directory/containing/training/olverays/"
val_source: str = "../path/to/directory/containing/validation/olverays/"

# write to config file
with open(os.path.join(config_dir, "harmonia_config.json"), "r") as f:
    config = json.load(f)

# update config with new paths
config["raw_source"] = raw_source
config["val_source"] = val_source

# write updated config back to file
with open(os.path.join(config_dir, "harmonia_config.json"), "w") as f:
    json.dump(config, f, indent=4)

In [None]:
## Instantiate
BatchProvider: Harmonia
BatchProvider = Harmonia.from_config(
    config_file_path="../configs/harmonia_config.json", datatype="npy", DEBUG=True
)

# the configuration can also be passed as a dictionary
with open("../configs/harmonia_config.json", "r") as f:
    config = json.load(f)

BatchProvider = Harmonia(config=config, datatype="npy", DEBUG=True)

## Windows

In [None]:
## Instantiate
BatchProvider: Harmonia
BatchProvider = Harmonia.from_config(
    config_file_path="../configs/harmonia_config.json", datatype="npy", DEBUG=True
)

### Checkpoint I: Instantiate the `Harmonia` class

In [None]:
## Let's check the configuration
BatchProvider.pprint_config()

## Part II: Configure CUDA and Hyperparameters

In [None]:
## Check if CUDA is available

# set backend flags
torch.backends.cudnn.enabled = True  # use CUDNN Cuda Deep Neural Network library
torch.backends.cudnn.benchmark = True  # enable benchmarking for optimized performance -> benchmark multiple convolution algorithms and choose the fastest

print("Checking CUDA availability...")
print("  version:", torch.__version__)
print("  CUDA available:", torch.cuda.is_available())
print("  cuDNN available:", torch.backends.cudnn.is_available())
print()
print("Checking on GPU...")
print("  GPU count:", torch.cuda.device_count())
print("  Current device:", torch.cuda.current_device())
print("  Current device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
print()
print("cuDNN properties:")
print("  Version:", torch.backends.cudnn.version())
print("  Enabled:", torch.backends.cudnn.enabled)
print("  Benchmark:", torch.backends.cudnn.benchmark)

In [None]:
## Define and Check Model Hyperparameters

# modify hyperparameters as needed
hyperparameters.update(
    {
        "epochs": 32,
        "steps_per_epoch": 256,
        "val_per_epoch": 128,
    }
)

# add path to store training logs and model checkpoints
save_dir: str = "../model_checkpoints/"
os.makedirs(save_dir, exist_ok=True)
hyperparameters["save_dir"] = save_dir

# check hyperparameters
pprint(hyperparameters)

### Checkpoint II
By now you should have:
* Data providers
* A configured CUDA environment
* Custom hyperparameters

## Part III: Construct Data Iterators and Conduct Model Training

In [None]:
## Load pre-trained model if available
import os

state_dict_path: str = ".../path/to/pretrained/model_state_dict.pth"
if os.path.exists(state_dict_path):
    print(f"Loading pre-trained model from {state_dict_path}...")
    assembled_model["model"].load_state_dict(
        torch.load(state_dict_path, weights_only=True)
    )
    print("Pre-trained model loaded.")

In [None]:
## Build Data Iterators
train_batcher, validation_batcher = BatchProvider.build()

In [None]:
## Model Setup Test
try:
    train(
        train_loader=train_batcher,
        val_loader=validation_batcher,
        net=assembled_model["model"],
        activation=assembled_model["activation"],
        optimizer=assembled_model["optimizer"],
        loss_fn=assembled_model["loss_fn"],
        device=assembled_model["device"],
        **hyperparameters,
    )
except Exception as e:
    print("An error occurred during training:")
    print(e)
    pass

### Checkpoint III: Model Training
Now you are ready to train your model! You can do this by calling the `train` function from the `dl_building_blocks_val_after_epoch` module. This function allows you to specify various parameters for training, including the number of epochs, steps per epoch, validation frequency, and more.

The `train` method implements automatic validation after a specified number of epochs, allowing you to monitor the model's performance on a validation dataset during training. This is particularly useful for preventing overfitting and ensuring that the model generalizes well to unseen data.

You can find the logs and model checkpoints in the directory specified by the `save_dir` parameter.
Per default this is set to `./model_checkpoints/`.