<a name="table-of-contents"></a>
# Protein Localization Prediction using Kolmogorov-Arnold Networks.
Github: https://github.com/JinMaxx/Protein-Localization-using-KANs<br>
For optimal performance, please use a GPU runtime T4 or better.

## Table of Contents
- [01. Install pip Dependencies (Must Execute)](#1-install-pip-dependencies-must-execute)
- [02. Check Files (Must Execute)](#2-check-files-must-execute)
- [03. Settings (Must execute)](#3-settings-must-execute)
- [04. Create Encodings (Must execute)](#4-create-encodings-must-execute)
- [05. Visualize Data](#5-visualize-data)
- [06. Train Model](#6-train-model)
- [07. Continue Training a Saved Model](#7-continue-training-a-saved-model)
- [08. Model Comparison](#8-model-comparison)
- [09. Explore Figures](#9-explore-figures)

---

License: MIT<br>
This notebook and its code are made available under the MIT License.<br>
See [https://opensource.org/licenses/MIT](https://opensource.org/licenses/MIT) for details.<br>

---

<a name="1-install-pip-dependencies-must-execute"></a>
### 1. Install pip Dependencies (Must Execute)
#### [↑](#table-of-contents) [→](#2-check-files-must-execute)

In [None]:
# !git clone https://github.com/JinMaxx/Protein-Localization-using-KANs.git /content/project_root

__repo_url = "https://github.com/JinMaxx/Protein-Localization-using-KANs.git"
_encodings_download_url = "YOUR_DIRECT_DOWNLOAD_URL"
_project_root_path = "/content/project_root"
__branch = "main"

# 1. Create project_root and initialize an empty Git repo
print("Initializing Git repository...")
!mkdir -p {_project_root_path}
!cd {_project_root_path} && git init -q

# 2. Add the remote repository and enable sparse checkout
!cd {_project_root_path} && git remote add origin {__repo_url}
!cd {_project_root_path} && git config core.sparseCheckout true

# 3. Define the files and directories needed
print("Defining sparse-checkout patterns...")
!echo "config.yaml" > {_project_root_path}/.git/info/sparse-checkout
!echo ".env" >> {_project_root_path}/.git/info/sparse-checkout
!echo "source/" >> {_project_root_path}/.git/info/sparse-checkout
!echo "data/fasta/" >> {_project_root_path}/.git/info/sparse-checkout

# 4. Pull the data from the repository
# --depth=1 creates a shallow clone, fetching only the latest commit, saving time and space.
print(f"Fetching from remote repository (__branch: {__branch})...")
!cd {_project_root_path} && git pull --depth=1 origin {__branch}

# 5. Verify the contents of your project directory
print("\n--- Verification ---")
print(f"Contents of '{_project_root_path}':")
!ls -l {_project_root_path}

# This check ensures the ls command doesn't fail if the source or data/fasta directory for some reason wasn't checked out
print(f"\nContents of '{_project_root_path}/source':")
!if [ -d "{_project_root_path}/source" ]; then ls -l {_project_root_path}/source; else echo "Source directory not found."; fi
print(f"\nContents of '{_project_root_path}/data/fasta':")
!if [ -d "{_project_root_path}/data/fasta" ]; then ls -l {_project_root_path}/data/fasta; else echo "data/fasta directory not found."; fi


!pip install -q \
    pypdf \
    optuna \
    dotenv \
    pyfaidx \
    colorcet \
    reportlab \
    biopython \
    umap_learn \
    transformers \
    sentencepiece \
    kaleido==0.2.1 \
    plotly==5.5.0 \
    "huggingface_hub[hf_xet]"  # to ignore missing account warnings

!pip install git+https://github.com/AthanasiosDelis/faster-kan.git@3bcabc25c1ed5bceb04d58a8c73756a9fe54e81b
# !pip install git+https://github.com/ZiyaoLi/fast-kan.git@17b65401c252334fffb5e63c9852dd8316d29e69

%load_ext rpy2.ipython

__cell1 = True

<a name="2-check-files-must-execute"></a>
### 2. Check Files (Must Execute)
#### [←](#1-install-pip-dependencies-must-execute) [↑](#table-of-contents) [→](#3-settings-must-execute)

In [None]:
assert '__cell1' in globals(), "You must execute cell 1 before!"

import os
import sys
import sysconfig
from dotenv import load_dotenv

# --- Project Path Setup ---
# The root path of the project is the directory of the cloned the project repository
os.chdir(_project_root_path)
if _project_root_path not in sys.path:
    sys.path.append(_project_root_path)

# --- Environment Variables ---
__dotenv_path = os.path.join(_project_root_path, ".env")
if os.path.exists(__dotenv_path):
    load_dotenv(dotenv_path=__dotenv_path)
    print("Successfully loaded environment variables from .env file.")
else:
    raise FileNotFoundError(".env file not found in the repository root.")

# For compatibility with subsequent cells, set BASE_COLAB programmatically
os.environ['BASE_COLAB'] = _project_root_path


# --- GPU Check ---
# for google colab checking if running with GPU
__gpu_info = !nvidia-smi
__gpu_info = '\n'.join(__gpu_info)
if __gpu_info.find('failed') >= 0: print('Not connected to a GPU')
else: print(__gpu_info)


# --- File Listing (for verification) ---
def _list_files(path) -> list[str] | None:
    # Check if the folder exists and list files
    if os.path.exists(path):
        files = os.listdir(path)  # List all files and directories
        if files:
            print(f"Files and directories in '{path}':")
            for file in files: print(f"- {file}")
            return files
        else:
            print(f"The folder '{path}' is empty.")
            return []
    else:
        print(f"The folder '{path}' does not exist.")
        return None

_list_files(_project_root_path)


sys.path.append(_project_root_path)  # Project root


# --- APPLY PATCHES ---

patches_to_apply = [
    # {
    #     "target_rel_path": os.path.join("<PACKAGE_DIR_NAME>", "<FILE>.py"),
    #     "patch_file_path": "source/patches/<FILE>.patch",
    #     "description": "<PACKAGE> <FILE>.py"
    # },
]

if patches_to_apply:
    print("\n--- Applying patches ---")

    # Find the site-packages directory once
    site_packages_path = sysconfig.get_path('purelib')
    print(f"Located site-packages directory at: {site_packages_path}")

    for patch_info in patches_to_apply:
        description = patch_info["description"]
        target_file = os.path.join(site_packages_path, patch_info["target_rel_path"])
        patch_file = patch_info["patch_file_path"]

        print(f"\nAttempting to patch {description}...")

        # Check that both files exist before attempting the patch
        if os.path.exists(target_file) and os.path.exists(patch_file):
            print(f"  - Found target file: {target_file}")
            print(f"  - Found patch file: {patch_file}")

            # Execute the patch command using shell access
            !patch "{target_file}" < "{patch_file}"

            print(f"  => Patch for {description} applied successfully.")
        else:
            print(f"  => ERROR: Patch for {description} failed.")
            if not os.path.exists(target_file):
                print(f"    - Target file not found at: {target_file}")
            if not os.path.exists(patch_file):
                print(f"    - Patch file not found at: {patch_file}")

    print("\n--- Finished applying all patches ---")
    # --- END OF PATCH SECTION ---


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

__cell2 = True

<a name="3-settings-must-execute"></a>
### 3. Settings (Must execute)
#### [←](#2-check-files-must-execute) [↑](#table-of-contents) [→](#4-create-encodings-must-execute)

In [None]:
assert '__cell2' in globals(), "You must execute cell 2 before!"

# @markdown Choose a protein language model:
encoding_model_name = "Rostlab/prot_t5_xl_uniref50" # @param ["onehot", "Rostlab/prot_t5_xl_uniref50", "Rostlab/prot_t5_xl_half_uniref50-enc", "ElnaggarLab/ankh-base", "ElnaggarLab/ankh-large", "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t30_150M_UR50D", "facebook/esm2_t33_650M_UR50D"]

_fasta_input_dir: str      = os.getenv("ENCODINGS_INPUT_DIR_COLAB")
_encodings_output_dir: str = os.path.join(os.getenv('ENCODINGS_OUTPUT_DIR_COLAB'), encoding_model_name)

train_file_name = "deeploc_our_train_set"  # @param {type:"string"}
val_file_name   = "deeploc_our_val_set"    # @param {type:"string"}
test_file_name  = "setHARD"                # @param {type:"string"}
                # "deeploc_test_set"

_train_encodings_file_path = os.path.join(
    _encodings_output_dir,
    f"{os.path.splitext(os.path.basename(train_file_name))[0]}.h5"
)
_val_encodings_file_path   = os.path.join(
    _encodings_output_dir,
    f"{os.path.splitext(os.path.basename(val_file_name))[0]}.h5"
)
_test_encodings_file_path  = os.path.join(
    _encodings_output_dir,
    f"{os.path.splitext(os.path.basename(test_file_name))[0]}.h5"
)


# Setting to more specified subfolders corresponding to encoding model. -> Less confusion and mistakes

_model_save_dir: str                = os.path.join(os.getenv("MODEL_SAVE_DIR_COLAB"), encoding_model_name)

_figures_save_dir: str              = os.path.join(os.getenv("FIGURES_SAVE_DIR_COLAB"), encoding_model_name)

_studies_save_dir: str              = os.path.join(os.getenv("STUDIES_SAVE_DIR_COLAB"), encoding_model_name)

_log_file_path: str                 = os.getenv("LOG_FILE_PATH_COLAB")
_training_metrics_file_path: str    = os.getenv("TRAINING_METRICS_FILE_PATH_COLAB")
_hyper_param_metrics_file_path: str = os.getenv("HYPER_PARAM_METRICS_FILE_PATH_COLAB")
_evaluation_metrics_file_path: str  = os.getenv("EVALUATION_METRICS_FILE_PATH_COLAB")

_log_file_path = os.path.join(os.path.dirname(_log_file_path), encoding_model_name, os.path.basename(_log_file_path))
_training_metrics_file_path = os.path.join(os.path.dirname(_training_metrics_file_path), encoding_model_name, os.path.basename(_training_metrics_file_path))
_hyper_param_metrics_file_path = os.path.join(os.path.dirname(_hyper_param_metrics_file_path), encoding_model_name, os.path.basename(_hyper_param_metrics_file_path))


print(f"fasta_input_dir: ............. {_fasta_input_dir}")
if not os.path.isdir(_fasta_input_dir): raise FileNotFoundError(f"Input directory {_fasta_input_dir} does not exist.")
if not _list_files(_fasta_input_dir): raise FileNotFoundError(f"Input directory {_fasta_input_dir} is empty.")

print()

print(f"encodings_output_dir: ........ {_encodings_output_dir}\n")
os.makedirs(_encodings_output_dir, exist_ok=True)

print(f"train_encodings_file_path: ... {_train_encodings_file_path}")
print(f"val_encodings_file_path: ..... {_val_encodings_file_path}")
print(f"test_encodings_file_path: .... {_test_encodings_file_path}\n")

print(f"model_save_dir: .............. {_model_save_dir}")
os.makedirs(_model_save_dir, exist_ok=True)

print(f"figures_save_dir: ............ {_figures_save_dir}")
os.makedirs(_encodings_output_dir, exist_ok=True)

print(f"studies_save_dir: ............ {_studies_save_dir}")
os.makedirs(_studies_save_dir, exist_ok=True)

print(f"log_file_path: ............... {_log_file_path}")
print(f"training_metrics_file_path: .. {_training_metrics_file_path}")
print(f"hyper_param_metrics_file_path: {_hyper_param_metrics_file_path}")

print("\n")


# @markdown Edit config.yaml?
edit_config = False  # @param {type:"boolean"}

__config_file_path = os.getenv("CONFIG_PATH_COLAB")

if edit_config:
    import ipywidgets as widgets
    from IPython.display import display, Markdown

    __content: str
    with open(__config_file_path, "r") as file:
        __content = file.read()

    __textarea = widgets.Textarea(
        value = __content,
        layout = widgets.Layout(width='100%', height='400px')
    )
    __save_button = widgets.Button(description="Save", button_style='success')
    __output = widgets.Output()

    def save_config(_):
        with __output:
            # output.clear_output()
            try:
                with open(__config_file_path, 'w') as file:
                    file.write(__textarea.value)
                print("Configuration saved!")
            except Exception as e:
                print(f"Error: {e}")

    __save_button.on_click(save_config)

    display(Markdown("**Edit your config below:**"))
    display(__textarea)
    display(__save_button, __output)

else:
    from IPython.display import display, Markdown

    with open(__config_file_path, "r") as file:
        __content = file.read()
        display(Markdown(f"```yaml\n{__content}\n```"))


__cell3 = True

<a name="4-create-encodings-must-execute"></a>
### 4. Create Encodings (Must execute)
#### [←](#3-settings-must-execute) [↑](#table-of-contents) [→](#5-visualize-data)
Pre-computed embeddings for the default datasets, generated using the Rostlab/prot_t5_xl_uniref50 model, are included with this project and are ready to download.

In [None]:
assert '__cell3' in globals(), "You must execute cell 3 before!"

import time
from source.data_scripts.encodings import main as generate_embeddings


# @markdown Download preprocessed prot_t5_xl_uniref50 embeddings.
download_embeddings = True  # @param {type:"boolean"}

# @markdown Choose batch size (might impact memory):
batch_size = 18  # @param {type:"slider", min:1, max:64, step:1}
# reaching ~10GB GPU Memory Peaks

# @markdown Choose number of threads (ignored if CUDA/GPU availible)
threads = 8  # @param {type:"slider", min:1, max:16, step:1}

if download_embeddings and encoding_model_name == "Rostlab/prot_t5_xl_uniref50":

    print("Downloading preprocessed prot_t5_xl_uniref50 embeddings.")
    !wget {_encodings_download_url} -O dataset.zip
    !unzip -q dataset.zip -d {_encodings_output_dir}

else:

    print("Generating Embeddings")

    if not _list_files(_fasta_input_dir):
        raise FileNotFoundError("Input directory is empty or does not exist.")

    try:
        await generate_embeddings(
            model_name = encoding_model_name,
            input_dir = _fasta_input_dir,
            output_dir = _encodings_output_dir,
            batch_size = batch_size,
            threads = threads
        )
    except Exception as e:
        print("Error during encoding generation:")
        print(str(e))
        terminate = True

    time.sleep(60)  # "Wait a minute." - Kazoo Kid
    # Prematurely terminating would make some file transfers to drive incomplete.

print(f"Contents of '{_encodings_output_dir}':")
_list_files(_encodings_output_dir)

__cell4 = True

<a name="5-visualize-data"></a>
### 5. Visualize Data
#### [←](#4-create-encodings-must-execute) [↑](#table-of-contents) [→](#6-train-model)

In [None]:
assert '__cell4' in globals(), "You must execute cell 4 before!"

from source.data_scripts.data_figures import DataFiguresCollection, PoolingType
from source.data_scripts.encodings import load_metadata_from_hdf5, stream_seq_enc_data_from_hdf5


if not _list_files(_encodings_output_dir):
    raise FileNotFoundError("Input directory is empty or does not exist.")


__figures: DataFiguresCollection = DataFiguresCollection(save_dir=_figures_save_dir)
__figures.class_distribution()
__figures.pca_embedding(pooling_type=PoolingType.Per_Protein_Mean)
__figures.tsne_embedding(pooling_type=PoolingType.Per_Protein_Mean)
__figures.umap_embedding(pooling_type=PoolingType.Per_Protein_Mean)
__figures.pca_embedding(pooling_type=PoolingType.Per_Protein_Max)
__figures.tsne_embedding(pooling_type=PoolingType.Per_Protein_Max)
__figures.umap_embedding(pooling_type=PoolingType.Per_Protein_Max)
__figures.raw_sequence_length_distribution(bins=150, log_y=False)
__figures.embedding_length_distribution(bins=150, log_y=False)
__figures.embedding_length_distribution(bins=150, log_y=True)
# __figures.pairwise_distance_distribution(bins=150, distance_metric="euclidean", pooling_type=__pooling_type)
# Simply too many pairwise comparisons resulting in too many values for the figure to handle...


for __file_name, __file_path in [
    (train_file_name, _train_encodings_file_path),
    (val_file_name,   _val_encodings_file_path),
    (test_file_name,  _test_encodings_file_path),
    ("Aggregated", [_train_encodings_file_path, _val_encodings_file_path, _test_encodings_file_path])
]:

    __encoding_dim, __count, __label_count = load_metadata_from_hdf5(file_path=__file_path)
    __seq_enc_data_generator_supplier = lambda: stream_seq_enc_data_from_hdf5(file_path=__file_path)

    __identifier = f"{encoding_model_name}_{__file_name}"

    __figures.update(
        identifier = __identifier.replace("/", "_"),
        label_count = __label_count,
        seq_enc_data_generator_supplier = __seq_enc_data_generator_supplier,
        file_name = __file_name,
        encoding_model = encoding_model_name,
    )

    __figures.save("data_visualization")

    print(f"{__identifier}")
    print(f"Encoding dimension: {__encoding_dim}")
    print(f"Number of sequences: {__count}")

del __figures

<a name="6-train-model"></a>
### 6. Train Model
#### [←](#5-visualize-data) [↑](#table-of-contents) [→](#7-continue-training-a-saved-model)

In [None]:
assert '__cell4' in globals(), "You must execute cell 4 before!"

from source.models.abstract import AbstractModel
from source.models.ffn import MLP, MLPpp, FastKAN
from source.config import AbstractTrainingConfig, HyperParamConfig, ConfigType, parse_config
from source.models.other.attention_lstm_hybrid import AttentionLstmHybridFastKAN
from source.models.other.lstm_reduction_hybrid import LstmAttentionReductionHybridFastKAN
from source.models.other.light_attention import (
    LightAttention,
    LightAttentionFastKAN,
)
from source.models.reduced_ffn import (
    MaxPoolFastKAN,
    MaxPoolMLP,
    AvgPoolFastKAN,
    AvgPoolMLP,
    LinearFastKAN,
    LinearMLP,
    AttentionFastKAN,
    AttentionMLP,
    PositionalFastKAN,
    PositionalMLP,
    UNetFastKAN,
    UNetMLP
)


__model_name_to_class: dict[str, AbstractModel] = {
    "MLP": MLP,
    "MLP_PerProtein": MLPpp,
    "FastKAN": FastKAN,

    "LightAttention": LightAttention,
    "LightAttentionFastKAN": LightAttentionFastKAN,

    "MaxPoolFastKAN": MaxPoolFastKAN,
    "MaxPoolMLP": MaxPoolMLP,
    "AvgPoolFastKAN": AvgPoolFastKAN,
    "AvgPoolMLP": AvgPoolMLP,
    "LinearFastKAN": LinearFastKAN,
    "LinearMLP": LinearMLP,
    "AttentionFastKAN": AttentionFastKAN,
    "AttentionMLP": AttentionMLP,
    "PosFastKAN": PositionalFastKAN,
    "PosMLP": PositionalMLP,
    "UNetFastKAN": UNetFastKAN,
    "UNetMLP": UNetMLP,

    "AttentionLstmHybridFastKAN": AttentionLstmHybridFastKAN,
    "LstmAttentionReductionHybridFastKAN": LstmAttentionReductionHybridFastKAN
}

# @ Select Model to train (adjust parameters in config.yaml):
model_name = "AttentionFastKAN"  # @param ["MLP", "MLP_PerProtein", "FastKAN", "MaxPoolFastKAN", "MaxPoolMLP", "AvgPoolFastKAN", "AvgPoolMLP", "LinearFastKAN", "LinearMLP", "AttentionFastKAN", "AttentionMLP", "PosFastKAN", "PosMLP", "UNetFastKAN", "UNetMLP", "AttentionLstmHybrid", "LstmReductionHybrid", "LightAttention", "LightAttentionFastKAN"]

use_config_defaults = False  # @param {type:"boolean"}

epochs = 50  # @param {type:"slider", min:1, max:100, step:1}
patience = 20  # @param {type:"slider", min:1, max:20, step:1}
batch_size = 28  # @param {type:"slider", min:1, max:64, step:1}

learning_rate = 0.00005  # @param ["0.001", "0.0005", "0.0001", "0.00005"] {"type":"raw"}
learning_rate_decay = 0.98  # @param ["0.80", "0.85", "0.90", "0.925", "0.95", "0.96", "0.97", "0.98", "0.99", "0.995", "1.0"] {"type":"raw"}
l2_penalty = 0.0  # @param ["0.0", "0.001", "0.0005", "0.0001", "0.00005", "0.00001"] {"type":"raw"}
weight_factor = 0.00  # @param {type:"slider", min:0.00, max:1.00, step:0.05}

# @markdown Select if you want to perform a hyperparameter search or just train the model:
do_hyper_param_search = False  # @param {type:"boolean"}


if not _list_files(_encodings_output_dir):
    raise FileNotFoundError(f"Encodings directory {_encodings_output_dir} is empty or does not exist.")

__model_class = __model_name_to_class[model_name]

if use_config_defaults:
    training_config: AbstractTrainingConfig = parse_config(ConfigType.HyperParam) if do_hyper_param_search else parse_config(ConfigType.Training)
    epochs = training_config.epochs
    patience = training_config.patience
    batch_size = training_config.batch_size
    learning_rate = training_config.learning_rate
    learning_rate_decay = training_config.learning_rate_decay
    l2_penalty = training_config.l2_penalty
    weight_factor = training_config.weight_factor
    print(f"Loaded parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, l2_penalty={l2_penalty}, weight_factor={weight_factor}")
else:
    print(f"Using manually defined parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, l2_penalty={l2_penalty}, weight_factor={weight_factor}")


try:

    if do_hyper_param_search:
        from source.training.hyper_param import main as tune
        hyper_param_config: HyperParamConfig = parse_config(ConfigType.HyperParam)

        __n_trials = hyper_param_config.n_trials
        __timeout = hyper_param_config.timeout

        tune(
            model_class = __model_class,
            train_encodings_file_path = _train_encodings_file_path,
            val_encodings_file_path = _val_encodings_file_path,
            epochs = epochs,
            patience = patience,
            batch_size = batch_size,
            l2_penalty = l2_penalty,
            weight_factor = weight_factor,
            learning_rate = learning_rate,
            learning_rate_decay = learning_rate_decay,
            n_trials = __n_trials,
            timeout = __timeout,
            study_name = __model_class.name(),
            studies_save_dir = _studies_save_dir,
            model_save_dir = _model_save_dir,
            figures_save_dir = _figures_save_dir,
            metrics_file_path= _hyper_param_metrics_file_path,
            log_file_path = _log_file_path
        )
        _list_files(f"{_model_save_dir}/hyper_param")

    else:
        from source.training.train_model import main as train

        train(
            model = __model_class,
            train_encodings_file_path = _train_encodings_file_path,
            val_encodings_file_path = _val_encodings_file_path,
            epochs = epochs,
            patience = patience,
            batch_size = batch_size,
            l2_penalty = l2_penalty,
            weight_factor = weight_factor,
            learning_rate = learning_rate,
            learning_rate_decay = learning_rate_decay,
            model_save_dir = _model_save_dir,
            figures_save_dir = _figures_save_dir,
            metrics_file_path= _training_metrics_file_path,
            log_file_path = _log_file_path
        )
        _list_files(f"{_model_save_dir}/training")

except Exception as e:
    print("Error during training:")
    print(str(e))
    import traceback
    traceback.print_exc()

<a name="7-continue-training-a-saved-model"></a>
### 7. Continue Training a Saved Model
#### [←](#6-train-model) [↑](#table-of-contents) [→](#8-model-comparison)

In [None]:
assert '__cell4' in globals(), "You must execute cell 4 before!"

from IPython.display import display, clear_output
from source.training.train_model import main as train


use_config_defaults = True  # @param {type:"boolean"}
epochs = 100  # @param {type:"slider", min:1, max:250, step:1}
# consider higher patience. Sometimes model recalibrate themselves after overfitting on epoch 3
patience = 20  # @param {type:"slider", min:1, max:50, step:1}
batch_size = 28  # @param {type:"slider", min:1, max:64, step:1}

learning_rate = 0.00005  # @param ["0.001", "0.005", "0.001", "0.0005", "0.0001", "0.00005"] {"type":"raw"}
learning_rate_decay = 0.98  # @param ["0.80", "0.85", "0.90", "0.95", "0.98", "0.99"] {"type":"raw"}

# Testing if models can be improved by continuing to train with class weights.
l2_penalty = 0.0001  # @param ["0.0", "0.001", "0.0005", "0.0001", "0.00005", "0.00001"] {"type":"raw"}
weight_factor = 0.5  # @param {type:"slider", min:0.00, max:1.00, step:0.05}


if use_config_defaults:
    from source.config import TrainingConfig, ConfigType, parse_config
    training_config: TrainingConfig = parse_config(ConfigType.Training)
    epochs = training_config.epochs
    patience = training_config.patience
    batch_size = training_config.batch_size
    learning_rate = training_config.learning_rate
    learning_rate_decay = training_config.learning_rate_decay
    l2_penalty = training_config.l2_penalty
    weight_factor = training_config.weight_factor
    print(f"Loaded parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, l2_penalty={l2_penalty}, weight_factor={weight_factor}")
else:
    print(f"Using manually defined parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, l2_penalty={l2_penalty}, weight_factor={weight_factor}")


__files = _list_files(_model_save_dir)
# __files.sort()
if not __files: raise FileNotFoundError(f"Model save directory {_model_save_dir} is empty or does not exist.")


import asyncio
import ipywidgets as widgets

def wait_for_click(button: widgets.Button) -> asyncio.Future:
    future = asyncio.Future()

    def on_click(_):
        future.set_result(True)
        button.on_click(on_click, remove=True)

    button.on_click(on_click)
    return future

# TODO: Code below runs into an deadlock.
# UI starts a different thread but the main thread is blocking the execution.
# But if the main thread is finished, then the button can execute the function.
# Because google colab thinks that the cell is idle (main thread is finished) the runtime gets disconnected after some time.
async def run_continue_training():

    __dropdown = widgets.Dropdown(
        options = __files,
        description = 'Select:',
        value = None
    )
    __button = widgets.Button(description="Continue")

    display(widgets.VBox([__dropdown, __button]))

    await wait_for_click(__button)

    clear_output(wait=True)

    if __selected_filename := __dropdown.value:
        print(f"Model selected: {__selected_filename}")

        try:
            train(
                model = os.path.join(_model_save_dir, __selected_filename),
                train_encodings_file_path = _train_encodings_file_path,
                val_encodings_file_path = _val_encodings_file_path,
                epochs = epochs,
                patience = patience,
                batch_size = batch_size,
                l2_penalty = l2_penalty,
                weight_factor = weight_factor,
                learning_rate = learning_rate,
                learning_rate_decay = learning_rate_decay,
                model_save_dir = _model_save_dir,
                figures_save_dir = _figures_save_dir,
                metrics_file_path = _training_metrics_file_path,
                log_file_path = _log_file_path
            )
            _list_files(f"{_model_save_dir}/training")

        except Exception as e:
            print("Error during continued training:")
            print(str(e))
            import traceback
            traceback.print_exc()

    else:
        print("No model was selected. Halting execution.")


await run_continue_training()

<a name="8-model-comparison"></a>
### 8. Model Comparison
#### [←](#7-continue-training-a-saved-model) [↑](#table-of-contents) [→](#9-explore-figures)

In [None]:
assert '__cell4' in globals(), "You must execute cell 4 before!"

%gui asyncio

# import ipywidgets as widgets
from typing_extensions import List
# from IPython.display import display, clear_output

from source.evaluation.evaluation import main as evaluate


iterations = 100  # @param {type:"slider", min:10, max:500, step:10}
batch_size = 28  # @param {type:"slider", min:1, max:64, step:1}

# using only final trained models
__training_model_save_dir = os.path.join(_model_save_dir, "training")
__model_file_names: List[str] = _list_files(__training_model_save_dir)
if not __model_file_names:
    raise FileNotFoundError(f"No models found in directory: {__training_model_save_dir}")



__model_file_paths = [
    os.path.join(__training_model_save_dir, model_file_name)
    for model_file_name in __model_file_names
]

for model_file_path in __model_file_paths:
    print(f"  - {os.path.basename(model_file_path)}")

evaluate(
    model_file_paths = __model_file_paths,
    test_encodings_file_path = _test_encodings_file_path,
    figures_save_dir = _figures_save_dir,
    iterations = iterations,
    batch_size = batch_size,
    metrics_file_path = _evaluation_metrics_file_path,
    log_file_path = _log_file_path
)

# TODO: Code below runs into an deadlock. (Same problem as above)
# async def run_model_comparison():
#
#     __selection = widgets.SelectMultiple(
#         options = __model_file_names,
#         value = tuple(__model_file_names),  # pre-select all
#         description = 'Models',
#         disabled = False,
#         layout = widgets.Layout(width='100%')
#     )
#
#     __button = widgets.Button(description="Load Selected Models")
#
#     display(widgets.VBox([__selection, __button]))
#
#     await wait_for_click(__button)
#
#     clear_output(wait=True)
#
#     try:
#         __model_file_paths = [
#             os.path.join(__training_model_save_dir, model_file_name)
#             for model_file_name in __selection.value
#         ]
#
#         if __model_file_paths:
#             print("Model selected:")
#             for model_file_path in __model_file_paths: print(f"  - {os.path.basename(model_file_path)}")
#             evaluate(
#                 model_file_paths = __model_file_paths,
#                 test_encodings_file_path = _test_encodings_file_path,
#                 figures_save_dir = _figures_save_dir,
#                 iterations = iterations,
#                 batch_size = batch_size
#             )
#             print("\nEvaluation Finished Successfully")
#         else:
#             print("No models selected. Please select at least one model.")
#
#     except Exception as e:
#         print("Error during model comparison:")
#         print(str(e))
#         import traceback
#         traceback.print_exc()
#
#
# await run_model_comparison()

<a name="9-explore-figures"></a>
### 9. Explore Figures
#### [←](#8-model-comparison) [↑](#table-of-contents)

In [None]:
assert '__cell4' in globals(), "You must execute cell 4 before!"

import base64
import ipywidgets as widgets

from collections import defaultdict
from typing_extensions import List, Dict, Tuple
from IPython.display import display, clear_output

from source.abstract_figures import ViewFiguresCollection


# Data Aggregation: Scan all subdirectories of the main figures directory
# Structure: { "SubdirectoryName": [("display_label", "full/path/to/figure.pkl"), ...]}
figures_by_subdir: Dict[str, List[Tuple[str, str]]] = defaultdict(list)
pdfs_by_subdir: Dict[str, List[Tuple[str, str]]] = defaultdict(list)

for subdir_name in sorted(os.listdir(_figures_save_dir)):
    subdir_path = os.path.join(_figures_save_dir, subdir_name)
    if not os.path.isdir(subdir_path): continue

    # Recursively find all .pkl files within this subdirectory
    for dirpath, _, filenames in os.walk(subdir_path):
        for filename in filenames:
            full_path = os.path.join(dirpath, filename)
            display_label = os.path.relpath(full_path, subdir_path)
            # Create a user-friendly label showing the path relative to the tab's directory
            if filename.lower().endswith(".pkl"):
                figures_by_subdir[subdir_name].append((display_label, full_path))
            elif filename.lower().endswith(".pdf"):
                pdfs_by_subdir[subdir_name].append((display_label, full_path))

if not figures_by_subdir and not pdfs_by_subdir:
    raise FileNotFoundError(f"No .pkl or .pdf files found in any subdirectories of '{_figures_save_dir}'")


# Build the figures accordion
pkl_accordion_children = []
pkl_accordion_titles = []

# Create a widget for each subdirectory that contains figures
for subdir_name, files in figures_by_subdir.items():
    # Sort files within each group for a consistent order
    sorted_files = sorted(files, key=lambda item: item[0])

    file_selector = widgets.SelectMultiple(
        options = sorted_files,
        description = ' ', # An empty description looks cleaner
        disabled = False,
        layout = widgets.Layout(width='95%', height='200px')
    )
    pkl_accordion_children.append(file_selector)
    pkl_accordion_titles.append(subdir_name)

# Create the accordion widget itself
pkl_accordion = widgets.Accordion(children=pkl_accordion_children)
for i, title in enumerate(pkl_accordion_titles):
    pkl_accordion.set_title(i, title)


pdf_accordion_children = []
pdf_accordion_titles = []

for subdir_name, files in pdfs_by_subdir.items():
    sorted_files = sorted(files, key=lambda item: item[0])
    file_selector = widgets.SelectMultiple(
        options = sorted_files,
        description = ' ',
        disabled = False,
        layout = widgets.Layout(width='95%', height='150px')
    )
    pdf_accordion_children.append(file_selector)
    pdf_accordion_titles.append(f"{subdir_name} Reports")

pdf_accordion = widgets.Accordion(children=pdf_accordion_children)
for i, title in enumerate(pdf_accordion_titles):
    pdf_accordion.set_title(i, title)

load_button = widgets.Button(description="Load Selected")


def on_button_clicked(_):

    clear_output(wait=True)

    # Figures
    figures = ViewFiguresCollection()  # Re-init to clear old figures
    all_selected_figure_paths = []

    # Iterate through all  the SelectMultiple widgets in the accordion
    for child_widget in pkl_accordion.children:
        # child_widget.value is a tuple of the full paths selected in that box
        all_selected_figure_paths.extend(child_widget.value)

    if all_selected_figure_paths:
        for full_path in all_selected_figure_paths: figures.load(full_path)
        print(f"Loaded and displayed {len(all_selected_figure_paths)} figures.")
        figures.display(clear=False)

    # PDFs
    all_selected_pdf_paths = []

    for child_widget in pdf_accordion.children:
        all_selected_pdf_paths.extend(child_widget.value)

    if all_selected_pdf_paths:
        print(f"\nDisplaying {len(all_selected_pdf_paths)} PDF reports:")
        pdf_widgets_to_display = []

        for full_path in all_selected_pdf_paths:

            try:
                display_label = os.path.relpath(full_path, _figures_save_dir)
                title_widget = widgets.HTML(f"<h4>--- {display_label} ---</h4>")

                # IFrame will route the view to localhost:8080 for the file. Thats why we have to load it.
                with open(full_path, "rb") as f: pdf_bytes = f.read()
                base64_pdf = base64.b64encode(pdf_bytes).decode('utf-8')
                pdf_data_uri = f"data:application/pdf;base64,{base64_pdf}"

                # multiple pdfs need can be loaded as html widgets.
                iframe_html = f"<iframe src='{pdf_data_uri}' width='100%' height='600'></iframe>"
                html_widget = widgets.HTML(value=iframe_html)

                pdf_widgets_to_display.append(title_widget)
                pdf_widgets_to_display.append(html_widget)

            except Exception as error:
                pdf_widgets_to_display.append(widgets.HTML(f"<p style='color:red;'>Error displaying PDF {full_path}: {error}</p>"))

        pdf_container = widgets.VBox(pdf_widgets_to_display)
        display(pdf_container)

    if not all_selected_figure_paths and not all_selected_pdf_paths: print("No items selected.")


load_button.on_click(on_button_clicked)


# Display the UI
ui_elements = []
if figures_by_subdir:
    ui_elements.append(widgets.HTML("<h3>Figure Selector (.pkl)</h3>"))
    ui_elements.append(pkl_accordion)
if pdfs_by_subdir:
    ui_elements.append(widgets.HTML("<h3>Report Selector (.pdf)</h3>"))
    ui_elements.append(pdf_accordion)
ui_elements.append(load_button)
display(widgets.VBox(ui_elements))

### Run this cell below to automatically stop the colab runtime.<br>

In [None]:
from google.colab import runtime
runtime.unassign()