In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '../'))

In [24]:
import pandas as pd
import warnings

import json
from numpy import random
from dataclasses import dataclass

from model.utils import report_results
from model.train import train_classifier

from sklearn.ensemble import HistGradientBoostingClassifier, AdaBoostClassifier, ExtraTreesClassifier, RandomForestClassifier

DEFAULT_RANDOM_SEED = 774
random.mtrand._rand.seed(DEFAULT_RANDOM_SEED)
seed_list = random.random_integers(low=0, high=2**32 - 1, size=100)
warnings.filterwarnings("ignore")

In [25]:
@dataclass
class RunConfiguration:
  run_grid_search: bool
  grid_search_params: dict
  default_parameters: dict

In [26]:
def get_parameters(df: pd.DataFrame, model_factory, run_config: RunConfiguration):
  if not run_config.run_grid_search:
    return run_config.default_parameters
  
  response = train_classifier(model_factory(), target="subtype", data=df, grid_search_params=run_config.grid_search_params)
  parameters = {k: response.model.get_params()[k] for k in response.model.get_params().keys() & run_config.grid_search_params.keys() }
  print(parameters)
  return parameters

def run_tests(model_factory, category: str, not_biased_config: RunConfiguration):
  data = pd.read_csv(f"../../preprocessed/{category}/genes.csv").drop(columns=["sample_id"])
  print(f"Data has {len(data.columns)} columns")
  print(report_results(data, model_factory, get_parameters(data, model_factory, not_biased_config), seed_list).report)

In [None]:
run_tests(
  category="min_tpm_5",
  model_factory=HistGradientBoostingClassifier,
  not_biased_config=RunConfiguration(
    run_grid_search=True,
    grid_search_params={"learning_rate": (0.05, 0.1, 0.5, 1), "max_features": (0.1, 0.2, 0.5, "sqrt"), "l2_regularization": (0, 0.5, 1), "max_depth": (16, 32, 64, None)},
    default_parameters={'learning_rate': 0.1, 'max_depth': 32, 'max_features': 0.1, 'l2_regularization': 0.5}
  )
)

Data has 13711 columns


In [28]:
run_tests(
  category="min_tpm_5",
  model_factory=ExtraTreesClassifier,
  not_biased_config=RunConfiguration(
    run_grid_search=True,
    grid_search_params={"n_estimators": (8, 16, 32, 64), "max_features": (0.2, 0.5, 1, "sqrt"), "max_depth": (16, 32, 64, None)},
    default_parameters={'n_estimators': 32, 'max_depth': 16, 'max_features': 0.5}
  )
)

Data has 13711 columns
{'max_depth': None, 'n_estimators': 64, 'max_features': 0.5}


100%|██████████| 100/100 [25:33<00:00, 15.33s/it]

           Metric           Overall             Male           Female
0   F1 (Weighted)   0.8609 ± 0.0252  0.8430 ± 0.0348  0.8676 ± 0.0338
1      F1 (Macro)   0.8139 ± 0.0360  0.7884 ± 0.0490  0.8170 ± 0.0513
2  Recall (Macro)   0.8044 ± 0.0322  0.7897 ± 0.0435  0.8132 ± 0.0437
3         ROC AUC   0.9814 ± 0.0085  0.9791 ± 0.0101  0.9863 ± 0.0118
4        Accuracy   0.8715 ± 0.0237  0.8571 ± 0.0327  0.8806 ± 0.0283
5        Duration  10.4790 ± 7.8348                0                0





In [29]:
run_tests(
  category="min_tpm_5",
  model_factory=RandomForestClassifier,
  not_biased_config=RunConfiguration(
    run_grid_search=True,
    grid_search_params={"n_estimators": (32, 64, 128), "max_features": (0.2, 0.5, "sqrt"), "max_depth": (None, 32)},
    default_parameters={'max_depth': None, 'n_estimators': 128, 'max_features': 0.5}
  )
)

Data has 13711 columns
{'max_depth': 32, 'n_estimators': 128, 'max_features': 0.2}


100%|██████████| 100/100 [1:39:20<00:00, 59.60s/it]

           Metric           Overall             Male           Female
0   F1 (Weighted)   0.8565 ± 0.0278  0.8480 ± 0.0365  0.8541 ± 0.0353
1      F1 (Macro)   0.8045 ± 0.0386  0.7858 ± 0.0500  0.7972 ± 0.0528
2  Recall (Macro)   0.7938 ± 0.0338  0.7910 ± 0.0431  0.7936 ± 0.0452
3         ROC AUC   0.9843 ± 0.0067  0.9823 ± 0.0084  0.9869 ± 0.0090
4        Accuracy   0.8681 ± 0.0258  0.8571 ± 0.0330  0.8657 ± 0.0303
5        Duration  59.2933 ± 1.1473                0                0





In [30]:
run_tests(
  category="min_tpm_5",
  model_factory=lambda **kwargs: AdaBoostClassifier(ExtraTreesClassifier(n_estimators=32, max_depth=16, max_features=0.5), **kwargs),
  not_biased_config=RunConfiguration(
    run_grid_search=True,
    grid_search_params={"n_estimators": (8, 16, 32, 64), "learning_rate": (0.01, 0.5, 0.1, 0.5)},
    default_parameters={'n_estimators': 16, 'learning_rate': 0.1}
  )
)

Data has 13711 columns
{'n_estimators': 16, 'learning_rate': 0.1}


100%|██████████| 100/100 [21:55<00:00, 13.16s/it]

           Metric           Overall             Male           Female
0   F1 (Weighted)   0.8535 ± 0.0259  0.8398 ± 0.0336  0.8631 ± 0.0356
1      F1 (Macro)   0.7966 ± 0.0359  0.7748 ± 0.0447  0.8078 ± 0.0515
2  Recall (Macro)   0.7940 ± 0.0317  0.7789 ± 0.0405  0.8083 ± 0.0445
3         ROC AUC   0.8879 ± 0.0176  0.8808 ± 0.0227  0.8965 ± 0.0243
4        Accuracy   0.8611 ± 0.0243  0.8442 ± 0.0317  0.8806 ± 0.0318
5        Duration  12.9857 ± 0.7528                0                0





In [31]:
run_tests(
  category="min_tpm_5",
  model_factory=lambda **kwargs: AdaBoostClassifier(RandomForestClassifier(n_estimators=32, max_depth=32, max_features=0.5), **kwargs),
  not_biased_config=RunConfiguration(
    run_grid_search=True,
    grid_search_params={"n_estimators": (32, 64, 128, 256), "learning_rate": (1, 2, 3, 5)},
    default_parameters={'n_estimators': 64, 'learning_rate': 2}
  )
)

Data has 13711 columns
{'n_estimators': 64, 'learning_rate': 1}


100%|██████████| 100/100 [1:03:51<00:00, 38.32s/it]

           Metric           Overall             Male           Female
0   F1 (Weighted)   0.8392 ± 0.0226  0.8294 ± 0.0322  0.8450 ± 0.0316
1      F1 (Macro)   0.7847 ± 0.0308  0.7684 ± 0.0452  0.7848 ± 0.0473
2  Recall (Macro)   0.7787 ± 0.0273  0.7683 ± 0.0381  0.7850 ± 0.0410
3         ROC AUC   0.8822 ± 0.0164  0.8745 ± 0.0229  0.8850 ± 0.0237
4        Accuracy   0.8472 ± 0.0211  0.8442 ± 0.0289  0.8657 ± 0.0279
5        Duration  36.3761 ± 8.3093                0                0



