In [1]:
import os

import numpy as np
import pandas as pd
import xarray as xr
import lightgbm as lgb

from otb.config import ROOT_DIR, DATASETS_FP, TASKS_FP, CACHE_DIR
from otb import CACHE
from otb.tasks import TaskFactory
from otb.dataset import Dataset

In [2]:
%matplotlib inline

In [3]:
os.getcwd()

'/home/cjellen/sources/github/ot-bench'

### setup (might not be needed?)

In [4]:
CACHE._cache

{}

We might want to move from a system in which tasks are first-class to one in which experiments are first class?

### load the tasks

In [5]:
task_factory = TaskFactory()

List all supported tasks

In [6]:
task_factory.list_tasks()

{'regression.mlo_cn2.dropna.Cn2_15m', 'regression.mlo_cn2.full.Cn2_15m'}

Get details for the `mlo_cn2` regression task, with missing values dropped

In [7]:
task = task_factory.get_task("regression.mlo_cn2.dropna.Cn2_15m")

In [8]:
type(task)

otb.tasks.tasks.RegressionTask

In [9]:
task.get_all_info()

{'description': 'Regression task for MLO Cn2 data, where the last 12 days are set aside for evaluation',
 'description_long': 'This dataset evaluates regression approaches for predicting the extent of optical turbulence, as measured by Cn2 at an elevation of 15m. Optical turbulence on data collected at the Mauna Loa Solar Observatory between 27 July 2006 and 8 August 2006, inclusive, are used to evaluate prediction accuracy under the root-mean square error metric.',
 'ds_name': 'mlo_cn2',
 'train_idx': ['0:8367'],
 'test_idx': ['8367:10367'],
 'val_idx': ['10367:13943'],
 'dropna': True,
 'log_transform': True,
 'eval_metrics': ['root_mean_square_error',
  'r2_score',
  'mean_absolute_error',
  'mean_absolute_percentage_error'],
 'target': 'Cn2_15m',
 'remove': ['base_time', 'Cn2_6m', 'Cn2_15m', 'Cn2_25m']}

Get the training data

In [10]:
X_train, y_train = task.get_train_data(data_type="pd")

Train your model

In [11]:
model = lgb.LGBMRegressor()

In [12]:
model.fit(X_train, y_train)

Evaluate your model

In [13]:
task.evaluate_model(predict_call=model.predict, x_transforms=None, x_transform_kwargs=None)

{'root_mean_square_error': 0.21550584034473647,
 'r2_score': 0.8936031659084048,
 'mean_absolute_error': 0.15718607854371786,
 'mean_absolute_percentage_error': 0.011320035736497307}

Compute metrics on the evaluation set

Compare against benchmarks

### deprecated

In [14]:
ds = xr.load_dataset("otb/data/mlo_cn2/mlo_cn2.nc")

In [15]:
ds