In [1]:
import pandas as pd
import numpy as np
import neptune
import torch
from neptune.utils import stringify_unsupported
from misc import save_json, load_json, get_clean_batch_sz
from utils_neptune import get_latest_dataset, check_neptune_existance
from models import get_encoder, get_head, MultiHead, create_model_wrapper, create_pytorch_model_from_info, CompoundModel
from train4 import train_compound_model, create_dataloaders_old, CompoundDataset
import os
import shutil
from collections import defaultdict
import re

NEPTUNE_API_TOKEN = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIxMGM5ZDhiMy1kOTlhLTRlMTAtOGFlYy1hOTQzMDE1YjZlNjcifQ=='

PROJECT_ID = 'revivemed/Survival-RCC'


In [2]:
def get_pretrained_encoder(dropout_rate=None,use_rand_init=False,load_dir=None,verbose=False):
    if load_dir is None:
        load_dir = os.path.expanduser('~/PRETRAINED_MODELS/2925')
        os.makedirs(load_dir,exist_ok=True)

    encoder_kwargs_path = os.path.join(load_dir,'encoder_kwargs.json')
    encoder_state_path =  os.path.join(load_dir,'encoder_state.pt')

    if (not os.path.exists(encoder_kwargs_path)) or (not os.path.exists(encoder_state_path)):

        neptune_model = neptune.init_model(project='revivemed/Survival-RCC',
            api_token= NEPTUNE_API_TOKEN,
            with_id='SUR-MOD',
            mode="read-only")

        if not os.path.exists(encoder_state_path):
            neptune_model['model/encoder_state'].download(encoder_state_path)
        if not os.path.exists(encoder_kwargs_path):
            encoder_kwargs = neptune_model['original_kwargs/encoder_kwargs'].fetch()
            if 'input_size' not in encoder_kwargs:
                encoder_kwargs['input_size'] = 2736
            if 'kind' not in encoder_kwargs:
                encoder_kwargs['kind'] = neptune_model['original_kwargs/encoder_kind'].fetch()
            if 'hidden_size' not in encoder_kwargs:
                if 'hidden_size_mult' in encoder_kwargs:
                    latent_size = encoder_kwargs['latent_size']
                    encoder_kwargs['hidden_size'] = int(encoder_kwargs['hidden_size_mult']*latent_size)
                else:
                    raise ValueError()
                # remove the hidden_size_mult key
                encoder_kwargs.pop('hidden_size_mult')
            save_json(encoder_kwargs,encoder_kwargs_path)

        neptune_model.stop()

    encoder_kwargs = load_json(encoder_kwargs_path)
    if (dropout_rate is not None):
        if verbose: print('Setting dropout rate to',dropout_rate)
        encoder_kwargs['dropout_rate'] = dropout_rate

    encoder = get_encoder(**encoder_kwargs)

    if not use_rand_init:
        encoder.load_state_dict(torch.load(encoder_state_path))

    return encoder




def get_model_heads(head_kwargs_dict,backup_input_size=None,verbose=False):
    if head_kwargs_dict is None:
        return MultiHead([]) #get_head(kind='Dummy')
    
    if isinstance(head_kwargs_dict,dict):
        head_kwargs_list = [head_kwargs_dict[k] for k in head_kwargs_dict.keys()]
    elif isinstance(head_kwargs_dict,list):
        head_kwargs_list = head_kwargs_dict
    else:
        raise ValueError(f'Invalid head_kwargs_dict type: {type(head_kwargs_dict)}')

    head_list = []
    if len(head_kwargs_list) == 0:
       return MultiHead([])

    for h_kwargs in head_kwargs_list:
            
        if 'input_size' not in h_kwargs:
            if backup_input_size is None:
                raise ValueError('backup_input_size is None')
            if verbose: print('Setting input_size to',backup_input_size)
            h_kwargs['input_size'] = backup_input_size

        head = get_head(**h_kwargs)
        head_list.append(head)


    head = MultiHead(head_list)
    return head



