<a name="table-of-contents"></a>
# Protein Localization Prediction using Kolmogorov-Arnold Networks.
Github: https://github.com/JinMaxx/Protein-Localization-using-KANs<br>
For optimal performance, please use a GPU runtime T4 or better.

## Table of Contents
- [01. Install pip Dependencies (Must Execute)](#1-install-pip-dependencies-must-execute)
- [02. Check Files (Must Execute)](#2-check-files-must-execute)
- [03. Settings (Must execute)](#3-settings-must-execute)
- [04. Create Encodings](#4-create-encodings)
- [05. Visualize Data](#5-visualize-data)
- [06. Train Model](#6-train-model)
- [07. Continue Training a Saved Model](#7-continue-training-a-saved-model)
- [08. Model Comparison](#8-model-comparison)
- [09. Explore Figures](#9-explore-figures)

---

License: MIT<br>
This notebook and its code are made available under the MIT License.<br>
See [https://opensource.org/licenses/MIT](https://opensource.org/licenses/MIT) for details.<br>

---

<a name="1-install-pip-dependencies-must-execute"></a>
### 1. Install pip Dependencies (Must Execute)
#### [↑](#table-of-contents) [→](#2-check-files-must-execute)

In [None]:
!git clone https://github.com/JinMaxx/Protein-Localization-using-KANs.git /content/project_root

!pip install -q \
    optuna \
    dotenv \
    pyfaidx \
    colorcet \
    biopython \
    umap_learn \
    transformers \
    sentencepiece \
    typing-extensions \
    kaleido==0.2.1 \
    plotly==5.5.0 \
    git+https://github.com/AthanasiosDelis/faster-kan.git \
    "huggingface_hub[hf_xet]"  # to ignore missing account warnings

# git+https://github.com/ZiyaoLi/fast-kan.git


%load_ext rpy2.ipython

__cell1 = True

<a name="2-check-files-must-execute"></a>
### 2. Check Files (Must Execute)
#### [←](#1-install-pip-dependencies-must-execute) [↑](#table-of-contents) [→](#3-settings-must-execute)

In [None]:
assert '__cell1' in globals(), "You must execute cell 1 before!"

import os
import sys
import sysconfig
from dotenv import load_dotenv

# --- Project Path Setup ---
# The root path of the project is the directory of the cloned the project repository
__project_root_path = "/content/project_root"
os.chdir(__project_root_path)
if __project_root_path not in sys.path:
    sys.path.append(__project_root_path)

# --- Environment Variables ---
__dotenv_path = os.path.join(__project_root_path, ".env")
if os.path.exists(__dotenv_path):
    load_dotenv(dotenv_path=__dotenv_path)
    print("Successfully loaded environment variables from .env file.")
else:
    raise FileNotFoundError(".env file not found in the repository root.")

# For compatibility with subsequent cells, set BASE_COLAB programmatically
os.environ['BASE_COLAB'] = __project_root_path


# --- GPU Check ---
# for google colab checking if running with GPU
__gpu_info = !nvidia-smi
__gpu_info = '\n'.join(__gpu_info)
if __gpu_info.find('failed') >= 0: print('Not connected to a GPU')
else: print(__gpu_info)


# --- File Listing (for verification) ---
def _list_files(path) -> list[str] | None:
    # Check if the folder exists and list files
    if os.path.exists(path):
        files = os.listdir(path)  # List all files and directories
        if files:
            print(f"Files and directories in '{path}':")
            for file in files: print(f"- {file}")
            return files
        else:
            print(f"The folder '{path}' is empty.")
            return []
    else:
        print(f"The folder '{path}' does not exist.")
        return None

_list_files(__project_root_path)


sys.path.append(__project_root_path)  # Project root


# --- APPLY PATCHES ---

patches_to_apply = [
    # {
    #     "target_rel_path": os.path.join("<PACKAGE_DIR_NAME>", "<FILE>.py"),
    #     "patch_file_path": "source/patches/<FILE>.patch",
    #     "description": "<PACKAGE> <FILE>.py"
    # },
]

if patches_to_apply:
    print("\n--- Applying patches ---")

    # Find the site-packages directory once
    site_packages_path = sysconfig.get_path('purelib')
    print(f"Located site-packages directory at: {site_packages_path}")

    for patch_info in patches_to_apply:
        description = patch_info["description"]
        target_file = os.path.join(site_packages_path, patch_info["target_rel_path"])
        patch_file = patch_info["patch_file_path"]

        print(f"\nAttempting to patch {description}...")

        # Check that both files exist before attempting the patch
        if os.path.exists(target_file) and os.path.exists(patch_file):
            print(f"  - Found target file: {target_file}")
            print(f"  - Found patch file: {patch_file}")

            # Execute the patch command using shell access
            !patch "{target_file}" < "{patch_file}"

            print(f"  => Patch for {description} applied successfully.")
        else:
            print(f"  => ERROR: Patch for {description} failed.")
            if not os.path.exists(target_file):
                print(f"    - Target file not found at: {target_file}")
            if not os.path.exists(patch_file):
                print(f"    - Patch file not found at: {patch_file}")

    print("\n--- Finished applying all patches ---")
    # --- END OF PATCH SECTION ---


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

__cell2 = True

<a name="3-settings-must-execute"></a>
### 3. Settings (Must execute)
#### [←](#2-check-files-must-execute) [↑](#table-of-contents) [→](#4-create-encodings)

In [None]:
assert '__cell2' in globals(), "You must execute cell 2 before!"

# @markdown Choose a protein language model:
encoding_model_name = "Rostlab/prot_t5_xl_uniref50" # @param ["onehot", "Rostlab/prot_t5_xl_uniref50", "Rostlab/prot_t5_xl_half_uniref50-enc", "ElnaggarLab/ankh-base", "ElnaggarLab/ankh-large", "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t30_150M_UR50D", "facebook/esm2_t33_650M_UR50D"]

_fasta_input_dir: str      = os.getenv("ENCODINGS_INPUT_DIR_COLAB")
_encodings_output_dir: str = os.path.join(os.getenv('ENCODINGS_OUTPUT_DIR_COLAB'), encoding_model_name)

train_file_name = "deeploc_our_train_set"  # @param {type:"string"}
val_file_name   = "deeploc_our_val_set"    # @param {type:"string"}
test_file_name  = "setHARD"                # @param {type:"string"}
                # "deeploc_test_set"

_train_encodings_file_path = os.path.join(
    _encodings_output_dir,
    f"{os.path.splitext(os.path.basename(train_file_name))[0]}.h5"
)
_val_encodings_file_path   = os.path.join(
    _encodings_output_dir,
    f"{os.path.splitext(os.path.basename(val_file_name))[0]}.h5"
)
_test_encodings_file_path  = os.path.join(
    _encodings_output_dir,
    f"{os.path.splitext(os.path.basename(test_file_name))[0]}.h5"
)


# Setting to more specified subfolders corresponding to encoding model. -> Less confusion and mistakes

_model_save_dir: str                = os.path.join(os.getenv("MODEL_SAVE_DIR_COLAB"), encoding_model_name)
_figures_save_dir: str              = os.path.join(os.getenv("FIGURES_SAVE_DIR_COLAB"), encoding_model_name)
_studies_save_dir: str              = os.path.join(os.getenv("STUDIES_SAVE_DIR_COLAB"), encoding_model_name)

_log_file_path: str                 = os.getenv("LOG_FILE_PATH_COLAB")
_training_metrics_file_path: str    = os.getenv("TRAINING_METRICS_FILE_PATH_COLAB")
_hyper_param_metrics_file_path: str = os.getenv("HYPER_PARAM_METRICS_FILE_PATH_COLAB")

_log_file_path = os.path.join(os.path.dirname(_log_file_path), encoding_model_name, os.path.basename(_log_file_path))
_training_metrics_file_path = os.path.join(os.path.dirname(_training_metrics_file_path), encoding_model_name, os.path.basename(_training_metrics_file_path))
_hyper_param_metrics_file_path = os.path.join(os.path.dirname(_hyper_param_metrics_file_path), encoding_model_name, os.path.basename(_hyper_param_metrics_file_path))


print(f"fasta_input_dir: ............. {_fasta_input_dir}")
if not os.path.isdir(_fasta_input_dir): raise FileNotFoundError(f"Input directory {_fasta_input_dir} does not exist.")
if not _list_files(_fasta_input_dir): raise FileNotFoundError(f"Input directory {_fasta_input_dir} is empty.")

print()

print(f"encodings_output_dir: ........ {_encodings_output_dir}\n")
os.makedirs(_encodings_output_dir, exist_ok=True)

print(f"train_encodings_file_path: ... {_train_encodings_file_path}")
print(f"val_encodings_file_path: ..... {_val_encodings_file_path}")
print(f"test_encodings_file_path: .... {_test_encodings_file_path}\n")

print(f"model_save_dir: .............. {_model_save_dir}")
os.makedirs(_encodings_output_dir, exist_ok=True)

print(f"figures_save_dir: ............ {_figures_save_dir}")
os.makedirs(_encodings_output_dir, exist_ok=True)

print(f"studies_save_dir: ............ {_studies_save_dir}")
os.makedirs(_encodings_output_dir, exist_ok=True)

print(f"log_file_path: ............... {_log_file_path}")
print(f"training_metrics_file_path: .. {_training_metrics_file_path}")
print(f"hyper_param_metrics_file_path: {_hyper_param_metrics_file_path}")

print("\n")


# @markdown Edit config.yaml?
edit_config = False  # @param {type:"boolean"}

__config_file_path = os.getenv("CONFIG_PATH_COLAB")

if edit_config:
    import ipywidgets as widgets
    from IPython.display import display, Markdown

    __content: str
    with open(__config_file_path, "r") as file:
        __content = file.read()

    __textarea = widgets.Textarea(
        value = __content,
        layout = widgets.Layout(width='100%', height='400px')
    )
    __save_button = widgets.Button(description="Save", button_style='success')
    __output = widgets.Output()

    def save_config(_):
        with __output:
            # output.clear_output()
            try:
                with open(__config_file_path, 'w') as file:
                    file.write(__textarea.value)
                print("Configuration saved!")
            except Exception as e:
                print(f"Error: {e}")

    __save_button.on_click(save_config)

    display(Markdown("**Edit your config below:**"))
    display(__textarea)
    display(__save_button, __output)

else:
    from IPython.display import display, Markdown

    with open(__config_file_path, "r") as file:
        __content = file.read()
        display(Markdown(f"```yaml\n{__content}\n```"))


__cell3 = True

<a name="4-create-encodings"></a>
### 4. Create Encodings
#### [←](#3-settings-must-execute) [↑](#table-of-contents) [→](#5-visualize-data)

In [None]:
assert '__cell3' in globals(), "You must execute cell 3 before!"

import time
from source.data_scripts.encodings import main as generate_embeddings

# @markdown Choose batch size (might impact memory):
batch_size = 18  # @param {type:"slider", min:1, max:64, step:1}
# reaching ~10GB GPU Memory Peaks

# @markdown Choose number of threads (ignored if CUDA/GPU availible)
threads = 8  # @param {type:"slider", min:1, max:16, step:1}

# @markdown Automatically terminate colab runtime:
terminate = True  # @param {type:"boolean"}

if not _list_files(_fasta_input_dir):
    raise FileNotFoundError("Input directory is empty or does not exist.")


try:
    await generate_embeddings(
        model_name = encoding_model_name,
        input_dir = _fasta_input_dir,
        output_dir = _encodings_output_dir,
        batch_size = batch_size,
        threads = threads
    )
except Exception as e:
    print("Error during encoding generation:")
    print(str(e))
    terminate = True

time.sleep(60)  # "Wait a minute." - Kazoo Kid
# Prematurely terminating would make some file transfers to drive incomplete.

if terminate:
  from google.colab import runtime
  runtime.unassign()

<a name="5-visualize-data"></a>
### 5. Visualize Data
#### [←](#4-create-encodings) [↑](#table-of-contents) [→](#6-train-model)

In [None]:
assert '__cell3' in globals(), "You must execute cell 3 before!"

from source.data_scripts.data_figures import DataFiguresCollection, PoolingType
from source.data_scripts.encodings import load_metadata_from_hdf5, stream_seq_enc_data_from_hdf5


if not _list_files(_encodings_output_dir):
    raise FileNotFoundError("Input directory is empty or does not exist.")

# @markdown Automatically terminate colab runtime:
terminate = True  # @param {type:"boolean"}


__figures: DataFiguresCollection = DataFiguresCollection(save_dir=_figures_save_dir)
__figures.class_distribution()
__figures.pca_embedding(pooling_type=PoolingType.Per_Protein_Mean)
__figures.tsne_embedding(pooling_type=PoolingType.Per_Protein_Mean)
__figures.umap_embedding(pooling_type=PoolingType.Per_Protein_Mean)
__figures.pca_embedding(pooling_type=PoolingType.Per_Protein_Max)
__figures.tsne_embedding(pooling_type=PoolingType.Per_Protein_Max)
__figures.umap_embedding(pooling_type=PoolingType.Per_Protein_Max)
__figures.raw_sequence_length_distribution(bins=150, log_y=False)
__figures.embedding_length_distribution(bins=150, log_y=False)
__figures.embedding_length_distribution(bins=150, log_y=True)
# __figures.pairwise_distance_distribution(bins=150, distance_metric="euclidean", pooling_type=__pooling_type)
# Simply too many pairwise comparisons resulting in too many values for the figure to handle...


for __file_name, __file_path in [
    (train_file_name, _train_encodings_file_path),
    (val_file_name,   _val_encodings_file_path),
    (test_file_name,  _test_encodings_file_path),
    ("Aggregated", [_train_encodings_file_path, _val_encodings_file_path, _test_encodings_file_path])
]:

    __encoding_dim, __count, __label_count = load_metadata_from_hdf5(file_path=__file_path)
    __seq_enc_data_generator_supplier = lambda: stream_seq_enc_data_from_hdf5(file_path=__file_path)

    __identifier = f"{encoding_model_name}_{__file_name}"

    __figures.update(
        identifier = __identifier.replace("/", "_"),
        label_count = __label_count,
        seq_enc_data_generator_supplier = __seq_enc_data_generator_supplier,
        file_name = __file_name,
        encoding_model = encoding_model_name,
    )

    __figures.save("data_visualization")

    print(f"{__identifier}")
    print(f"Encoding dimension: {__encoding_dim}")
    print(f"Number of sequences: {__count}")

del __figures


if terminate:
  from google.colab import runtime
  runtime.unassign()

<a name="6-train-model"></a>
### 6. Train Model
#### [←](#5-visualize-data) [↑](#table-of-contents) [→](#7-continue-training-a-saved-model)

In [None]:
assert '__cell3' in globals(), "You must execute cell 3 before!"

from source.metrics.metrics import Metrics
from source.models.abstract import AbstractModel
from source.training.utils.save_state import SaveState

from source.models.ffn import MLP, MLPpp, FastKAN
from source.models.other.attention_lstm_hybrid import AttentionLstmHybrid
from source.models.other.lstm_reduction_hybrid import LstmReductionHybrid
from source.models.other.light_attention import (
    LightAttention,
    LightAttentionFastKAN,
)
from source.models.reduced_ffn import (
    MaxPoolFastKAN,
    MaxPoolMLP,
    AvgPoolFastKAN,
    AvgPoolMLP,
    LinearFastKAN,
    LinearMLP,
    AttentionFastKAN,
    AttentionMLP,
    PositionalFastKAN,
    PositionalMLP,
    UNetFastKAN,
    UNetMLP
)


__model_name_to_class: dict[str, AbstractModel] = {
    "MLP": MLP,
    "MLP_PerProtein": MLPpp,
    "FastKAN": FastKAN,

    "LightAttention": LightAttention,
    "LightAttentionFastKAN": LightAttentionFastKAN,

    "MaxPoolFastKAN": MaxPoolFastKAN,
    "MaxPoolMLP": MaxPoolMLP,
    "AvgPoolFastKAN": AvgPoolFastKAN,
    "AvgPoolMLP": AvgPoolMLP,
    "LinearFastKAN": LinearFastKAN,
    "LinearMLP": LinearMLP,
    "AttentionFastKAN": AttentionFastKAN,
    "AttentionMLP": AttentionMLP,
    "PosFastKAN": PositionalFastKAN,
    "PosMLP": PositionalMLP,
    "UNetFastKAN": UNetFastKAN,
    "UNetMLP": UNetMLP,

    "AttentionLstmHybrid": AttentionLstmHybrid,
    "LstmReductionHybrid": LstmReductionHybrid
}

# @ Select Model to train (adjust parameters in config.yaml):
model_name = "AttentionFastKAN"  # @param ["MLP", "MLP_PerProtein", "FastKAN", "MaxPoolFastKAN", "MaxPoolMLP", "AvgPoolFastKAN", "AvgPoolMLP", "LinearFastKAN", "LinearMLP", "AttentionFastKAN", "AttentionMLP", "PosFastKAN", "PosMLP", "UNetFastKAN", "UNetMLP", "AttentionLstmHybrid", "LstmReductionHybrid", "LightAttention", "LightAttentionFastKAN"]

use_config_defaults = True  # @param {type:"boolean"}

# @ Settings predefined for hyper parameter tuning
epochs = 20  # @param {type:"slider", min:1, max:100, step:1}
# consider higher patience. Sometimes model recalibrate themselves after overfitting on epoch 3
# Recommended: patience 8 for hyper param tuning
patience = 14  # @param {type:"slider", min:1, max:20, step:1}
batch_size = 28  # @param {type:"slider", min:1, max:64, step:1}

learning_rate = 0.00005  # @param ["0.001", "0.0005", "0.0001", "0.00005"] {"type":"raw"}
learning_rate_decay = 0.99  # @param ["0.80", "0.85", "0.90", "0.95", "0.975", "0.99", "0.999", "1.0"] {"type":"raw"}
use_weights = False  # @param {type:"boolean"}
weight_decay = 0.0001  # @param ["0.001", "0.0005", "0.0001", "0.00005", "0.00001"] {"type":"raw"}

# @markdown Select if you want to perform a hyperparameter search or just train the model:
do_hyper_param_search = False  # @param {type:"boolean"}

# @markdown Automatically terminate colab runtime:
terminate = True  # @param {type:"boolean"}

if not _list_files(_encodings_output_dir):
    raise FileNotFoundError(f"Encodings directory {_encodings_output_dir} is empty or does not exist.")

__model_class = __model_name_to_class[model_name]

if use_config_defaults:
    from source.config import TrainingConfig, ConfigType, parse_config
    training_config: TrainingConfig = parse_config(ConfigType.Training)
    epochs = training_config.epochs
    patience = training_config.patience
    batch_size = training_config.batch_size
    learning_rate = training_config.learning_rate
    learning_rate_decay = training_config.learning_rate_decay
    use_weights = training_config.use_weights
    weight_decay = training_config.weight_decay
    print(f"Loaded parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, use_weights={use_weights}, weight_decay={weight_decay}")
else:
    print(f"Using manually defined parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, use_weights={use_weights}, weight_decay={weight_decay}")


try:

    if do_hyper_param_search:
        from source.training.hyper_param import main as tune

        __n_trials = 5
        __timeout = None

            tune(
                model_class = __model_class,
                train_encodings_file_path = _train_encodings_file_path,
                val_encodings_file_path = _val_encodings_file_path,
                epochs = epochs,
                patience = patience,
                batch_size = batch_size,
                use_weights = use_weights,
                weight_decay = weight_decay,
                learning_rate = learning_rate,
                learning_rate_decay = learning_rate_decay,
                n_trials = __n_trials,
                timeout = __timeout,
                study_name = __model_class.name(),
                studies_save_dir = _studies_save_dir,
                model_save_dir = _model_save_dir,
                figures_save_dir = _figures_save_dir,
                metrics_file_path= _hyper_param_metrics_file_path,
                log_file_path = _log_file_path
            )

    else:
        from source.training.train_model import main as train
        train(
            model = __model_class,
            train_encodings_file_path = _train_encodings_file_path,
            val_encodings_file_path = _val_encodings_file_path,
            epochs = epochs,
            patience = patience,
            batch_size = batch_size,
            use_weights = use_weights,
            weight_decay = weight_decay,
            learning_rate = learning_rate,
            learning_rate_decay = learning_rate_decay,
            model_save_dir = _model_save_dir,
            figures_save_dir = _figures_save_dir,
            metrics_file_path= _training_metrics_file_path,
            log_file_path = _log_file_path
        )

except Exception as e:
    print("Error during training:")
    print(str(e))
    raise e
    terminate = True


_list_files(_model_save_dir)


if terminate:
  from google.colab import runtime
  runtime.unassign()

<a name="7-continue-training-a-saved-model"></a>
### 7. Continue Training a Saved Model
#### [←](#6-train-model) [↑](#table-of-contents) [→](#8-model-comparison)

In [None]:
assert '__cell3' in globals(), "You must execute cell 3 before!"

import ipywidgets as widgets
from IPython.display import display
from source.training.train_model import main as train
from source.training.utils.save_state import SaveState


use_config_defaults = True  # @param {type:"boolean"}
epochs = 150  # @param {type:"slider", min:1, max:250, step:1}
# consider higher patience. Sometimes model recalibrate themselves after overfitting on epoch 3
patience = 25  # @param {type:"slider", min:1, max:50, step:1}
batch_size = 28  # @param {type:"slider", min:1, max:64, step:1}

learning_rate = 0.0005  # @param ["0.001", "0.005", "0.001", "0.0005", "0.0001"] {"type":"raw"}
learning_rate_decay = 0.95  # @param ["0.80", "0.85", "0.90", "0.95", "0.975", "0.99"] {"type":"raw"}

use_weights = True  # @param {type:"boolean"}
weight_decay = 0.0001 # @param ["0.001", "0.0005", "0.0001", "0.00005", "0.00001"] {"type":"raw"}

# @markdown Automatically terminate colab runtime:
terminate = True  # @param {type:"boolean"}


if use_config_defaults:
    from source.config import TrainingConfig, ConfigType, parse_config
    training_config: TrainingConfig = parse_config(ConfigType.Training)
    epochs = training_config.epochs
    patience = training_config.patience
    batch_size = training_config.batch_size
    learning_rate = training_config.learning_rate
    learning_rate_decay = training_config.learning_rate_decay
    use_weights = training_config.use_weights
    weight_decay = training_config.weight_decay
    print(f"Loaded parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, use_weights={use_weights}, weight_decay={weight_decay}")
else:
    print(f"Using manually defined parameters: epochs={epochs}, patience={patience}, batch_size={batch_size}, learning_rate={learning_rate}, learning_rate_decay={learning_rate_decay}, use_weights={use_weights}, weight_decay={weight_decay}")


__files = _list_files(_model_save_dir)
# __files.sort()
if not __files: raise FileNotFoundError(f"Model save directory {_model_save_dir} is empty or does not exist.")


def __train(selected_filename: str):

    try:
        train(
            model = os.path.join(_model_save_dir, selected_filename),
            train_encodings_file_path = _train_encodings_file_path,
            val_encodings_file_path = _val_encodings_file_path,
            epochs = epochs,
            patience = patience,
            batch_size = batch_size,
            use_weights = use_weights,
            weight_decay = weight_decay,
            learning_rate = learning_rate,
            learning_rate_decay = learning_rate_decay,
            model_save_dir = _model_save_dir,
            figures_save_dir = _figures_save_dir,
            log_file_path = _log_file_path
        )

    except Exception as e:
        print("Error during continued training:")
        print(str(e))
        raise e
        terminate = True

    finally:
        if terminate:
            from google.colab import runtime
            runtime.unassign()


__dropdown = widgets.Dropdown(
    options = __files,
    description = 'Select:',
    value = None
)
__button = widgets.Button(description="Continue")

def __on_button_clicked(b):
    if __dropdown.value:
        print(f"Model selected: {__dropdown.value}")
        __train(__dropdown.value)
    else: print("No model selected!")

__button.on_click(__on_button_clicked)
display(__dropdown, __button)

<a name="8-model-comparison"></a>
### 8. Model Comparison
#### [←](#7-continue-training-a-saved-model) [↑](#table-of-contents) [→](#9-explore-figures)

In [None]:
assert '__cell3' in globals(), "You must execute cell 3 before!"

import ipywidgets as widgets
from typing_extensions import List
from source.metrics.metrics import Metrics
from source.models.abstract import AbstractModel
from source.evaluation.evaluation import evaluate_models


iterations = 100  # @param {type:"slider", min:10, max:1000, step:10}
batch_size = 28  # @param {type:"slider", min:1, max:64, step:1}

# # @markdown Automatically terminate colab runtime:
# terminate = True  # @param {type:"boolean"}


__model_file_names: List[str] = _list_files(_model_save_dir)

if not __model_file_names:
    raise FileNotFoundError(f"No models found in directory: {_model_save_dir}")


__selection = widgets.SelectMultiple(
    options = __model_file_names,
    # value=tuple(__model_file_names),  # pre-select all
    description = 'Models',
    disabled = False,
    layout = widgets.Layout(width='100%')
)

__button = widgets.Button(description="Load Selected Models")

# __output = widgets.Output()

def on_button_clicked(_):
    __output.clear_output()
    # with __output:
    model_file_paths = [
        os.path.join(_model_save_dir, model_file_name)
        for model_file_name in __selection.value
    ]
    print("Model selected:")
    for model_file_path in model_file_paths: print(model_file_path)
    try:
        evaluate_models(
            model_file_paths = model_file_paths,
            test_encodings_file_path = _test_encodings_file_path,
            figures_save_dir = _figures_save_dir,
            iterations = iterations,
            batch_size = batch_size
        )
    except Exception as e:
        print("Error during model comparison:")
        print(str(e))
        raise e
        terminate = True

__button.on_click(on_button_clicked)

display(widgets.VBox([__selection, __button])) #, __output]))

<a name="9-explore-figures"></a>
### 9. Explore Figures
#### [←](#8-model-comparison) [↑](#table-of-contents)

In [None]:
assert '__cell3' in globals(), "You must execute cell 3 before!"

import ipywidgets as widgets
from typing_extensions import List
from IPython.display import display
from source.abstract_figures import AbstractFiguresCollection


__figures: AbstractFiguresCollection = AbstractFiguresCollection()

__figures_file_names: List[str] = [
    __file_name
    for __file_name in _list_files(_figures_save_dir)
    if __file_name.lower().endswith(".pkl")
]

if not __figures_file_names:
    raise FileNotFoundError(f"No figures found in directory: {_figures_save_dir}")


__selection = widgets.SelectMultiple(
    options = __figures_file_names,
    # value=tuple(__figures_file_names),  # pre-select all
    description = 'Figures',
    disabled = False,
    layout = widgets.Layout(width='100%')
)

__button = widgets.Button(description="Load Selected Figures")

# __output = widgets.Output()

def on_button_clicked(_):
    # __output.clear_output()
    # with __output:
    for file_name in __selection.value:
        __figures.load(os.path.join(_figures_save_dir, file_name))
    print(f"Loaded and displayed {len(__selection.value)} figures.")
    __figures.display()

__button.on_click(on_button_clicked)

display(widgets.VBox([__selection, __button])) #, __output]))