In [1]:
import torch
from misc import normalize_loss
import json
import os
import torch.nn.functional as F
from models import AE, MultiClassClassifier, BinaryClassifier, VAE
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier

from semisupervised_train_feb25_2 import train_semisupervised, CompoundDataset#, evaluate_head_auc, evaluate_adversary_auc
from misc import get_dropbox_dir

from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score

In [2]:
def evaluate_head_auc(head_ouputs, head_targets, adversary_outputs, adversary_targets, head, adversary):
    if len(head_ouputs) == 0:
        return 0
    head_probs = head.logits_to_proba(head_ouputs.detach()).numpy()
    head_targets_np = head_targets.numpy()
    nan_mask = ~np.isnan(head_targets_np)
    # print(nan_mask)
    # print(head_targets_np)
    # print(head_probs[nan_mask])
    # print(head_targets_np[nan_mask])
    # nan_mask = ~torch.isnan(head_targets)
    if head.num_classes > 2:
        head_auc = roc_auc_score(head_targets_np[nan_mask], head_probs[nan_mask],
                                average='weighted', multi_class='ovo')
    else:
        head_auc = roc_auc_score(head_targets_np[nan_mask], head_probs[nan_mask],
                                average='weighted')
    return head_auc

def evaluate_adversary_auc(head_ouputs, head_targets, adversary_outputs, adversary_targets,head, adversary):
    if len(adversary_outputs) == 0:
        return 0
    adversary_probs = adversary.logits_to_proba(adversary_outputs.detach()).numpy()
    adversary_targets_np = adversary_targets.numpy()
    nan_mask = ~np.isnan(adversary_targets_np)

    if adversary.num_classes > 2:
        adversary_auc = roc_auc_score(adversary_targets_np[nan_mask], adversary_probs[nan_mask],
                                    average='weighted', multi_class='ovo')
    else:
        adversary_auc = roc_auc_score(adversary_targets_np[nan_mask], adversary_probs[nan_mask],
                                    average='weighted')
    return adversary_auc

end_state_eval_funcs = {
    # 'head_accuracy': evaluate_head_accuracy,
    # 'adversary_accuracy': evaluate_adversary_accuracy,
    'head_auc': evaluate_head_auc,
    'adversary_auc': evaluate_adversary_auc
}

In [8]:
dropbox_dir = get_dropbox_dir()
# data_dir = f'{dropbox_dir}/development_CohortCombination/benefit_study_feb20'
data_dir = f'{dropbox_dir}/development_CohortCombination/mskcc_prediction_study_feb19'
output_dir = os.path.join(data_dir,'finetune_adversarial_network_Feb25')
os.makedirs(output_dir, exist_ok=True)

pretrain_model_id = 'feb26_AE_4_with_head-0.2_and_adversary-5'
pretrain_load_dir = f'{dropbox_dir}/development_CohortCombination/reconstruction_study_feb16/adversarial_network_Feb25/{pretrain_model_id}_0'

# head_col = 'Benefit'
adv_col = 'Sex'
# head_col_mapper = {'CB': 1, 'NCB': 0, 'ICB': np.nan}
adv_col_mapper = {'M': 1, 'F': 0}

head_col = 'MSKCC'
head_col_mapper = {'FAVORABLE': 1, 'POOR': 0, 'INTERMEDIATE': np.nan}

batch_size = 64
num_epochs = 200
encoder_weight = 0
head_weight = 1
adversary_weight = 0
latent_dim = 26
hidden_size = 14
dropout_rate = 0
activation = 'tanh'
use_batch_norm = False
noise_factor = 0.05
learning_rate = 2.7224595522366062e-05
beta = 0
val_frac = 0

finetune_id = 2
model_id = pretrain_model_id + f'_finetune_{finetune_id}'