def build_model_components(head_kwargs_dict,adv_kwargs_dict=None,dropout_rate=None,use_rand_init=False):

    encoder = get_pretrained_encoder(dropout_rate=dropout_rate,use_rand_init=use_rand_init)

    #TODO we need better handling of head input size to account for the size of the other_vars
    head = get_model_heads(head_kwargs_dict,backup_input_size=encoder.latent_size+1)
    if head.kind == 'MultiHead':
        head.name = 'HEAD'

    adv = get_model_heads(adv_kwargs_dict,backup_input_size=encoder.latent_size)
    if adv.kind == 'MultiHead':
        adv.name = 'ADVERSARY'

    return encoder, head, adv




In [3]:
def fit_model_wrapper(X,y,model_components_dict={},run_dict={},**train_kwargs):

    assert isinstance(model_components_dict,dict), 'model_components_dict should be a dictionary'

    if isinstance(run_dict, neptune.metadata_containers.run.Run) or isinstance(run_dict, neptune.handler.Handler):
        print('Record the model fitting to Neptune')
        use_neptune= True
    else:
        use_neptune = False
        
    ### Model Component Defaults
    y_head_cols = model_components_dict.get('y_head_cols',None)
    y_adv_cols = model_components_dict.get('y_adv_cols',None)
    head_kwargs_dict = model_components_dict.get('head_kwargs_dict',None)
    adv_kwargs_dict = model_components_dict.get('adv_kwargs_dict',None)

    if y_head_cols is None:
        # select all of the numeric columns
        y_head_cols = list(y.select_dtypes(include=[np.number]).columns)

    if y_adv_cols is None:
        y_adv_cols = []

    ### Train Defaults
    dropout_rate = train_kwargs.get('dropout_rate', None)
    use_rand_init = train_kwargs.get('use_rand_init', False)
    batch_size = train_kwargs.get('batch_size', 64)
    holdout_frac = train_kwargs.get('holdout_frac', 0)
    early_stopping_patience = train_kwargs.get('early_stopping_patience', 0)
    scheduler_kind = train_kwargs.get('scheduler_kind', None)
    train_name = train_kwargs.get('train_name', 'train')

    ### Prepare the Data Loader
    X_size = X.shape[1]
    if (holdout_frac > 0) and (early_stopping_patience < 1) and (scheduler_kind is None):
        # raise ValueError('holdout_frac > 0 and early_stopping_patience < 1 is not recommended')
        print('holdout_frac > 0 and early_stopping_patience < 1 is not recommended, set hold out frac to 0')
        print('UNLESS you are using a scheduler, in which case the holdout_frac is used for the scheduler')
        holdout_frac = 0
        batch_size = get_clean_batch_sz(X_size, batch_size)
    else:
        batch_size = get_clean_batch_sz(X_size*(1-holdout_frac), batch_size)

    train_dataset = CompoundDataset(X,y[y_head_cols], y[y_adv_cols])
    # stratify on the adversarial column (stratify=2)
    # this is probably not the most memory effecient method, would be better to do stratification before creating the dataset
    # train_loader_dct = create_dataloaders(train_dataset, batch_size, holdout_frac, set_name=train_name, stratify=2)
    train_loader_dct = create_dataloaders_old(train_dataset, batch_size, holdout_frac, set_name=train_name)


    ### Build the Model Components
    encoder, head, adv = build_model_components(head_kwargs_dict=head_kwargs_dict,
                                                adv_kwargs_dict=adv_kwargs_dict,
                                                dropout_rate=dropout_rate,
                                                use_rand_init=use_rand_init)

    if train_dataset is not None:
        #TODO load the class weights from the model info json
        head.update_class_weights(train_dataset.y_head)
        adv.update_class_weights(train_dataset.y_adv)

    if len(adv.heads)==0:
        train_kwargs['adversary_weight'] = 0

    ### Train the Model
    encoder, head, adv = train_compound_model(train_loader_dct, 
                                            encoder, head, adv, 
                                            run=run_dict, **train_kwargs)
    if encoder is None:
        raise ValueError('Encoder is None after training, training failed')

    
    return run_dict, encoder, head, adv



