In [6]:
# set working directory
import os
os.chdir('/home/katerchen/Code/DrVAE-master/src')

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
## add src path that is missing when running with Slurm 
sys.path.append(os.path.dirname(os.getcwd())+"/src")
## print python version, machine hostname, and hash+summary of the latest git commit
print('hostname:', os.uname()[1])
print(sys.version)
os.system('git log -1 |cat')

import argparse
import numpy as np
import sklearn.utils
import sklearn.metrics
from scipy import stats
from collections import Counter, OrderedDict
from pprint import pprint
import json
import h5py

import torch
import torch.utils.data
from torch.autograd import Variable

from DrVAE import DrVAE, DrVAEDataset, wrap_in_DrVAEDataset
import utils as utl

hostname: DESKTOP-G507M9H
3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]


fatal: not a git repository (or any of the parent directories): .git


In [7]:
# enforce pytorch version 0.3.x, refactoring is required for 0.4.x
print('pytorch version:', torch.__version__)
# if not torch.__version__.startswith('0.3'):
#    raise Exception('pytorch version 0.3.x is required')
# set number of CPU parallel threads to 4, performance doesn't scale beyond 4
print('orig num threads:', torch.get_num_threads())
torch.set_num_threads(4)
print('now num threads:', torch.get_num_threads())
print('-----')

pytorch version: 2.0.1+cu117
orig num threads: 4
now num threads: 4
-----


In [8]:
def load_from_HDF(fname):
    """Load data from a HDF5 file to a dictionary."""
    data = dict()
    with h5py.File(fname, 'r') as f:
        for key in f:
            data[key] = np.asarray(f[key])
            if isinstance(data[key][0], np.bytes_):
                data[key] = data[key].astype(str)
            # print(key + ":", f[key])
    return data

In [9]:
data = load_from_HDF('/home/katerchen/Code/DrVAE-master/workspace/datafiles/CTRPv2+L1000_FDAdrugs6h_v2.1.h5')
print('Loaded data from:', data)
    
utl.make_out_dirs('./out')