for subset_num in range(0,30):
    model_subdir = f'{model_id}_{subset_num}'
    model_dir = os.path.join(output_dir, model_subdir)
    os.makedirs(model_dir, exist_ok=True)

    X_train = pd.read_csv(os.path.join(data_dir, f'X_train_{subset_num}.csv'),index_col=0)
    y_train = pd.read_csv(os.path.join(data_dir, f'y_train_{subset_num}.csv'),index_col=0)
    X_test = pd.read_csv(os.path.join(data_dir, f'X_test_{subset_num}.csv'),index_col=0)
    y_test = pd.read_csv(os.path.join(data_dir, f'y_test_{subset_num}.csv'),index_col=0)

    # apply the mapper to the head and adv columns
    y_train[head_col] = y_train[head_col].map(head_col_mapper)
    y_train[adv_col] = y_train[adv_col].map(adv_col_mapper)
    y_test[head_col] = y_test[head_col].map(head_col_mapper)
    y_test[adv_col] = y_test[adv_col].map(adv_col_mapper)

    train_dataset = CompoundDataset(X_train, y_train[head_col], y_train[adv_col])

    num_classes_head = train_dataset.get_num_classes_head()
    num_classes_adv = train_dataset.get_num_classes_adv()

    weights_head = train_dataset.get_class_weights_head()
    weights_adv = train_dataset.get_class_weights_adv()

    # Split the training dataset into training and validation sets
    if val_frac>0:
        train_size = int((1-val_frac) * len(train_dataset))
        val_size = len(train_dataset) - train_size

        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    test_dataset = CompoundDataset(X_test, y_test[head_col], y_test[adv_col])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    # %%
    # create the models
    input_dim = X_train.shape[1]                                      
    # encoder = AE(input_size=input_dim, latent_size=latent_dim, hidden_size=32, num_hidden_layers=1, 
    #                 dropout_rate=0.2,use_batch_norm=True,act_on_latent_layer=True, activation='tanh')

    encoder = AE(input_size=input_dim, latent_size=latent_dim, hidden_size=hidden_size, num_hidden_layers=1, 
                    dropout_rate=dropout_rate,use_batch_norm=use_batch_norm,act_on_latent_layer=True, activation=activation)

    encoder.load_state_dict(torch.load(os.path.join(pretrain_load_dir, 'encoder.pth')))

    if num_classes_head > 2:
        head = MultiClassClassifier(latent_dim, hidden_size=4, num_hidden_layers=0, num_classes=num_classes_head)
    else:
        head = BinaryClassifier(latent_dim, hidden_size=4, num_hidden_layers=0)

    if num_classes_adv > 2:
        adversary = MultiClassClassifier(latent_dim, hidden_size=4, num_hidden_layers=1, num_classes=num_classes_adv)
    else:
        adversary = BinaryClassifier(latent_dim, hidden_size=4, num_hidden_layers=1)

    head_info = {'class_weight': weights_head, 'num_classes': num_classes_head}
    adversary_info = {'class_weight': weights_adv, 'num_classes': num_classes_adv}

    head.define_loss(class_weight=weights_head)
    adversary.define_loss(class_weight=weights_adv)


    # %%
    dataloader_dct = {'train': train_loader, 'test': test_loader}
    if val_frac>0:
        dataloader_dct['val'] = val_loader



    # %%

    # TODO: allow adversary to have extra training
    # TODO: retrain adversary on fixed latent space at the end of training
    # TODO: allow for multiple heads and multiple adversaries   
    # TODO define loss-weights in the model object
    # TODO ordinal category loss and head
    # TODO: Update the TGEM model to be a encoder-style model class "TGEM_Encoder"
    # TODO: add a "task" (binary, regression, multi, etc) to the model class
    # TODO: add a "goal" (reduce, adversarial, primary) (method: define_goal)

    encoder, head, adversary, output_data = train_semisupervised(dataloader_dct,encoder,head,adversary,
                        save_dir=model_dir,
                        num_epochs=num_epochs,
                        learning_rate=learning_rate,
                        noise_factor= noise_factor,
                        encoder_weight=encoder_weight,
                        head_weight=head_weight,
                        beta= beta,
                        adversary_weight=adversary_weight,
                        end_state_eval_funcs=end_state_eval_funcs)

Epoch [1/200], train Loss: 0.9982
train head_auc 0.664381454625357
Epoch [11/200], train Loss: 0.9239
train head_auc 0.8888156449132059
Epoch [21/200], train Loss: 0.8620
train head_auc 0.9158866183256427
Epoch [31/200], train Loss: 0.8423
train head_auc 0.9341683146561196
Epoch [41/200], train Loss: 0.8020
train head_auc 0.9483629971434849
Epoch [51/200], train Loss: 0.7795
train head_auc 0.9601406284333114
Epoch [61/200], train Loss: 0.7433
train head_auc 0.9696330476818282
Epoch [71/200], train Loss: 0.7093
train head_auc 0.9758294880246099
Epoch [81/200], train Loss: 0.6811
train head_auc 0.9813667325862447
Epoch [91/200], train Loss: 0.6445
train head_auc 0.9852340145023072
Epoch [101/200], train Loss: 0.6286
train head_auc 0.987695012085256
Epoch [111/200], train Loss: 0.5933
train head_auc 0.9897165458141067
Epoch [121/200], train Loss: 0.5524
train head_auc 0.9912546693034499
Epoch [131/200], train Loss: 0.5146
train head_auc 0.9924851680949243
Epoch [141/200], train Loss: 0.49

In [10]:
vals = []
for subset_num in range(0,30):
    model_subdir = f'{model_id}_{subset_num}'
    model_dir = os.path.join(output_dir, model_subdir)
    output_data = json.load(open(os.path.join(model_dir, 'output.json')))
    vals.append(output_data['end_state_eval']['test']['head_auc'])
    
print(np.mean(vals), np.std(vals))

0.9018397814966441 0.04032743132443603