def save_model_wrapper(encoder, head, adv, save_dir=None,run_dict={},prefix='training_run'):

        if isinstance(run_dict, neptune.metadata_containers.run.Run) or isinstance(run_dict, neptune.handler.Handler):
            print('Save models to Neptune')
            use_neptune= True
        else:
            use_neptune = False

        delete_after_upload = False
        if save_dir is None:
            save_dir = os.path.expanduser('~/TEMP_MODELS')
            if use_neptune:
                delete_after_upload = True
                if os.path.exists(save_dir):
                    # delete the directory
                    shutil.rmtree(save_dir)
            os.makedirs(save_dir,exist_ok=True)


        upload_models_to_gcp = False    

        encoder.save_state_to_path(save_dir,save_name='encoder_state.pt')
        encoder.save_info(save_dir,save_name='encoder_info.json')
        head.save_state_to_path(save_dir,save_name='head_state.pt')
        head.save_info(save_dir,save_name='head_info.json')
        adv.save_state_to_path(save_dir,save_name='adversary_state.pt')
        adv.save_info(save_dir,save_name='adversary_info.json')

        # torch.save(head.state_dict(), f'{save_dir}/{setup_id}_head_state_dict.pth')
        # torch.save(adv.state_dict(), f'{save_dir}/{setup_id}_adv_state_dict.pth')
        if use_neptune:
            run_dict[f'{prefix}/models/encoder_state'].upload(f'{save_dir}/encoder_state.pt')
            run_dict[f'{prefix}/models/encoder_info'].upload(f'{save_dir}/encoder_info.json')
            run_dict[f'{prefix}/models/head_state'].upload(f'{save_dir}/head_state.pt')
            run_dict[f'{prefix}/models/head_info'].upload(f'{save_dir}/head_info.json')
            run_dict[f'{prefix}/models/adv_state'].upload(f'{save_dir}/adv_state.pt')
            run_dict[f'{prefix}/models/adv_info'].upload(f'{save_dir}/adv_info.json')
            run_dict.wait()

        if upload_models_to_gcp:
            raise NotImplementedError('upload_models_to_gcp not implemented')

        if use_neptune and delete_after_upload:
            run_dict.wait()
            shutil.rmtree(save_dir)

        return run_dict


In [4]:
def evaluate_model_wrapper(encoder, head, adv, X_data_eval, y_data_eval,y_cols,y_head=None):

    if y_head is None:
        chosen_head = head
    else:
        multihead_name_list = head.heads_names
        if y_head not in multihead_name_list:
            raise ValueError(f'Invalid head name: {y_head}')
        chosen_head_idx = multihead_name_list.index(y_head)
        chosen_head = head.heads[chosen_head_idx]

        if isinstance(chosen_head.y_idx,list):
            if len(y_cols) != len(chosen_head.y_idx):
                raise ValueError(f'Invalid y_cols length: {len(y_cols)} vs {len(chosen_head.y_idx)}')
            if len(y_cols) == 1:
                chosen_head.y_idx = [0]
            else:
                chosen_head.y_idx = list(range(len(y_cols)))
        else:
            if len(y_cols) != 1:
                raise ValueError(f'Invalid y_cols length: {len(y_cols)} vs {len(chosen_head.y_idx)}')
            chosen_head.y_idx = 0

        # if len(y_cols) != len(chosen_head.y_idx):
        #     raise ValueError(f'Invalid y_cols length: {len(y_cols)} vs {len(chosen_head.y_idx)}')
        # if len(y_cols) == 1:
        #     chosen_head.y_idx = 0
        # else:
        #     chosen_head.y_idx = list(range(len(y_cols)))
        
    model = CompoundModel(encoder, chosen_head)
    skmodel = create_pytorch_model_from_info(full_model=model)

    return skmodel.score(X_data_eval.to_numpy(),y_data_eval[y_cols].to_numpy())

