In [1]:
import logging
import os
from datetime import datetime

import pandas as pd
import torch

from naiad import NAIAD, load_naiad_data, ActiveLearner, ActiveLearnerReplicates

In [2]:
logging.basicConfig(level=logging.INFO)
pd.set_option("mode.copy_on_write", True)

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
data_source = 'norman'
data_map = {'norman': '../../data/norman/norman_gamma.csv',
            'simpson': '../../data/simpson/simpson_gamma.csv',
            'horlbeck_jurkat': '../../data/horlbeck/horlbeck_jurkat_gamma.csv',
            'horlbeck_k562': '../../data/horlbeck/horlbeck_k562_gamma.csv'}
data_file = data_map[data_source]

os.mkdir('./results') if not os.path.exists('./results') else None

time_label = datetime.now().strftime(r'%Y-%m-%d-%H-%M-%S')
result_dir = os.path.join('./results/', f'{data_source}_{time_label}')
os.mkdir(result_dir)

In [4]:
method_mins = {
    'mean': True,           # since strong cell viability scores are negative, we want to minimize them
    'std': False,           # we want to pick points with maximum standard deviation
    'residual+std': False   # and also pick points with maximum residual+std
}

In [5]:
if data_source == 'norman':
    model_args = [
        {'model_type': 'both', 'd_embed': 2, 'd_pheno_hid': 64, 'p_dropout': 0.1}, 
        {'model_type': 'both', 'd_embed': 4, 'd_pheno_hid': 64, 'p_dropout': 0.1},
        {'model_type': 'both', 'd_embed': 4, 'd_pheno_hid': 64, 'p_dropout': 0.1},
        {'model_type': 'both', 'd_embed': 8, 'd_pheno_hid': 64, 'p_dropout': 0.1},
        {'model_type': 'both', 'd_embed': 16, 'd_pheno_hid': 64, 'p_dropout': 0.1}
    ]
    n_sample = [100, 200, 300, 400, 500]
elif data_source in ['simpson', 'horlbeck_jurkat', 'horlbeck_k562']:
    model_args = [
        {'model_type': 'both', 'd_embed': 2, 'd_pheno_hid': 256, 'p_dropout': 0.1}, 
        {'model_type': 'both', 'd_embed': 8, 'd_pheno_hid': 256, 'p_dropout': 0.1},
        {'model_type': 'both', 'd_embed': 8, 'd_pheno_hid': 256, 'p_dropout': 0.1},
        {'model_type': 'both', 'd_embed': 16, 'd_pheno_hid': 256, 'p_dropout': 0.1},
        {'model_type': 'both', 'd_embed': 32, 'd_pheno_hid': 256, 'p_dropout': 0.1}
    ]
    n_sample = [500, 1000, 1500, 2000, 2500]

if data_source == 'norman':
    n_epoch = 200
    batch_size = 4096
elif data_source in ['simpson', 'horlbeck_jurkat', 'horlbeck_k562']:
    n_epoch = 100
    batch_size = 4096
        
model_optimizer_settings = {'pheno_lr': 1e-2, 'embed_lr': 1e-2, 'weight_decay': 0}
test_frac = 0.3

In [6]:
data = load_naiad_data(data_file)
active_learner = ActiveLearner(
    n_round = 5, 
    data = data, 
    model = NAIAD, 
    model_args = model_args,
    model_optimizer_settings = model_optimizer_settings,
    n_ensemble = 5,
    n_epoch = n_epoch,
    n_sample = n_sample, 
    test_frac = test_frac, 
    early_stop = False,
    device = device,
    batch_size = batch_size
)

active_learner_reps = ActiveLearnerReplicates(
    n_rep = 5, 
    overall_seed = 0, 
    active_learner = active_learner, 
    save_dir = result_dir, 
    save_prefix = data_source)

## Sample Mean

In [7]:
active_learner_reps.set_method('mean', True)
results = active_learner_reps.run_replicates(parallel=True)
aggregated_results = active_learner_reps.aggregate_replicate_metrics(return_value=True)
active_learner_reps.save_aggregated_results(data_source)

## Sample Std

In [8]:
active_learner_reps.set_method('std', False)
active_learner_reps.run_replicates(parallel=True)
active_learner_reps.aggregate_replicate_metrics(return_value=False)
active_learner_reps.save_aggregated_results(data_source)

## Sample Residual+Std

In [9]:
active_learner_reps.set_method('residual+std', False)
active_learner_reps.run_replicates(parallel=True)
active_learner_reps.aggregate_replicate_metrics(return_value=False)
active_learner_reps.save_aggregated_results(data_source)

## Compare Results

In [10]:
active_learner_reps.save_aggregated_results(data_source)

In [None]:
active_learner_reps.plot_aggregated_results(
    metrics = ['mse', 'tpr'], 
    splits = 'overall', 
    methods = ['mean', 'std', 'residual+std'], 
    orientation = 'horizontal'
)