# 🧙🏻‍♂️ GANDALF

Gated Adaptive Network for Deep Automated Learning of Features (GANDALF): 
 - [Paper](https://arxiv.org/abs/2207.08548) 
 - [Model](https://pytorch-tabular.readthedocs.io/en/latest/models/#gated-adaptive-network-for-deep-automated-learning-of-features-gandalf)

# 📦 Setup

In [None]:
import datetime
import json
import os
import sys

import pandas as pd
import torch
from loguru import logger
from pytorch_tabular import TabularModel
from pytorch_tabular.config import (
    DataConfig,
    ExperimentConfig,
    OptimizerConfig,
    TrainerConfig,
)
from pytorch_tabular.models import GANDALFConfig

In [None]:
data_dir = "/Users/catherine/GMS/project/datasets"
model = "GANDALF_SEM"
project = "SEM_MLL-N_TF"
region_name = "promoters_1024bp"

start_time = datetime.datetime.now().strftime("%Y-%m-%d_%H%M")
target = "MLL-N"
task = "singlelabel_regression"
group = f"{model}_{region_name}_{target}_{task}"
best_config_path = f"results/{project}/{group}_2025-04-26_1631/best_run_config.json"
results_dir = f"final_results_{start_time}"
os.environ["WANDB_DIR"] = f"{results_dir}"
os.makedirs(results_dir, exist_ok=True)

logger.info(f"Project: {project} | Group: {group}")


In [None]:
# import argparse
# parser = argparse.ArgumentParser(
#     description="Retrain final model using best sweep config"
# )
# parser.add_argument("--data_dir", type=str, required=True, help="Dataset directory")
# parser.add_argument(
#     "--region_name", type=str, default="promoters_1024bp", help="Region name"
# )
# parser.add_argument("--target", type=str, default="MLL-N", help="Target column")
# parser.add_argument(
#     "--best_config_path",
#     type=str,
#     required=True,
#     help="Path to best_run_config.json",
# )
# parser.add_argument(
#     "--results_dir",
#     type=str,
#     default="final_results",
#     help="Directory to save final model checkpoints",
# )
# args = parser.parse_args()

# os.makedirs(args.results_dir, exist_ok=True)


In [None]:
logger.remove()

log_dir = os.path.join(results_dir, "logs")
os.makedirs(log_dir, exist_ok=True)

log_file = os.path.join(log_dir, "run.log")

logger.add(
    sink=sys.stderr,
    level="INFO",
    format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
    colorize=True,
)
logger.add(
    sink=log_file,
    level="INFO",
    format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
    colorize=False,
    enqueue=True,
)


In [None]:
logger.info(f"Loading best config from {best_config_path}")
with open(best_config_path, "r") as f:
    config = json.load(f)


In [None]:
# Load train and val data
logger.info("Loading data...")
data = pd.read_parquet(f"{data_dir}/data_{region_name}/{region_name}.parquet")

for col in data.select_dtypes(include=["float64"]).columns:
    data[col] = data[col].astype("float32")

meth_cols = [col for col in data.columns if "METH" in col]
data[meth_cols] = data[meth_cols].fillna(-1)
X_data = data[[col for col in data.columns if "SEM" in col and target not in col]]
y_data = data[["SEM_CAT_1_MLL-N"]]

dataset = pd.concat([X_data, y_data], axis=1)

train_data = dataset[~dataset.index.str.startswith(("chr8", "chr9"))]
val_data = dataset[dataset.index.str.startswith("chr8")]


In [None]:
# Setup configs
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)

data_config = DataConfig(
    continuous_cols=[col for col in train_data.columns if target not in col],
    dataloader_kwargs={"persistent_workers": True},
    normalize_continuous_features=False,
    num_workers=8,
    pin_memory=True,
    target=[col for col in train_data.columns if target in col],
    validation_split=0,
)

optimizer_config = OptimizerConfig()

trainer_config = TrainerConfig(
    accelerator="mps" if device.type == "mps" else "gpu",
    auto_lr_find=False,
    batch_size=config["batch_size"],
    check_val_every_n_epoch=5,
    checkpoints_path=os.path.join(results_dir, "checkpoints"),
    early_stopping_mode="min",
    early_stopping_patience=3,
    early_stopping="valid_loss",
    load_best=True,
    max_epochs=config["max_epochs"],
    progress_bar="rich",
    trainer_kwargs=dict(enable_model_summary=False),
)

experiment_config = ExperimentConfig(
    exp_log_freq=5,
    exp_watch="gradients",
    log_logits=False,
    log_target="wandb",
    project_name=project,
    run_name="full",
)

model_config = GANDALFConfig(
    embedding_dropout=config["embedding_dropout"],
    gflu_dropout=config["gflu_dropout"],
    gflu_feature_init_sparsity=config["gflu_feature_init_sparsity"],
    gflu_stages=config["gflu_stages"],
    learning_rate=config["lr"],
    head="LinearHead",
    loss="MSELoss",
    metrics=["r2_score", "mean_squared_error"],
    metrics_params=[{}] * 2,
    seed=42,
    target_range=[(0, 1)],
    task="regression",
)

model = TabularModel(
    data_config=data_config,
    experiment_config=experiment_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=False,
    suppress_lightning_logger=True,
)

logger.info("Building model...")
model = TabularModel(
    data_config=data_config,
    experiment_config=experiment_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    verbose=False,
    suppress_lightning_logger=True,
)


: 

In [9]:
logger.info("Training final model...")
model.fit(train=train_data, validation=val_data)

logger.success("✅ Final model trained and saved!")


: 

: 