In [5]:
def run_model_wrapper(data_dir,params,output_dir=None,train_name='train', prefix='training_run',
              eval_name_list=['val'],eval_params_list=None,run_dict=None):
    
    # if run_dict is None:
    #     run_dict = {}

    if isinstance(run_dict, neptune.metadata_containers.run.Run) or isinstance(run_dict, neptune.handler.Handler):
        print('Using Neptune')
        use_neptune= True
        #TODO: check if the models are already trained on neptune
        download_models_from_neptune = check_neptune_existance(run_dict,'models/encoder_info')


    if use_neptune:
        run_dict[f'{prefix}/dataset'].track_files(data_dir)
        run_dict[f'{prefix}/model_name'] = 'Model2925'
        run_dict[f'{prefix}/params'] = stringify_unsupported(params)

    if eval_params_list is None:
        eval_params_list = [{}]

    if output_dir is None:
        output_dir = os.path.expanduser('~/TEMP_MODELS')

    X_filename = 'X_finetune'
    y_filename = 'y_finetune'
    saved_model_dir = os.path.join(output_dir,prefix,'models')        
    os.makedirs(saved_model_dir,exist_ok=True)
    task_components_dict = params['task_kwargs']
    train_kwargs = params['train_kwargs']

    if (not os.path.exists(f'{saved_model_dir}/encoder_info.json')) and (use_neptune) and (download_models_from_neptune): 
        run_dict[f'{prefix}/models/encoder_state'].download(f'{saved_model_dir}/encoder_state.pt')
        run_dict[f'{prefix}/models/encoder_info'].download(f'{saved_model_dir}/encoder_info.json')
        run_dict[f'{prefix}/models/head_state'].download(f'{saved_model_dir}/head_state.pt')
        run_dict[f'{prefix}/models/head_info'].download(f'{saved_model_dir}/head_info.json')
        run_dict[f'{prefix}/models/adv_state'].download(f'{saved_model_dir}/adv_state.pt')
        run_dict[f'{prefix}/models/adv_info'].download(f'{saved_model_dir}/adv_info.json')
    
    if os.path.exists(f'{saved_model_dir}/encoder_info.json'):
        encoder = create_model_wrapper(f'{saved_model_dir}/encoder_info.json',f'{saved_model_dir}/encoder_state.pt')
        head = create_model_wrapper(f'{saved_model_dir}/head_info.json',f'{saved_model_dir}/head_state.pt',is_encoder=False)
        if os.path.exists(f'{saved_model_dir}/adv_info.json'):
            adv = create_model_wrapper(f'{saved_model_dir}/adv_info.json',f'{saved_model_dir}/adv_state.pt',is_encoder=False)
        else:
            adv = MultiHead([])

    else:
        X_data_train = pd.read_csv(f'{data_dir}/{X_filename}_{train_name}.csv', index_col=0)
        y_data_train = pd.read_csv(f'{data_dir}/{y_filename}_{train_name}.csv', index_col=0)

        _, encoder, head, adv = fit_model_wrapper(X=X_data_train,
                                                            y=y_data_train,
                                                            model_components_dict=task_components_dict,
                                                            run_dict=run_dict[prefix],**train_kwargs)

        save_model_wrapper(encoder, head, adv, save_dir=saved_model_dir,run_dict=run_dict,prefix=prefix)


    metrics = defaultdict(dict)
    for eval_name in eval_name_list:
        X_data_eval = pd.read_csv(f'{data_dir}/{X_filename}_{eval_name}.csv', index_col=0)
        y_data_eval = pd.read_csv(f'{data_dir}/{y_filename}_{eval_name}.csv', index_col=0)

        for eval_params in eval_params_list:
            y_col_name = eval_params.get('y_col_name',None)
            y_cols = eval_params.get('y_cols',None)
            y_head = eval_params.get('y_head',None)
            if y_cols is None:
                y_cols = params['task_kwargs']['y_head_cols']

            if y_col_name is None:
                metrics[f'{eval_name}' ].update(evaluate_model_wrapper(encoder, head, adv, X_data_eval, y_data_eval,y_cols=y_cols,y_head=y_head))
            else:
                metrics[f'{eval_name}__{y_col_name}'].update(evaluate_model_wrapper(encoder, head, adv, X_data_eval, y_data_eval,y_cols=y_cols,y_head=y_head))
                
    if use_neptune:
        run_dict[f'{prefix}/metrics'] = metrics
        run_dict.wait()

    return metrics




In [6]:
finetune_structure_kwargs = {
    'encoder_weight': 0,
    'adversary_weight': 0,
    'head_hidden_layers': 0,
    'auxillary_heads' : {'name': 'dummy', 'weight': 0},
    'adversary_heads' : {}
}

finetune_train_kwargs = {
    'batch_size': 64,
    'holdout_frac': 0,
    'early_stopping_patience': 0,
    'scheduler_kind': None,
    'train_name': 'train',
    'num_epochs': 10,
    'learning_rate': 0.001,
    'noise_factor': 0.1,
    'weight_decay': 0,
    'l2_reg_weight': 0,
    'l1_reg_weight': 0,
    'dropout_rate': 0.1,
    'adversarial_start_epoch': 0,
    'clip_grads_with_norm': True,
}

