# 🧙🏻‍♂️ GANDALF

Gated Adaptive Network for Deep Automated Learning of Features (GANDALF): 
 - [Paper](https://arxiv.org/abs/2207.08548) 
 - [Model](https://pytorch-tabular.readthedocs.io/en/latest/models/#gated-adaptive-network-for-deep-automated-learning-of-features-gandalf)

# 📦 Setup

In [1]:
import datetime
import json
import os

import matplotlib.pyplot as plt
import pandas as pd
import pytorch_tabular
import seaborn as sns
import torch
import wandb
from pytorch_tabular import TabularModel
from pytorch_tabular.config import (
    DataConfig,
    ExperimentConfig,
    OptimizerConfig,
    TrainerConfig,
)
from pytorch_tabular.models import (
    GANDALFConfig,
)
from rich.pretty import pprint

In [2]:
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
pprint(f"Using device: {device}")

if device.type == "cuda":
    gpu_count = torch.cuda.device_count()
    gpu_name = torch.cuda.get_device_name(0)
    pprint(f"GPU Count: {gpu_count} | GPU Name: {gpu_name}")

wandb.login()

pprint(
    f"Versions: Torch: {torch.__version__}, PyTorch Tabular: {pytorch_tabular.__version__}"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcatherine-chahrour[0m ([33mcatherine-chahrour-university-of-oxford[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
model = "GANDALF_SEM"
project = "SEM_MLL-N_TF"
region_name = "promoters_1024bp"
target = "MLL-N"
task = "singlelabel_regression"
start_time = datetime.datetime.now().strftime("%Y-%m-%d_%H%M")
group = f"{model}_{region_name}_{target}_{task}"
results_dir = f"results/{project}/{group}_{start_time}"
os.environ["WANDB_DIR"] = f"{results_dir}"
os.makedirs(results_dir, exist_ok=True)

pprint(f"Project: {project} | Group: {group}")

# 📊 Load Data

In [4]:
data = pd.read_parquet(f"/Users/catherine/GMS/project/datasets/data_{region_name}/{region_name}.parquet")

for col in data.select_dtypes(include=["float64"]).columns:
    data[col] = data[col].astype("float32")

meth_cols = [col for col in data.columns if "METH" in col]
data[meth_cols] = data[meth_cols].fillna(-1)
X_data = data[[col for col in data.columns if "SEM" in col and target not in col]]
y_data = data[["SEM_CAT_1_MLL-N"]]

dataset = pd.concat([X_data, y_data], axis=1)

train_data = dataset[~dataset.index.str.startswith(("chr8", "chr9"))]
val_data = dataset[dataset.index.str.startswith("chr8")]
test_data = dataset[dataset.index.str.startswith("chr9")]

# ⚙️ Config

In [5]:
data_config = DataConfig(
    continuous_cols=[col for col in train_data.columns if target not in col],
    continuous_feature_transform="quantile_uniform",
    dataloader_kwargs={"persistent_workers": True},
    normalize_continuous_features=True,
    num_workers=10,
    pin_memory=True,
    target=[col for col in train_data.columns if target in col],
    validation_split=0,
)

In [6]:
optimizer_config = OptimizerConfig()


def train():
    """Trains a model with the hyperparameters defined in the sweep."""
    if wandb.run is not None:
        wandb.finish()
    with wandb.init(
        name=f"run_{datetime.datetime.now().strftime('%Y-%m-%d_%H%M')}",
        project=project,
        group=group,
        job_type="sweep",
        dir=f"{results_dir}/wandb",
        reinit="finish_previous",
    ) as run:
        config = run.config

        trainer_config = TrainerConfig(
            accelerator="mps" if device.type == "mps" else "gpu",
            auto_lr_find=True,
            batch_size=config.batch_size,
            check_val_every_n_epoch=5,
            checkpoints_path=f"{results_dir}/checkpoints",
            early_stopping_mode="min",
            early_stopping_patience=3,
            early_stopping="valid_loss",
            load_best=True,
            max_epochs=config.max_epochs,
            progress_bar="rich",
            trainer_kwargs=dict(enable_model_summary=False),
        )

        experiment_config = ExperimentConfig(
            exp_log_freq=5,
            exp_watch="gradients",
            log_logits=False,
            log_target="wandb",
            project_name=project,
            run_name=run.name,
        )

        model_config = GANDALFConfig(
            embedding_dropout=config.embedding_dropout,
            gflu_dropout=config.gflu_dropout,
            gflu_feature_init_sparsity=config.gflu_feature_init_sparsity,
            gflu_stages=config.gflu_stages,
            head="LinearHead",
            loss="MSELoss",
            metrics=["r2_score", "mean_squared_error"],
            metrics_params=[{}] * 2,
            seed=42,
            target_range=[(0, 1)],
            task="regression",
        )

        model = TabularModel(
            data_config=data_config,
            experiment_config=experiment_config,
            model_config=model_config,
            optimizer_config=optimizer_config,
            trainer_config=trainer_config,
            verbose=False,
            suppress_lightning_logger=True,
        )

        model.fit(train=train_data, validation=val_data)
        model.predict(test_data)

# 🧹Sweep

In [None]:
with open("config/sweep_config.json", "r") as f:
    sweep_config = json.load(f)

sweep_config["name"] = group
sweep_id = wandb.sweep(sweep_config, project=project)
wandb.agent(sweep_id=sweep_id, function=train, count=50, project=project)



Create sweep with ID: u4lhvzff
Sweep URL: https://wandb.ai/catherine-chahrour-university-of-oxford/SEM_MLL-N_TF/sweeps/u4lhvzff


[34m[1mwandb[0m: Agent Starting Run: n9pjob5i with config:
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	embedding_dropout: 0.134332830078693
[34m[1mwandb[0m: 	gflu_dropout: 0.0127280376735874
[34m[1mwandb[0m: 	gflu_feature_init_sparsity: 0.2003119120000716
[34m[1mwandb[0m: 	gflu_stages: 4
[34m[1mwandb[0m: 	lr: 1.0041171739034491
[34m[1mwandb[0m: 	max_epochs: 62


/Users/catherine/miniforge3/envs/model/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]



Output()

  return torch.load(f, map_location=map_location)


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇█████
train_loss,▇█▃▁▂▃▁▃▂▃▅█▅▆▃▁▄▄▂▄▇▄▅▇▄▃▇▃▂▄▅▃▄▃▃▅▁▄▂▄
train_mean_squared_error,█▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁
train_r2_score,▁▇▇▇▇▇▇▇▇▇██▇███████
trainer/global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇██
valid_loss,▁▁█▅
valid_mean_squared_error,▁▁█▅
valid_r2_score,██▁▅

0,1
epoch,19.0
train_loss,0.00467
train_mean_squared_error,0.00459
train_r2_score,0.89154
trainer/global_step,3119.0
valid_loss,0.00396
valid_mean_squared_error,0.00396
valid_r2_score,0.91008


[34m[1mwandb[0m: Agent Starting Run: wze74knp with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	embedding_dropout: 0.1173764817113448
[34m[1mwandb[0m: 	gflu_dropout: 0.04845070125721372
[34m[1mwandb[0m: 	gflu_feature_init_sparsity: 0.2223861472079489
[34m[1mwandb[0m: 	gflu_stages: 6
[34m[1mwandb[0m: 	lr: 1.0051682991853323
[34m[1mwandb[0m: 	max_epochs: 139


/Users/catherine/miniforge3/envs/model/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
/Users/catherine/miniforge3/envs/model/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/catherine/GMS/project/models/2025-04-11_gandalf/results/SEM_MLL-N_TF/GANDALF_SEM_promoters_1024bp_MLL-N_singlelabel_regression_2025-04-26_1510/checkpoints exists and is not empty.


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]



Output()

Traceback (most recent call last):
  File "/var/folders/bs/nx8yhh5s2tn0r0307mtl2mf40000gn/T/ipykernel_27169/43235070.py", line 66, in train
    model.fit(train=train_data, validation=val_data)
  File "/Users/catherine/miniforge3/envs/model/lib/python3.11/site-packages/pytorch_tabular/tabular_model.py", line 806, in fit
    return self.train(model, datamodule, callbacks, max_epochs, min_epochs, handle_oom)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/catherine/miniforge3/envs/model/lib/python3.11/site-packages/pytorch_tabular/tabular_model.py", line 680, in train
    self.trainer.fit(self.model, train_loader, val_loader)
  File "/Users/catherine/miniforge3/envs/model/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/Users/catherine/miniforge3/envs/model/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_inte

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█
train_loss,▃▅▃▅▅▂▆▄▇▅▆▆▃▇▆▂▁▄▄▆▄▅▄▃▂▃▆▅▄█▅▅▅▆▄▆▂▄▇▃
train_mean_squared_error,█▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train_r2_score,▁▇▇▇▇▇█████████████
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
valid_loss,█▁▆
valid_mean_squared_error,█▁▆
valid_r2_score,▁█▃

0,1
epoch,19.0
train_loss,0.0037
train_mean_squared_error,0.0048
train_r2_score,0.88582
trainer/global_step,6099.0
valid_loss,0.00398
valid_mean_squared_error,0.00398
valid_r2_score,0.89634


[34m[1mwandb[0m: Agent Starting Run: rpcbf01a with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	embedding_dropout: 0.1930468698206162
[34m[1mwandb[0m: 	gflu_dropout: 0.10534681707304792
[34m[1mwandb[0m: 	gflu_feature_init_sparsity: 0.4622232075800971
[34m[1mwandb[0m: 	gflu_stages: 10
[34m[1mwandb[0m: 	lr: 1.0077279916387143
[34m[1mwandb[0m: 	max_epochs: 181


# 🚂 Train

In [None]:
pprint(sweep_id)
# cwotlgqx

In [None]:
api = wandb.Api()
sweep = api.sweep(f"catherine-chahrour-university-of-oxford/{project}/{sweep_id}")
best_run = sorted(
    sweep.runs, key=lambda r: r.summary.get("valid_r2_score", 0), reverse=True
)[0]
config = best_run.config
pprint(f"Best run: {best_run.id} | R²: {best_run.summary['valid_r2_score']}")

In [None]:
wandb.finish()