In [1]:
from __future__ import annotations

import os

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['ARCHITECTURE'] = 'convnext'
os.environ['TILE_SIZE'] = '50'
os.environ['PRETRAINED'] = '1'
os.environ['OVERLAP'] = '25'
os.environ['BATCH_SIZE'] = '50'
os.environ['EXPERIMENTS_NAME'] = 'univariate'
os.environ['EPOCHS'] = '10'
os.environ['NUM_SAMPLES'] = '5'
os.environ['TRAIN_NUM_SAMPLES'] = '1000'
os.environ['OVERWRITE'] = '0'

In [2]:
import ray
from ray.tune import Result

from landnet.config import CPUS, GPUS, MODELS_DIR, EXPERIMENTS_NAME
from landnet.enums import GeomorphometricalVariable
from landnet.logger import create_logger
from landnet.modelling import torch_clear
from landnet.modelling.classification.train import train_model
from landnet.modelling.tune import MetricSorter, get_results_df

if GPUS:
    torch_clear()

logger = create_logger(__name__)

if not ray.is_initialized():
    ray.init(num_cpus=CPUS, num_gpus=GPUS)

2025-11-15 08:04:35,657	INFO worker.py:1917 -- Started a local Ray instance.


# Configs

In [3]:
# used to only keep best checkpoint; if this does not work also try 'val_f2_score'
sorter = MetricSorter('val_f2_score', 'max')

# Train univariate models

In [None]:
results: list[Result] = []
for variable in GeomorphometricalVariable:
    logger.info('Tuning model with variable %s' % variable)
    best_result = train_model(
        variables=[variable],
        model_name=variable.value,
        sorter=sorter,
        out_dir=MODELS_DIR / EXPERIMENTS_NAME / variable.value,
    )
    results.append(best_result)

logger.info('Converting results to dataframe %s' % variable)
df = get_results_df(results, sorter, fix_missing_predictions=True)

df.to_csv(MODELS_DIR / EXPERIMENTS_NAME / 'results.csv', index=False)

INFO: Tuning model with variable GeomorphometricalVariable.SLOPE


TypeError: unsupported operand type(s) for /: 'PosixPath' and 'list'

In [None]:
df

Unnamed: 0,model,train_accuracy,train_f1_score,train_f2_score,train_f3_score,train_negative_predictive_value,train_positive_predictive_value,train_roc_auc,train_sensitivity,train_specificity,...,validation_negative_predictive_value,validation_positive_predictive_value,validation_roc_auc,validation_sensitivity,validation_specificity,learning_rate,batch_size,tile_config,checkpoint,epoch
0,slope,0.855,0.716381,0.855724,0.915053,0.99537,0.563462,0.982497,0.983221,0.825653,...,0.975806,0.76986,0.969922,0.973412,0.786566,2.5e-05,8,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000006,6
1,shade,0.878125,0.746424,0.862898,0.910244,0.990257,0.609342,0.982508,0.963087,0.858679,...,0.961588,0.789988,0.966151,0.955687,0.813651,3.1e-05,8,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000008,8
2,tri,0.865625,0.727503,0.852644,0.904507,0.990081,0.584521,0.979528,0.963087,0.843318,...,0.963918,0.787621,0.965889,0.958641,0.810401,2.5e-05,8,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000005,5
3,poso,0.825,0.677419,0.834279,0.904059,0.996117,0.515789,0.981966,0.986577,0.788018,...,0.969144,0.738444,0.958659,0.967504,0.748646,3.1e-05,8,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000009,9
4,nego,0.886875,0.763399,0.880048,0.927279,0.994704,0.625268,0.983711,0.979866,0.865591,...,0.956919,0.772182,0.954863,0.951256,0.79415,1.5e-05,4,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000008,8
5,tpi,0.854375,0.7162,0.857143,0.917317,0.996286,0.562141,0.986167,0.986577,0.824117,...,0.963415,0.75406,0.962423,0.960118,0.770314,2.5e-05,8,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000009,9
6,cprof,0.824375,0.675144,0.830017,0.898738,0.994192,0.514991,0.98475,0.979866,0.788786,...,0.967787,0.738149,0.960065,0.966027,0.748646,1.5e-05,4,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000009,9
7,cgene,0.849375,0.710684,0.856977,0.920112,0.998122,0.553271,0.986521,0.993289,0.816436,...,0.961792,0.770511,0.965702,0.957164,0.790899,3.1e-05,8,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000009,9
8,ioc,0.739375,0.579213,0.761273,0.85037,0.987872,0.414141,0.94069,0.963087,0.688172,...,0.952522,0.696544,0.932141,0.952733,0.695558,3.1e-05,8,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000006,6
9,conv,0.18625,0.314015,0.533668,0.695937,0.0,0.18625,0.507673,1.0,0.0,...,0.0,0.423125,0.546848,1.0,0.0,6e-06,2,"TileConfig(size=TileSize(width=50, height=50),...",checkpoint_000009,9