In [7]:
default_sweep_kwargs = {
    'holdout_frac': 0,
    'head_hidden_layers': 0,
    'encoder_kwargs__dropout_rate': 0.2,
    'train_kwargs__num_epochs': 30,
    'train_kwargs__early_stopping_patience': 0,
    'train_kwargs__learning_rate': 0.0005,
    'train_kwargs__l2_reg_weight': 0.0005,
    'train_kwargs__l1_reg_weight': 0.005,
    'train_kwargs__noise_factor': 0.1,
    'train_kwargs__weight_decay': 0,
    'train_kwargs__adversary_weight': 1,
    'train_kwargs__adversarial_start_epoch': 10,
    'train_kwargs__encoder_weight': 0,
    'train_kwargs__clip_grads_with_norm': False,
}


In [8]:


def get_head_kwargs_by_desc(desc_str,num_hidden_layers=0,weight=1,y_cols=None):
    if (desc_str is None) or (desc_str == ''):
        return None, [], []
    
    if y_cols is None:
        y_cols = []

    if 'weight-' in desc_str:
        match = re.search(r'weight-(\d+)', desc_str)
        if match:
            weight = int(match.group(1))
            desc_str = desc_str.replace(match.group(0),'')

    if 'mskcc' in desc_str.lower():
        # if 'mskcc-ord' in desc_str.lower:
            # raise NotImplementedError('not implemented yet')
        y_head_cols = ['MSKCC BINARY']
        head_name = 'MSKCC'
        head_kind = 'Binary'
        num_classes = 2
        y_idx = 0
        plot_latent_space_cols = ['MSKCC']

    elif 'imdc' in desc_str.lower():
        y_head_cols = ['IMDC BINARY']
        head_name = 'IMDC'
        head_kind = 'Binary'
        num_classes = 2
        y_idx = 0
        plot_latent_space_cols = ['IMDC']

    elif 'nivo-benefit' in desc_str.lower():
        raise NotImplementedError()

    elif 'benefit' in desc_str.lower():
        y_head_cols = ['Benefit BINARY']
        head_name = 'Benefit'
        head_kind = 'Binary'
        num_classes = 2
        y_idx = 0
        plot_latent_space_cols = ['Benefit']

    elif 'both-os' in desc_str.lower():
        y_head_cols = ['OS','OS_Event']
        head_name = 'OS'
        head_kind = 'Cox'
        num_classes = 1
        y_idx = [0,1]
        plot_latent_space_cols = ['OS']   

    elif 'both-pfs' in desc_str.lower():
        y_head_cols = ['PFS','PFS_Event']
        head_name = 'PFS'
        head_kind = 'Cox'
        num_classes = 1
        y_idx = [0,1]
        plot_latent_space_cols = ['PFS']        

    elif 'nivo-os' in desc_str.lower():
        y_head_cols = ['NIVO OS','OS_Event']
        head_name = 'NIVO OS'
        head_kind = 'Cox'
        num_classes = 1
        y_idx = [0,1]
        plot_latent_space_cols = ['NIVO OS']   

    elif 'nivo-pfs' in desc_str.lower():
        y_head_cols = ['NIVO PFS','PFS_Event']
        head_name = 'NIVO PFS'
        head_kind = 'Cox'
        num_classes = 1
        y_idx = [0,1]
        plot_latent_space_cols = ['NIVO PFS']         

    elif 'ever-os' in desc_str.lower():
        y_head_cols = ['EVER OS','OS_Event']
        head_name = 'EVER OS'
        head_kind = 'Cox'
        num_classes = 1
        y_idx = [0,1]
        plot_latent_space_cols = ['EVER OS']   

    elif 'ever-pfs' in desc_str.lower():
        y_head_cols = ['EVER PFS','PFS_Event']
        head_name = 'EVER PFS'
        head_kind = 'Cox'
        num_classes = 1
        y_idx = [0,1]
        plot_latent_space_cols = ['EVER OS']            
    else:
        raise ValueError('Unknown desc_str:',desc_str)

    for col in y_head_cols:
        if col not in y_cols:
            y_cols.append(col)

    if len(y_head_cols) == 1:
        y_idx = y_cols.index(y_head_cols[0])
    else:
        y_idx = [y_cols.index(col) for col in y_head_cols]



    head_kwargs = {
            'kind': head_kind,
            'name': head_name,
            'weight': weight,
            'y_idx': y_idx,
            'hidden_size': 4,
            'num_hidden_layers': num_hidden_layers,
            'dropout_rate': 0,
            'activation': 'leakyrelu',
            'use_batch_norm': False,
            'num_classes': num_classes,
            }

    return head_kwargs, y_head_cols, plot_latent_space_cols