Loaded data from: {'drug_drug': array(['axitinib', 'bortezomib', 'bosutinib', 'chlorambucil',
       'ciclosporin', 'cimetidine', 'clofarabine', 'crizotinib',
       'dasatinib', 'decitabine', 'dexamethasone', 'docetaxel',
       'erlotinib', 'etoposide', 'fluvastatin', 'fulvestrant',
       'gefitinib', 'gemcitabine', 'imatinib', 'itraconazole',
       'lovastatin', 'mitomycin', 'niclosamide', 'nilotinib',
       'omacetaxine mepesuccinate', 'paclitaxel', 'pazopanib', 'PLX-4032',
       'procarbazine', 'prochlorperazine', 'ruxolitinib', 'sildenafil',
       'simvastatin', 'sirolimus', 'sitagliptin', 'sorafenib',
       'tacrolimus', 'temozolomide', 'teniposide', 'thalidomide',
       'topotecan', 'tretinoin', 'trifluoperazine', 'valdecoxib',
       'vincristine', 'vorinostat'], dtype='<U25'), 'drug_m': array([[0.52658381, 0.        , 0.35714286, ..., 0.        , 0.        ,
        0.        ],
       [0.20688762, 0.        , 0.71428571, ..., 0.        , 0.        ,
        0.        

In [5]:
drug_name = 'bortezomib'

drug_list_26 = ['omacetaxine mepesuccinate', 'bortezomib', 'vorinostat', 'paclitaxel', 'docetaxel', 'topotecan',
                    'niclosamide', 'valdecoxib','teniposide', 'vincristine', 'prochlorperazine', 'mitomycin', 'lovastatin',
                    'gemcitabine', 'dasatinib', 'fluvastatin', 'clofarabine', 'sirolimus', 'etoposide', 'sitagliptin',
                    'decitabine', 'PLX-4032', 'fulvestrant', 'bosutinib', 'trifluoperazine', 'ciclosporin']
drug_list_26 = sorted(drug_list_26)
if drug_name == 'all':
    drug_list = drug_list_26
    for d in sorted(data['drug_drug']): # add all drugs from the data
        if d not in drug_list:          # except those already in the list
            drug_list.append(d)         # (to avoid duplicates)
elif drug_name == '26':
    drug_list = drug_list_26
else:
    if drug_name in data['drug_drug']:  # check if the drug is in the data
        drug_list = [drug_name]         # if yes, use it
    else:
        raise ValueError('Selected drug not found: ' + args.drug)

In [10]:
# Here we set the random seed for numpy, torch, and cuda

rseed = 123
cuda = 12345

all_stats = {'train': dict(), 'valid': dict(), 'test': dict()}
for selected_drug in drug_list:
    ## ignore drugs that don't have enough perturbations
    if selected_drug in ["abiraterone", "azacitidine", "cyclophosphamide", "methotrexate", "fluorouracil",
                        "ifosfamide", "ciclopirox"]:
        print("Ignoring: ", selected_drug)
        continue
    rnds = sklearn.utils.check_random_state(rseed)
    np.random.seed(rseed)
    torch.manual_seed(rseed)
    if cuda:
        torch.cuda.manual_seed(rseed)

In [15]:
def selectFromDict(d, keys):
    '''select the chosen @keys in dict @d'''
    res = dict()
    for k in keys:
        if k not in d.keys():
            raise ValueError("Key " + str(k) + " not found")
        else:
            res[k] = d[k]
    return res


In [16]:
def subsetDict(d, ind):
    '''subset each numpy array in dict @d to indices @ind'''
    res = dict()
    if not isinstance(ind, np.ndarray):
        ind = np.asarray(ind)
    for k in d.keys():
        if isinstance(d[k], np.ndarray):
            res[k] = d[k][ind]
    return res

In [12]:
def split_data_sd(data, selectDrug, dataMode, fold, n_folds=5, rnds=None, noPairTest=False,
                  unlab_token=-47, verbose=True):
    ''' Get data of a selected drug from the data set and split for cross-validation:
        * pair (perturbation) data:
            - split to train/valid/test by GroupKFold
            - if @noPairTest set to True, then split just to train/valid (possibly used for SSVAE, DrVAE)
        * singleton data:
            - assign cell lines that also have perturbation pair data to train/valid/test according to pair data split
            - split the rest to train/valid/test by StratifiedKFold
    '''
    #### list of drugs
    all_drugs = data['drug_drug']
    if not selectDrug in all_drugs: raise Exception("drug data not found for " + selectDrug)
    drugIndex = np.where(all_drugs == selectDrug)[0][0]
    if verbose: print(selectDrug, ": index =", drugIndex)

    #### paired perturbation data for the chosen drug
    ## "pair_x1", "pair_x2", "pair_cid", "pair_tid", "pair_s", "pair_y", "pair_ycont", "pair_drug", "pair_m", "pair_conc", "pair_dur"
    alldrugs_ppair_data = selectFromDict(data,
                                         np.asarray(list(data.keys()))[[k.startswith('pair_') for k in data.keys()]])
    ## subset to selectDrug
    all_ppairs = subsetDict(alldrugs_ppair_data, alldrugs_ppair_data['pair_drug'] == selectDrug)
    # ## select duration of ~6h
    # all_ppairs = subsetDict(all_ppairs, np.logical_and(all_ppairs['pair_dur'] > 5, all_ppairs['pair_dur'] < 7))
    ## strip 'pair_' prefix from the keys
    for k in list(all_ppairs.keys()):
        if k.startswith('pair_'): all_ppairs[k[len('pair_'):]] = all_ppairs.pop(k)
    all_ppairs['y'] = np.asarray(all_ppairs['y'], dtype=np.int32)
    # print("all_ppairs", sorted(all_ppairs.keys()))
    # for k,v in all_ppairs.items(): print(k, v.shape)

    ## denote unlabeled data as @unlab_token
    assert np.all((all_ppairs['y'] == -1) == np.isnan(all_ppairs['ycont']))
    all_ppairs['has_y'] = all_ppairs['y'] != -1
    all_ppairs['y'][np.logical_not(all_ppairs['has_y'])] = unlab_token
    all_ppairs['ycont'][np.logical_not(all_ppairs['has_y'])] = unlab_token
    if verbose:
        print("ppair labels:", Counter(all_ppairs['y']))
        print("ppair labels:", stats.describe(all_ppairs['ycont'][all_ppairs['ycont'] != unlab_token]),
              "unlabeled:", np.sum(all_ppairs['ycont'] == unlab_token))

    #### singleton sensitivity data for the chosen drug
    ## "sing_x1", "sing_cid", "sing_tid", "sing_s", "sing_y", "sing_ycont"
    all_sing = selectFromDict(data, np.asarray(list(data.keys()))[[k.startswith('sing_') for k in data.keys()]])
    ## strip 'sing_' prefix from the keys
    for k in list(all_sing.keys()):
        if k.startswith('sing_'): all_sing[k[len('sing_'):]] = all_sing.pop(k)
    ## subset to selectDrug
    # all_sing['x1'] = sklearn.preprocessing.minmax_scale(all_sing['x1']) ######################################
    # all_sing['x1'] = sklearn.preprocessing.scale(all_sing['x1']) #############################################
    # all_sing['x1'] = sklearn.preprocessing.robust_scale(all_sing['x1']) ######################################
    # scaled_tmp = sklearn.preprocessing.robust_scale(np.concatenate((all_sing['x1'], all_ppairs['x1']))) ######
    # all_sing['x1'] = scaled_tmp[:all_sing['x1'].shape[0]]
    # all_ppairs['x1'] = scaled_tmp[all_sing['x1'].shape[0]:]

    all_sing['y'] = np.asarray(all_sing['y'][:, drugIndex], dtype=np.int32)
    all_sing['ycont'] = all_sing['ycont'][:, drugIndex]
    all_sing['drug'] = np.array([data['drug_drug'][drugIndex]] * all_sing['x1'].shape[0])
    all_sing['m'] = np.array([data['drug_m'][drugIndex]] * all_sing['x1'].shape[0])
    # print("all_sing", sorted(all_sing.keys()))
    # for k,v in all_sing.items(): print(k, v.shape)

    ## denote unlabeled data as -47
    assert np.all((all_sing['y'] == -1) == np.isnan(all_sing['ycont']))
    all_sing['has_y'] = all_sing['y'] != -1
    all_sing['y'][np.logical_not(all_sing['has_y'])] = unlab_token
    all_sing['ycont'][np.logical_not(all_sing['has_y'])] = unlab_token
    if verbose:
        print("sing labels:", Counter(all_sing['y']))
        print("sing labels:", stats.describe(all_sing['ycont'][all_sing['ycont'] != unlab_token]),
              "unlabeled:", np.sum(all_sing['ycont'] == unlab_token))

    ############ Train - Valid - Test split
    ### split perturbation pairs:
    # no intersection of cell lines between the k folds
    if verbose: print("cell-line-wise Group K-Fold split on perturbation data")
    ind_tr, ind_va, ind_te = getSplitsByGroupKFold(all_ppairs['cid'], fold, n_folds, True, rnds)
    if noPairTest == False:
        ppair = subsetDict(all_ppairs, ind_tr)
        ppairv = subsetDict(all_ppairs, ind_va)
        ppairt = subsetDict(all_ppairs, ind_te)
    else:  #### no perturbation test set
        ind_tr = np.concatenate((ind_tr, ind_te))
        del ind_te
        ppair = subsetDict(all_ppairs, ind_tr)
        ppairv = subsetDict(all_ppairs, ind_va)
        ## no test set, use empty dummy with a matching dimensionality and dtype
        ppairt = dict([(k, np.empty([0] + list(v.shape[1:]), v.dtype)) for k, v in all_ppairs.items()])

    ### split singleton data:
    # dataMode = 'strictC2C' ## no patient data used; even in unlabeled data
    # dataMode = 'C2C' ## no labeled patient data used    
    # dataMode = 'C2P' ## hold out labeled patient data for testing; use unlabeled patient data in training
    # dataMode = 'CP2C' ##use patient data (both labeled and unlabeled) in training
    if dataMode in ['strictC2C']:
        ## Set aside cell lines that are also in perturbation experiments
        ## (and later add them to respective train/valid/test set)
        sing_wpert_tr = subsetDict(all_sing, np.asarray([x in ppair['cid'] for x in all_sing['cid']]))
        sing_wpert_va = subsetDict(all_sing, np.asarray([x in ppairv['cid'] for x in all_sing['cid']]))
        sing_wpert_te = subsetDict(all_sing, np.asarray([x in ppairt['cid'] for x in all_sing['cid']]))
        # only in viability screens
        unique_idx = np.logical_not(np.asarray([x in all_ppairs['cid'] for x in all_sing['cid']]))
        sing_unique = subsetDict(all_sing, unique_idx)

        ## stratified splits by response class, i.e. preserving the percentage of samples for each class in each fold
        if verbose: print("response-wise Stratified K-Fold split on singleton data")
        ind_tr, ind_va, ind_te = getSplitsByStratifiedKFold(sing_unique['y'], fold, n_folds, True, rnds)
        # apply the splits
        sing = subsetDict(sing_unique, ind_tr)
        singv = subsetDict(sing_unique, ind_va)
        singt = subsetDict(sing_unique, ind_te)

        ## merge cell lines w/perturbations to train/valid/test set
        sing = concatDictElemetwise(sing, sing_wpert_tr)
        singv = concatDictElemetwise(singv, sing_wpert_va)
        singt = concatDictElemetwise(singt, sing_wpert_te)
    else:
        raise Exception('Data mode not supported: ' + dataMode)

    return sing, singv, singt, ppair, ppairv, ppairt

In [17]:
### select the drug data and get CV-split of all data types
downlabel_to = None
semi_supervised = True
type_y = 'discrete'
data_mode = 'both'
fold = 0
pair_data_only = False

y_unlab_token = -47                # token for unlabeled data
y_key = 'y' if type_y == 'discrete' else 'ycont' # key for the response variable
sing, singv, singt, pair, pairv, pairt = split_data_sd( # split data into singletons and pairs
    data, selected_drug, data_mode, fold=fold, n_folds=5, rnds=rnds, 
    noPairTest=True, unlab_token=y_unlab_token, verbose=False
    )

concat_flag = 'pair_only' if pair_data_only else 'both'
train_dataset, train_ddict = wrap_in_DrVAEDataset(sing, pair, y_key, concat=concat_flag,
                                                downlabel_to=downlabel_to,
                                                remove_unlabeled=not semi_supervised)
valid_dataset, valid_ddict = wrap_in_DrVAEDataset(singv, pairv, y_key, concat=concat_flag,
                                                remove_unlabeled=not semi_supervised)
test_dataset, test_ddict = wrap_in_DrVAEDataset(singt, pairt, y_key, concat=concat_flag)

AssertionError: 

In [9]:
N = len(train_dataset)
dim_x = sing['x1'].shape[1]
dim_s = np.unique(sing['s']).shape[0]
if args.type_y == 'discrete':
    class_sizes = np.bincount(sing['y'][ sing['has_y'] ])
    print(class_sizes)
    dim_y = len(class_sizes)
    data_prior_y = class_sizes / (1.*sum(class_sizes))
else:
    dim_y = 1
    _tmp_ycont = sing['ycont'][ sing['has_y'] ]
    data_prior_y = np.array([_tmp_ycont.mean(), _tmp_ycont.std()])

if args.clf_dataprior:
    prior_y = data_prior_y
else:
    prior_y = 'uniform'

NameError: name 'train_dataset' is not defined

In [None]:
if args.modelid in ["auto", "'auto'"]:
            args.modelid = 'RS{}_L{}_YR{:.0f}_FOLD{}'.format(args.rseed, args.L, args.yloss_rate, args.fold)

        print("DrVAE on a single drug, modelid: ", args.modelid)
        print("selected_drug: ", selected_drug)
        print("train data prior: ", data_prior_y)
        print("using prior: ", prior_y)
        print("concat_flag: ", concat_flag)
        print(N, dim_x, dim_y)
        print("sensitivity data")
        if args.type_y == 'discrete':
            print("    train data (Y, S):", Counter(sing['y']), Counter(sing['s']))
            print("    valid data (Y, S):", Counter(singv['y']), Counter(singv['s']))
            print("    test data  (Y, S):", Counter(singt['y']), Counter(singt['s']))
        else:
            print("    train data (Ycont):", stats.describe(sing['ycont'][sing['ycont'] != y_unlab_token]))
            print("    valid data (Ycont):", stats.describe(singv['ycont'][singv['ycont'] != y_unlab_token]))
            print("    test data  (Ycont):", stats.describe(singt['ycont'][singt['ycont'] != y_unlab_token]))

In [None]:
### create balanced sampler for training set
        ## balance by cell line ids
        balanced_weights = utl.compute_balanced_weights(utl.cid2numid(train_ddict['cid']), unlabeled_data_ratio=None)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(balanced_weights, len(balanced_weights))

In [None]:
### create balanced sampler for validation set
        ## balance validation sampler by cell line ids
        balanced_weights = utl.compute_balanced_weights(utl.cid2numid(valid_ddict['cid']), unlabeled_data_ratio=None)
        valid_sampler = torch.utils.data.sampler.WeightedRandomSampler(balanced_weights, len(balanced_weights))

In [None]:
## create train/valid/test data loaders
        dl_kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
        dl_kwargs['batch_size'] = args.batch_size
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   drop_last=(len(train_dataset) >= args.batch_size),
                                                   sampler=train_sampler, **dl_kwargs)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   drop_last=(len(valid_dataset) >= args.batch_size),
                                                   sampler=valid_sampler, **dl_kwargs)

        print('Train set length: {} batches; {} examples'.format(len(train_loader), len(train_loader.dataset)))
        print('Valid set length: {} batches; {} examples'.format(len(valid_loader), len(valid_loader.dataset)))
        print('Test set length: {} examples'.format(len(test_dataset)))

In [None]:
model_filename = 'models/DrVAE_SD_{}_{}.pth'.format(args.modelid, selected_drug)
        model = DrVAE(dim_x=dim_x, dim_s=dim_s, dim_y=dim_y, dim_z1=args.dim_z1, dim_z3=args.dim_z3,
                      dim_h_en_z1=args.enc_z1, dim_h_en_z3=args.enc_z3, dim_h_en_z2Fz1=args.enc_z2Fz1,
                      dim_h_de_z1=args.dec_z1, dim_h_de_x=args.dec_x, type_rec='diag_gaussian',
                      dim_h_clf=args.class_y, type_y=args.type_y, prior_y=prior_y,
                      optim_alg='adam', batch_size=args.batch_size, epochs=300,
                      nonlinearity='elu', L=args.L,
                      weight_decay=0.05, dropout_rate=0., input_x_dropout=args.x_dropout, add_noise_var=args.noise_var,
                      learning_rate=0.0005, yloss_rate=args.yloss_rate, clf_z1z2=True, clf_1sig=args.clf_1sig,
                      anneal_yloss_offset=args.anneal_yloss_offset,
                      kl_qz2pz2_rate=1., pertloss_rate=0.05, anneal_perturb_rate_itermax=1, anneal_perturb_rate_offset=0,
                      use_s=args.useS, use_MMD=args.useMMD, kernel_MMD='rbf_fourier', mmd_rate=args.mmd_rate,  ## <= settings for "fairness" (like VFAE)
                      random_seed=args.rseed, log_txt=None  #'DrVAE_SD_{}_{}.txt'.format(args.modelid, selected_drug)

In [None]:
if args.cuda:
            model.cuda()

        if not args.test_only:
            try:
                model.fit(train_loader=train_loader, valid_loader=valid_loader, add_noise=args.train_w_noise,
                          verbose=True, early_stop=args.stopearly, model_filename=model_filename)
            except Exception as e:
                print("  >>> TRAINING CRASHED <<<")
                print(e)

In [None]:
print ('Loading trained model from file...')
        model.load_params_from_file(model_filename)
        print ('...done.')

        #### test the model and baselines
        print ('Testing...')
        ## run our model
        model_train_perf, train_perfstr = model.evaluate_performance_on_dataset(train_dataset, return_full_data=True)
        model_valid_perf, valid_perfstr = model.evaluate_performance_on_dataset(valid_dataset, return_full_data=True)
        model_test_perf, test_perfstr = model.evaluate_performance_on_dataset(test_dataset, return_full_data=True)
        model.w2log('Train set performance:\t', train_perfstr)
        model.w2log('Valid set performance:\t', valid_perfstr)
        model.w2log('Test set performance:\t', test_perfstr)

In [None]:
## run baselines
        # all_stats['train'][selected_drug] = utl.compile_baseline_stats(args.type_y,
        #                                                         tr=train_ddict, ev=train_ddict,
        #                                                         model_tr=model_train_perf, model_ev=model_train_perf,
        #                                                         svmkernel='rbf', rseed=args.rseed)
        all_stats['valid'][selected_drug] = utl.compile_baseline_stats(args.type_y,
                                                                tr=train_ddict, ev=valid_ddict,
                                                                model_tr=model_train_perf, model_ev=model_valid_perf,
                                                                svmkernel='rbf', rseed=args.rseed)
        all_stats['test'][selected_drug] = utl.compile_baseline_stats(args.type_y,
                                                                tr=train_ddict, ev=test_ddict,
                                                                model_tr=model_train_perf, model_ev=model_test_perf,
                                                                svmkernel='rbf', rseed=args.rseed)

In [None]:

    drug_list_26 = ['omacetaxine mepesuccinate', 'bortezomib', 'vorinostat', 'paclitaxel', 'docetaxel', 'topotecan',
                    'niclosamide', 'valdecoxib','teniposide', 'vincristine', 'prochlorperazine', 'mitomycin', 'lovastatin',
                    'gemcitabine', 'dasatinib', 'fluvastatin', 'clofarabine', 'sirolimus', 'etoposide', 'sitagliptin',
                    'decitabine', 'PLX-4032', 'fulvestrant', 'bosutinib', 'trifluoperazine', 'ciclosporin']
    drug_list_26 = sorted(drug_list_26)
    if args.drug == 'all':
        drug_list = drug_list_26
        for d in sorted(data['drug_drug']): # add all drugs from the data
            if d not in drug_list:          # except those already in the list
                drug_list.append(d)         # (to avoid duplicates)
    elif args.drug == '26':
        drug_list = drug_list_26
    else:
        if args.drug in data['drug_drug']:  # check if the drug is in the data
            drug_list = [args.drug]         # if yes, use it
        else:
            raise ValueError('Selected drug not found: ' + args.drug)

    all_stats = {'train': dict(), 'valid': dict(), 'test': dict()}
    for selected_drug in drug_list:
        ## ignore drugs that don't have enough perturbations
        if selected_drug in ["abiraterone", "azacitidine", "cyclophosphamide", "methotrexate", "fluorouracil",
                             "ifosfamide", "ciclopirox"]:
            print("Ignoring: ", selected_drug)
            continue
        ## initialize random state
        rnds = sklearn.utils.check_random_state(args.rseed)
        np.random.seed(args.rseed)
        torch.manual_seed(args.rseed)
        if args.cuda:
            torch.cuda.manual_seed(args.rseed)

        ### select the drug data and get CV-split of all data types
        y_unlab_token = -47                # token for unlabeled data
        y_key = 'y' if args.type_y == 'discrete' else 'ycont' # key for the response variable
        sing, singv, singt, pair, pairv, pairt = utl.split_data_sd( # split data into singletons and pairs
            data, selected_drug, args.data_mode, fold=args.fold, n_folds=5, rnds=rnds, 
            noPairTest=True, unlab_token=y_unlab_token, verbose=False
        )

        concat_flag = 'pair_only' if args.pair_data_only else 'both'
        train_dataset, train_ddict = wrap_in_DrVAEDataset(sing, pair, y_key, concat=concat_flag,
                                                          downlabel_to=args.downlabel_to,
                                                          remove_unlabeled=not args.semi_supervised)
        valid_dataset, valid_ddict = wrap_in_DrVAEDataset(singv, pairv, y_key, concat=concat_flag,
                                                          remove_unlabeled=not args.semi_supervised)
        test_dataset, test_ddict = wrap_in_DrVAEDataset(singt, pairt, y_key, concat=concat_flag)

        N = len(train_dataset)
        dim_x = sing['x1'].shape[1]
        dim_s = np.unique(sing['s']).shape[0]
        if args.type_y == 'discrete':
            class_sizes = np.bincount(sing['y'][ sing['has_y'] ])
            print(class_sizes)
            dim_y = len(class_sizes)
            data_prior_y = class_sizes / (1.*sum(class_sizes))
        else:
            dim_y = 1
            _tmp_ycont = sing['ycont'][ sing['has_y'] ]
            data_prior_y = np.array([_tmp_ycont.mean(), _tmp_ycont.std()])

        if args.clf_dataprior:
            prior_y = data_prior_y
        else:
            prior_y = 'uniform'

        if args.modelid in ["auto", "'auto'"]:
            args.modelid = 'RS{}_L{}_YR{:.0f}_FOLD{}'.format(args.rseed, args.L, args.yloss_rate, args.fold)

        print("DrVAE on a single drug, modelid: ", args.modelid)
        print("selected_drug: ", selected_drug)
        print("train data prior: ", data_prior_y)
        print("using prior: ", prior_y)
        print("concat_flag: ", concat_flag)
        print(N, dim_x, dim_y)
        print("sensitivity data")
        if args.type_y == 'discrete':
            print("    train data (Y, S):", Counter(sing['y']), Counter(sing['s']))
            print("    valid data (Y, S):", Counter(singv['y']), Counter(singv['s']))
            print("    test data  (Y, S):", Counter(singt['y']), Counter(singt['s']))
        else:
            print("    train data (Ycont):", stats.describe(sing['ycont'][sing['ycont'] != y_unlab_token]))
            print("    valid data (Ycont):", stats.describe(singv['ycont'][singv['ycont'] != y_unlab_token]))
            print("    test data  (Ycont):", stats.describe(singt['ycont'][singt['ycont'] != y_unlab_token]))

        ### create balanced sampler for training set
        ## balance by cell line ids
        balanced_weights = utl.compute_balanced_weights(utl.cid2numid(train_ddict['cid']), unlabeled_data_ratio=None)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(balanced_weights, len(balanced_weights))

        ### create balanced sampler for validation set
        ## balance validation sampler by cell line ids
        balanced_weights = utl.compute_balanced_weights(utl.cid2numid(valid_ddict['cid']), unlabeled_data_ratio=None)
        valid_sampler = torch.utils.data.sampler.WeightedRandomSampler(balanced_weights, len(balanced_weights))
        
        ## create train/valid/test data loaders
        dl_kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
        dl_kwargs['batch_size'] = args.batch_size
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   drop_last=(len(train_dataset) >= args.batch_size),
                                                   sampler=train_sampler, **dl_kwargs)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   drop_last=(len(valid_dataset) >= args.batch_size),
                                                   sampler=valid_sampler, **dl_kwargs)

        print('Train set length: {} batches; {} examples'.format(len(train_loader), len(train_loader.dataset)))
        print('Valid set length: {} batches; {} examples'.format(len(valid_loader), len(valid_loader.dataset)))
        print('Test set length: {} examples'.format(len(test_dataset)))
            
        model_filename = 'models/DrVAE_SD_{}_{}.pth'.format(args.modelid, selected_drug)
        model = DrVAE(dim_x=dim_x, dim_s=dim_s, dim_y=dim_y, dim_z1=args.dim_z1, dim_z3=args.dim_z3,
                      dim_h_en_z1=args.enc_z1, dim_h_en_z3=args.enc_z3, dim_h_en_z2Fz1=args.enc_z2Fz1,
                      dim_h_de_z1=args.dec_z1, dim_h_de_x=args.dec_x, type_rec='diag_gaussian',
                      dim_h_clf=args.class_y, type_y=args.type_y, prior_y=prior_y,
                      optim_alg='adam', batch_size=args.batch_size, epochs=300,
                      nonlinearity='elu', L=args.L,
                      weight_decay=0.05, dropout_rate=0., input_x_dropout=args.x_dropout, add_noise_var=args.noise_var,
                      learning_rate=0.0005, yloss_rate=args.yloss_rate, clf_z1z2=True, clf_1sig=args.clf_1sig,
                      anneal_yloss_offset=args.anneal_yloss_offset,
                      kl_qz2pz2_rate=1., pertloss_rate=0.05, anneal_perturb_rate_itermax=1, anneal_perturb_rate_offset=0,
                      use_s=args.useS, use_MMD=args.useMMD, kernel_MMD='rbf_fourier', mmd_rate=args.mmd_rate,  ## <= settings for "fairness" (like VFAE)
                      random_seed=args.rseed, log_txt=None  #'DrVAE_SD_{}_{}.txt'.format(args.modelid, selected_drug)
                      )
        if args.cuda:
            model.cuda()

        if not args.test_only:
            try:
                model.fit(train_loader=train_loader, valid_loader=valid_loader, add_noise=args.train_w_noise,
                          verbose=True, early_stop=args.stopearly, model_filename=model_filename)
            except Exception as e:
                print("  >>> TRAINING CRASHED <<<")
                print(e)
    
        print ('Loading trained model from file...')
        model.load_params_from_file(model_filename)
        print ('...done.')

        #### test the model and baselines
        print ('Testing...')
        ## run our model
        model_train_perf, train_perfstr = model.evaluate_performance_on_dataset(train_dataset, return_full_data=True)
        model_valid_perf, valid_perfstr = model.evaluate_performance_on_dataset(valid_dataset, return_full_data=True)
        model_test_perf, test_perfstr = model.evaluate_performance_on_dataset(test_dataset, return_full_data=True)
        model.w2log('Train set performance:\t', train_perfstr)
        model.w2log('Valid set performance:\t', valid_perfstr)
        model.w2log('Test set performance:\t', test_perfstr)

        ## run baselines
        # all_stats['train'][selected_drug] = utl.compile_baseline_stats(args.type_y,
        #                                                         tr=train_ddict, ev=train_ddict,
        #                                                         model_tr=model_train_perf, model_ev=model_train_perf,
        #                                                         svmkernel='rbf', rseed=args.rseed)
        all_stats['valid'][selected_drug] = utl.compile_baseline_stats(args.type_y,
                                                                tr=train_ddict, ev=valid_ddict,
                                                                model_tr=model_train_perf, model_ev=model_valid_perf,
                                                                svmkernel='rbf', rseed=args.rseed)
        all_stats['test'][selected_drug] = utl.compile_baseline_stats(args.type_y,
                                                                tr=train_ddict, ev=test_ddict,
                                                                model_tr=model_train_perf, model_ev=model_test_perf,
                                                                svmkernel='rbf', rseed=args.rseed)
        ## print the stats
        for k in all_stats['test'][selected_drug].keys():
            valid_v = all_stats['valid'][selected_drug][k]
            test_v = all_stats['test'][selected_drug][k]
            k = k.replace('|', '\t')
            if isinstance(valid_v, float):
                print('{}\t{:.4f}\t{:.4f}'.format(k, valid_v, test_v))
            else:
                print(k, valid_v, test_v)

        ## save performance reports as json
        for evalset in ('valid', 'test'):
            ## all stats
            results_fname = 'results/DrVAE_SD_{}_all_{}.json'.format(evalset, args.modelid)
            with open(results_fname, 'wb') as f:
                jd = json.dumps(all_stats[evalset], sort_keys=False, indent=4, separators=(',', ': '))
                f.write(jd.encode())

    return all_stats


if __name__ == '__main__':
    ### Parse command line arguments
    parser = argparse.ArgumentParser(description='Drug response VAE (DrVAE)')
    parser.add_argument('--cuda', action='store_true', default=False, help='enables CUDA training')
    parser.add_argument('--modelid', type=str, required=True, help='model ID')
    parser.add_argument('--datafile', type=str, required=True, help='input data file')
    parser.add_argument('--outdir', type=str, default=None, help='output directory')
    parser.add_argument('--fold', type=int, default=1, help='which data fold to run (1 to 5)')
    parser.add_argument('--test-only', action='store_true', default=False, help='load saved parameters and run tests')
    parser.add_argument('--batch-size', type=int, default=200, help='minibatch size')
    parser.add_argument('--L', type=int, default=1, help='number of samples from Q (default 1)')
    parser.add_argument('--stopearly', action="store_true", dest="stopearly", default=False, help='train with early stopping (default False)')
    parser.add_argument('--yloss-rate', type=float, default=50., help='weight of prediction loss on variable Y')
    parser.add_argument('--rseed', type=int, default=12345, help='random seed')
    parser.add_argument('--use-s', action="store_true", dest="useS", default=False, help='use model with nuisance variable S')
    parser.add_argument('--use-mmd', action="store_true", dest="useMMD", default=False, help='include MMD loss to improve independence of Z from S')
    parser.add_argument('--no-mf', action="store_true", dest="useMF", default=False, help='use molecular features of drugs')
    parser.add_argument('--mmd-rate', type=float, default=1., help='weight of MMD loss')
    parser.add_argument('--data-mode', type=str, default="strictC2C", help='which data to use for training and testing')
    parser.add_argument('--drug', type=str, default="26", help='select one drug to run or set of "26" or "all"')
    parser.add_argument('--fully-supervised', action="store_false", dest="semi_supervised", default=True, help='use only labeled data')
    parser.add_argument('--anneal-yloss-offset', type=int, default=1, help='offset (in num of interations) for annealing "Y" loss')
    parser.add_argument('--pair-data-only', action='store_true', default=False, help='use only perturbation pair data (ignore singletons)')
    parser.add_argument('--train-w-noise', action='store_true', default=False, help='add random Gaussian noise to the gene expression during training')
    parser.add_argument('--noise-var', type=float, default=0.01, help='variance of the Gaussian noise to agument input data')
    parser.add_argument('--x-dropout', type=float, default=0., help='Input dropout rate')
    parser.add_argument('--downlabel-to', type=int, default=None, help='Reduce number of labeled data in training set to given number by masking labeled data as unlabeled')
    # architecture:
    parser.add_argument('--dim-z1', type=int, default=50, help='size of z1 & z2')
    parser.add_argument('--dim-z3', type=int, default=50, help='size of z3')
    parser.add_argument('--enc-z1', type=int, nargs='+', default=[200,200], help='NN size of encoder q(z_k|x_k)')
    parser.add_argument('--enc-z2Fz1', type=int, nargs='+', default=[], help='NN size of encoder p(z2|z1)')
    parser.add_argument('--enc-z3', type=int, nargs='+', default=[200], help='NN size of encoder q(z3|z1,y)')
    parser.add_argument('--dec-z1', type=int, nargs='+', default=[200], help='NN size of decoder p(z1|z3,y)')
    parser.add_argument('--dec-x', type=int, nargs='+', default=[200,200], help='NN size of decoder p(x_k|z_k)')
    parser.add_argument('--class-y', type=int, nargs='+', default=[], help='NN size of classifier p(y|z1,z2)')
    parser.add_argument('--type-y', type=str, default='discrete', help='("discrete" or "cont"): classification vs regression')
    parser.add_argument('--clf-1sig', action='store_true', default=False, help='if the task is binary classification, use 1 sigmoid unit instead of softmax over 2 units')
    parser.add_argument('--clf-dataprior', action='store_true', default=False, help='use training data distribution prior, otherwise uniform')

    args = parser.parse_args()
    print(args)

    # args.cuda = args.cuda and torch.cuda.is_available()    
    assert (args.cuda == False), 'GPU is not supported yet'

    availableDataModes = ['strictC2C'] ## no patient data used; even in unlabeled data
    assert (args.data_mode in availableDataModes), 'Unsupported data mode'

    assert (args.type_y in ['discrete', 'cont']), 'Invalid type-y'
    assert (args.useMF == False), 'use of mol.features not supported'

    main(args)