In [9]:
def parse_model_components_dict_from_str(desc_str,sweep_kwargs=None):

    if sweep_kwargs is None:
        sweep_kwargs = default_sweep_kwargs

    model_components_dict = {}

    clean_desc_str = desc_str
    if 'optuna_' in desc_str:
        clean_desc_str = clean_desc_str.replace('optuna_','')
    if 'Optimized_' in clean_desc_str:
        clean_desc_str = clean_desc_str.replace('Optimized_','')
    if '__' in clean_desc_str:
        clean_desc_str = clean_desc_str.split('__')[0]

    if 'ADV' in clean_desc_str:
        adv_desc_str = clean_desc_str.split('ADV')[1]
        head_desc_str = clean_desc_str.split('ADV')[0]
    else:
        adv_desc_str = ''
        head_desc_str = clean_desc_str

    y_head_cols = []
    head_kwargs_list = []
    plot_latent_space_cols = []
    
    if 'AND' in head_desc_str:
        head_desc_str_list = head_desc_str.split('AND')
    else:
        head_desc_str_list = [head_desc_str]

    head_hidden_layers = sweep_kwargs.get('head_hidden_layers',0)

    for h_desc in head_desc_str_list:
        head_weight = sweep_kwargs.get(f'{h_desc}__weight',1)
        head_kwargs, head_cols, plot_latent_space_head_cols = get_head_kwargs_by_desc(h_desc,
                                                                                    num_hidden_layers=head_hidden_layers,
                                                                                    weight=head_weight,y_cols=y_head_cols)
        if head_kwargs is None:
            continue
        head_kwargs_list.append(head_kwargs)
        for col in head_cols:
            if col not in y_head_cols:
                y_head_cols.append(col)

        for col in plot_latent_space_head_cols:
            if col not in plot_latent_space_cols:
                plot_latent_space_cols.append(col)



    y_adv_cols = []
    adv_kwargs_list = []

    if 'AND' in adv_desc_str:
        adv_desc_str_list = adv_desc_str.split('AND')
    else:
        adv_desc_str_list = [adv_desc_str]

    for a_desc in adv_desc_str_list:
        adv_weight = sweep_kwargs.get(f'{a_desc}__weight',1)
        adv_kwargs, adv_cols, plot_latent_space_adv_cols = get_head_kwargs_by_desc(a_desc,
                                                                                   num_hidden_layers=head_hidden_layers,
                                                                                    weight=adv_weight,y_cols=y_adv_cols)
        if adv_kwargs is None:
            continue
        adv_kwargs_list.append(adv_kwargs)
        for col in adv_cols:
            if col not in y_adv_cols:
                y_adv_cols.append(col)
        for col in plot_latent_space_adv_cols:
            if col not in plot_latent_space_cols:
                plot_latent_space_cols.append(col)

    model_components_dict = {
        'y_head_cols': y_head_cols,
        'y_adv_cols': y_adv_cols,
        'head_kwargs_dict': head_kwargs_list,
        'adv_kwargs_dict': adv_kwargs_list,
        'plot_latent_space_cols': plot_latent_space_cols
    }

    return model_components_dict



In [10]:
def get_params(desc_str,sweep_kwargs=None):

    model_components_dict = parse_model_components_dict_from_str(desc_str,sweep_kwargs)

    train_kwargs_list = [
        'batch_size',
        'holdout_frac',
        'early_stopping_patience',
        'scheduler_kind',
        'num_epochs',
        'learning_rate',
        'noise_factor',
        'weight_decay',
        'l2_reg_weight',
        'l1_reg_weight',
        'dropout_rate',
        'adversarial_start_epoch',
        'clip_grads_with_norm',
        'encoder_weight',
        'adversary_weight',
        'train_name'
    ]

    train_kwargs_dict = {k:v for k,v in sweep_kwargs.items() if k in train_kwargs_list}
    params = {
        'task_kwargs': model_components_dict,
        'train_kwargs': train_kwargs_dict
    }
    return params

In [11]:
run = neptune.init_run(project=PROJECT_ID,
                                api_token=NEPTUNE_API_TOKEN,
                                with_id='SUR-2481')
                                # mode=neptune_mode,
                                # capture_stdout=yes_logging,
                                # capture_stderr=yes_logging,
                                # capture_hardware_metrics=yes_logging)



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/revivemed/Survival-RCC/e/SUR-2481


In [12]:
data_dir= '/Users/jonaheaton/ReviveMed Dropbox/Jonah Eaton/development_finetune_optimization/April_30_Finetune_Data'


In [None]:
# run_dict = run['training_run']
params = get_params('IMDC AND MSKCC',sweep_kwargs=default_sweep_kwargs)

eval_params_list = [
    {},
    {'y_col_name':'IMDC',
    'y_head':'MSKCC',
    'y_cols': ['IMDC BINARY']},
]

run_model_wrapper(data_dir,params,output_dir=None,train_name='train',prefix='training_run1', 
              eval_name_list=['val'],eval_params_list=eval_params_list,run_dict=run)

In [15]:
# run_dict = run['training_run']
params = get_params('EVER-OS AND NIVO-OS',sweep_kwargs=default_sweep_kwargs)

eval_params_list = [
    {},
    {'y_col_name':'NIVO OS',
    'y_head':'EVER OS',
    'y_cols': ['NIVO OS','OS_Event']},
]

run_model_wrapper(data_dir,params,output_dir=None,train_name='train',prefix='training_run2', 
              eval_name_list=['val'],eval_params_list=eval_params_list,run_dict=run)

Using Neptune


defaultdict(dict,
            {'val': {'Cox_EVER OS': {'Concordance Index': 0.7134986225895317},
              'Cox_NIVO OS': {'Concordance Index': 0.6332842415316642}},
             'val__NIVO OS': {'Concordance Index': 0.6991899852724595}})

In [14]:
params

{'task_kwargs': {'y_head_cols': ['EVER OS', 'OS_Event', 'NIVO OS'],
  'y_adv_cols': [],
  'head_kwargs_dict': [{'kind': 'Cox',
    'name': 'EVER OS',
    'weight': 1,
    'y_idx': [0, 1],
    'hidden_size': 4,
    'num_hidden_layers': 0,
    'dropout_rate': 0,
    'activation': 'leakyrelu',
    'use_batch_norm': False,
    'num_classes': 1},
   {'kind': 'Cox',
    'name': 'NIVO OS',
    'weight': 1,
    'y_idx': [2, 1],
    'hidden_size': 4,
    'num_hidden_layers': 0,
    'dropout_rate': 0,
    'activation': 'leakyrelu',
    'use_batch_norm': False,
    'num_classes': 1}],
  'adv_kwargs_dict': [],
  'plot_latent_space_cols': ['EVER OS', 'NIVO OS']},
 'train_kwargs': {'holdout_frac': 0}}

In [None]:
params = get_params('IMDC AND MSKCC',sweep_kwargs=default_sweep_kwargs)

In [None]:
params

In [None]:
run.stop()

In [None]:
encoder, head, adv = build_model_components(head_kwargs_dict=params['task_kwargs']['head_kwargs_dict'],
                                            adv_kwargs_dict=params['task_kwargs']['adv_kwargs_dict'],
                                            dropout_rate=0,
                                            use_rand_init=False)

In [None]:
m1 = CompoundModel(encoder,head.heads[0])

In [None]:
head.heads_names

In [None]:
filepath = '/Users/jonaheaton/Downloads/head_info.json'
model_info = load_json(filepath)

In [None]:
model_info.values()

In [16]:
run.stop()

[neptune] [info   ] Shutting down background jobs, please wait a moment...
[neptune] [info   ] Done!
[neptune] [info   ] All 0 operations synced, thanks for waiting!
[neptune] [info   ] Explore the metadata in the Neptune app: https://app.neptune.ai/revivemed/Survival-RCC/e/SUR-2481/metadata
