In [1]:
import sys
sys.path.append('../../Nina repo')
import scipy.optimize._minimize
import torchvision
from dataloader.dataloader import DISEASE_LABELS_NIH,NIHDataResampleModule, DISEASE_LABELS_CHE, CheXpertDataResampleModule
from prediction.models import ResNet,DenseNet



import os
import torch


import matplotlib.pyplot as plt
plt.style.use("ggplot")

from tqdm import tqdm
from argparse import ArgumentParser
import numpy as np
from misc import *
from gmm_and_mix_class import *



In [2]:
parser = ArgumentParser()
#parser.add_argument('--gpus', default=1)
parser.add_argument('--dev', default=0)

#dataset_default = "chexpert"
#img_dir_default = r"C:\Users\Karlu\Desktop\11\Learning From Noisy Data\archivepreproc_224x224"

dataset_default = "NIH"
img_dir_default = r"C:\Users\Karlu\Desktop\11\Learning From Noisy Data\NIH\preproc_224x224"

disease_label_default = ['Pneumothorax']
female_percent_in_training_default = "50"
npp_default = 1 #Number per patient, could be integer or None (no sampling)
run_dir_default = r"C:\Users\Karlu\Desktop\11\Learning From Noisy Data\runs"

epochs_default = 38
pretrained_default = True
save_model_default = False
num_workers_default = 0
augmentation_default = True
model_scale_default = '50'
batch_size_default = 85
lr_default = 1e-4

# hps that need to chose when training
parser.add_argument('-s','--dataset',default=dataset_default,help='Dataset', choices =['NIH','chexpert'])
parser.add_argument('-d','--disease_label',default=disease_label_default, help='Chosen disease label', type=str, nargs='*')
parser.add_argument('-f', '--female_percent_in_training', default=female_percent_in_training_default,
                    help='Female percentage in training set, should be any of [0, 50, 100]', type=str, nargs='+')
parser.add_argument('-n', '--npp',default=npp_default, help='Number per patient, could be integer or None (no sampling)',type=int)
parser.add_argument('-r', '--random_state', default='0-10', help='random state')
parser.add_argument('-p','--img_dir', default=img_dir_default, help='your img dir path here',type=str)
parser.add_argument('-rd','--run_dir', default=run_dir_default, help='where the runs are saved',type=str)

# hps that set as defaults
parser.add_argument('--lr', default=lr_default, help='learning rate, default=1e-6')
parser.add_argument('--bs', default=batch_size_default, help='batch size, default=64')
parser.add_argument('--epochs',default=epochs_default,help='number of epochs, default=20')
parser.add_argument('--model', default='resnet', help='model, default=\'ResNet\'')
parser.add_argument('--model_scale', default=model_scale_default, help='model scale, default=50',type=str)
parser.add_argument('--pretrained', default=pretrained_default, help='pretrained or not, True or False, default=True',type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--augmentation', default=augmentation_default, help='augmentation during training or not, True or False, default=True',type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--is_multilabel',default=False,help='training with multilabel or not, default=False, single label training',type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--image_size', default=224,help='image size',type=int)
parser.add_argument('--crop',default=None,help='crop the bottom part of the image, the percentage of cropped part, when cropping, default=0.6')
parser.add_argument('--prevalence_setting',default='separate',help='which kind of prevalence are being used when spliting,\
                    choose from [separate, equal, total]',choices=['separate','equal','total'])
parser.add_argument('--save_model',default=save_model_default,help='dave model parameter or not',type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--num_workers', default=num_workers_default, help='number of workers')

args = parser.parse_args()


# other hps
if args.is_multilabel:
    args.num_classes = len(DISEASE_LABELS_NIH) if args.dataset == 'NIH' else len(DISEASE_LABELS_CHE)
else: args.num_classes = 1


if args.image_size == 224:
    args.img_data_dir = args.img_dir+'{}/preproc_224x224/'.format(args.dataset)
elif args.image_size == 1024:
    args.img_data_dir = args.img_dir+'{}/images/'.format(args.dataset)

if args.dataset == 'NIH':
    args.csv_file_img = '../datafiles/'+'Data_Entry_2017_v2020_clean_split.csv'
elif args.dataset == 'chexpert':
    args.csv_file_img = '../datafiles/'+'chexpert.sample.allrace.csv'
else:
    raise Exception('Not implemented.')

#print('hyper-parameters:')
#print(args)

if len(args.random_state.split('-')) != 2:
    if len(args.random_state.split('-')) == 1:
        rs_min, rs_max = int(args.random_state), int(args.random_state)+1
    else:
        raise Exception('Something wrong with args.random_states : {}'.format(args.random_states))
rs_min, rs_max = int(args.random_state.split('-')[0]),int(args.random_state.split('-')[1])

# female_percent_in_training_set = [int(percent) for percent in args.female_percent_in_training.split(" ")]
female_percent_in_training_set = [50]
print('female_percent_in_training_set:{}'.format(female_percent_in_training_set))
disease_label_list = args.disease_label #[''.join(each) for each in args.disease_label]
if len(disease_label_list) ==1 and disease_label_list[0] == 'all':
    disease_label_list = DISEASE_LABELS_NIH if args.dataset == 'NIH' else DISEASE_LABELS_CHE
print('disease_label_list:{}'.format(disease_label_list))

female_percent_in_training_set:[50]
disease_label_list:['Pneumothorax']


In [None]:
def main(args, female_perc_in_training=None, random_state=None, chose_disease_str=None):

    # sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
    #pl.seed_everything(42, workers=True)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:" + str(args.dev) if use_cuda else "cpu")
    print('DEVICE:{}'.format(device))

    # get run_config
    run_config = f'{args.dataset}-{chose_disease_str}' # dataset and the predicted label
    run_config+= f'-fp{female_perc_in_training}-npp{args.npp}-rs{random_state}' #f_per, npp and rs

    # if the hp value is not default
    # args_dict = vars(args)
    # for each_hp in hp_default_value.keys():
    #     if (hp_default_value[each_hp] != args_dict[each_hp] and
    #             each_hp!="num_workers"):
    #
    #         run_config+= f'-{each_hp}{args_dict[each_hp]}'

    print('------------------------------------------\n'*3)
    print('run_config: {}'.format(run_config))

    # Create output directory
    # out_name = str(model.model_name)
    run_dir = args.run_dir#'/work3/ninwe/run/cause_bias/'
    out_dir = run_dir + run_config
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    cur_version = get_cur_version(out_dir)

    if args.dataset == 'NIH':
        data = NIHDataResampleModule(img_data_dir=args.img_data_dir,
                                     csv_file_img=args.csv_file_img,
                                     image_size=args.image_size,
                                     pseudo_rgb=False,
                                     batch_size=args.bs, #90 is limit i.e. 10.9gb vram
                                     num_workers=args.num_workers,
                                     augmentation=args.augmentation,
                                     outdir=out_dir,
                                     version_no=cur_version,
                                     female_perc_in_training=female_perc_in_training,
                                     chose_disease=chose_disease_str,
                                     random_state=random_state,
                                     num_classes=args.num_classes,
                                     num_per_patient=args.npp,
                                     crop=args.crop,
                                     prevalence_setting = args.prevalence_setting,

                                     )
    elif args.dataset == 'chexpert':
        if args.crop != None:
            raise Exception('Crop experiment not implemented for chexpert.')
        data = CheXpertDataResampleModule(img_data_dir=args.img_data_dir,
                                          csv_file_img=args.csv_file_img,
                                          image_size=args.image_size,
                                          pseudo_rgb=False,
                                          batch_size=args.bs, #90 is limit i.e. 10.9gb vram
                                          num_workers=args.num_workers,
                                          augmentation=args.augmentation,
                                          outdir=out_dir,
                                          version_no=cur_version,
                                          female_perc_in_training=female_perc_in_training,
                                          chose_disease=chose_disease_str,
                                          random_state=random_state,
                                          num_classes=args.num_classes,
                                          num_per_patient=args.npp,
                                          prevalence_setting = args.prevalence_setting

                                          )

    else:
        raise Exception('not implemented')
    
    
    from sklearn.model_selection import KFold
    k_folds=5
    kfold = KFold(n_splits=k_folds, shuffle=True)


  # K-fold Cross Validation model evaluation
    for fold, (train_ids, test_ids) in enumerate(kfold.split(data.train_set)):
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
        
        path = os.path.join(os.getcwd(), "kfold_mixup_exp", f"fold{fold}")
        os.makedirs(path, exist_ok=True)
        
        data_loader = data.train_dataloader(subsampler=train_subsampler)
        data_loader_holdout = data.train_dataloader(subsampler=test_subsampler)
        # model
        if args.model == 'resnet':
            model_type = ResNet
        elif args.model == 'densenet':
            model_type = DenseNet
        model = model_type(num_classes=args.num_classes,lr=args.lr,pretrained=args.pretrained,model_scale=args.model_scale,
                           loss_func_type = 'BCE')
    
        batch_size = args.bs
        model.to(device)
        epochs = args.epochs
        lr = args.lr
        class_imbalance_train = data.df_train[args.disease_label[0]].values.mean()
        #class_imbalance_train = torch.tensor([class_imbalance_train], device=device, dtype=torch.float)
        class_imbalance_init = 4.5
        #loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=class_imbalance)
        loss_fn = bce_with_logits_mannual(class_imb=class_imbalance_init, n=len(data.df_train))
        optimizer = torch.optim.AdamW(params=model.model.parameters(), lr=lr)
        #optimizer = torch.optim.AdamW(params=model.model.parameters(), lr=lr)
        #scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=50)
        scheduler = scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=3)
    
        losses = []
        accs = []
        lrs = [lr]
        accu_female_ALL = [0.5]
        accu_male_ALL = [0.5]
        auroc_male_ALL = [0.5]
        auroc_female_ALL = [0.5]
        batch_num = [0]
        #AUROC(task='binary', num_labels=num_classes, average='macro', thresholds=None)
    
        if args.dataset == "NIH":
            is_female_val = data.df_valid["Patient Gender"].values == "F"
            is_female_val = torch.from_numpy(is_female_val).to(torch.bool).to(device)
            is_female_train = data.df_train["Patient Gender"].values == "F"
            #is_female_train = torch.from_numpy(is_female_train).to(torch.bool).to(device)
        else:
            is_female_val = data.df_valid["sex"].values == "Female"
            is_female_val = torch.from_numpy(is_female_val).to(torch.bool).to(device)
            is_female_train = data.df_train["sex"].values == "Female"
            #is_female_train = torch.from_numpy(is_female_train).to(torch.bool).to(device)
    
        beta = torch.distributions.beta.Beta(20.0, 20.0)
        method = "mixup"
        print(f"using {method} approch/method")
        assert method in ["plain", "gmm and mixup", "mixup", "gmm"]
        
        
        for epoch in range(epochs):
            model.train()
            prog_bar = tqdm(data_loader, unit="batches")
            prog_bar.set_description(f"train epoch {epoch+1}/{epochs}")
            prog_bar.set_postfix({"male val acc":accu_male_ALL[-1], "female val acc": accu_female_ALL[-1] })
    
            losses_non_reduced = []
            ids = []
            predicted_class_train = []
            ys = []
            for i, xy in enumerate(prog_bar):
                x = xy["image"].to(device)
                y = xy["label"].to(device)
                id = xy["id"].to(device)
                ys.append(y)
                ids.append(id)
    
                if method == "plain":
                    logits_pure = model.forward(x)
                    loss = loss_fn.bce_with_logits_plain(logits = logits_pure, y=y, reduce=True)
    
                elif method == "gmm and mixup":
                    x_mix, y_mix, delta, perm = mixup(x, y, beta, device)
                    logit_mixup = model.forward(x_mix)
    
                    with torch.no_grad():
                        logits_pure = model.forward(x)
                        loss_pure = loss_fn.bce_with_logits_plain(logit_mixup, y)
    
                    loss1 = loss_fn.bce_mix_up_gmm(logit_mixup, loss_pure, y_mix, perm=None, epoch=epoch)
                    loss2 = loss_fn.bce_mix_up_gmm(logit_mixup, loss_pure, y_mix, perm=perm, epoch=epoch)
                    loss = 0.5*(loss1 + loss2)
    
                elif method == "mixup":
                    x_mix, y_mix, delta, perm = mixup(x, y, beta, device)
                    logit_mixup = model.forward(x_mix)
                    loss = loss_fn.bce_with_logits_plain(logit_mixup, y_mix, s=loss_fn.class_imb_scaling, reduce=True)
                    with torch.no_grad():
                        logits_pure = model.forward(x)
    
                elif method == "gmm":
                    logits_pure = model.forward(x)
                    with torch.no_grad():
                        loss_pure = loss_fn.bce_with_logits_plain(logits_pure, y)
                    loss = loss_fn.bce_mix_up_gmm(logits_pure, loss_pure, y, perm=None, epoch=epoch)
    
                else:
                    AssertionError("pick a proper model!")
    
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
                logits_pure = logits_pure.detach()
                losses_non_reduced.append(loss_fn.bce_with_logits_plain(logits_pure, y))
                prediction = torch.sigmoid(logits_pure)
                predicted_class_train.append(prediction > 0.5)
                correct_prediction = ((prediction > 0.5) == y).to(torch.float)
                accs.append(torch.mean(correct_prediction))
    
                losses.append(loss.detach())
    
            predicted_class_train = torch.concat(predicted_class_train, dim=0)
            avg_predicted_class_train = predicted_class_train.to(torch.float).mean().item()
            loss_fn.class_imb_scaling_step(0.25, avg_predicted_class_train, class_imbalance_train)
            #
    
            losses_non_reduced = torch.concat(losses_non_reduced)
            ids = torch.concat(ids)
            if "gmm" in method:
                loss_fn.fit_and_plot_gmm(epoch_num=epoch,
                                    losses_non_reduced=losses_non_reduced,
                                    y=torch.concat(ys).cpu().numpy().squeeze(),
                                    preds=predicted_class_train.to(torch.int).cpu().numpy().squeeze(),
                                    is_female=is_female_train,
                                    ids = ids,
                                    plot_hist=True)
    
            lrs.append(scheduler.get_last_lr()[0])
            scheduler.step()
    
            with torch.no_grad():
                model.eval()
                val_predictions = []
                val_labels = []
    
                for xy in data.val_dataloader():
                #for xy in tqdm(data.val_dataloader(), desc="val "):
                    x = xy["image"]
                    y = xy["label"]
                    x = x.to(device)
                    y = y.to(device)
                    forward_logit = model.forward(x)
                    prediction = torch.sigmoid(forward_logit)
    
                    val_predictions.append(prediction)
                    val_labels.append(y)
    
                print(f"E{epoch+1}: "
                      f" avg val pred class:"
                      f" {torch.mean((torch.concat(val_predictions, dim=0) > 0.5).to(torch.float)).item():.4f},"
                      f" avg train pred class:"
                      f" {avg_predicted_class_train:.4f},"
                      f" loss_fn.class_imb_scaling: {loss_fn.class_imb_scaling}"
                      f" curr lr: {scheduler.get_last_lr()[0]:.7f}")
    
                predictions_female = torch.concat(val_predictions, dim=0)[is_female_val]
                labels_female = torch.concat(val_labels, dim=0)[is_female_val]
                accu_female = model.accu_func(predictions_female, labels_female).item()
                auroc_female = model.auroc_func(predictions_female, labels_female).item()
                accu_female_ALL.append(accu_female)
                auroc_female_ALL.append(auroc_female)
    
                predictions_male = torch.concat(val_predictions, dim=0)[~is_female_val]
                labels_male = torch.concat(val_labels, dim=0)[~is_female_val]
                accu_male = model.accu_func(predictions_male, labels_male).item()
                auroc_male = model.auroc_func(predictions_male, labels_male).item()
                accu_male_ALL.append(accu_male)
                auroc_male_ALL.append(auroc_male)
    
                batch_num.append((epoch+1)*len(data_loader))
    
            #prog_bar.set_postfix({"cur lr": scheduler.get_last_lr()})
            plt.figure(figsize=[7, 7])
            plt.plot(ewma(torch.stack(losses).cpu().numpy().astype(np.float64), 50), label="TRAIN loss (smooth)")
            plt.plot(ewma(torch.stack(accs).cpu().numpy().astype(np.float64), 50), label ="TRAIN accu (smooth)")
            plt.plot(batch_num, accu_male_ALL, marker='o', ls="--", label =f"val acc male: {accu_male_ALL[-1]:.3f}")
            plt.plot(batch_num, accu_female_ALL, marker='o', label =f"val acc female: {accu_female_ALL[-1]:.3f}")
    
            plt.plot(batch_num, auroc_male_ALL, marker='*', ls="--", label =f"val AUROC male: {auroc_male_ALL[-1]:.3f}")
            plt.plot(batch_num, auroc_female_ALL, marker='*', label =f"val AUROC female: {auroc_female_ALL[-1]:.3f}")
            plt.plot(batch_num, np.asarray(lrs) / np.asarray(lrs).max(), label= "LR (normalized)",
                 marker='o', alpha=0.5)
            plt.title(method + " resnet"+args.model_scale +" "+ args.dataset +"  epoch: "+ str(epoch+1))
            batch_num_array = np.linspace(0, batch_num[-1], 5, dtype=int)
            plt.xticks(batch_num_array, batch_num_array*batch_size)
            plt.xlabel("training samples")
            #plt.ylabel("loss and accu")
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(path, "plot.png"), dpi=150)
            #plt.show()
           
        print(f"fold {fold} has been completed saving data")
        with torch.no_grad():
            model.eval()
            forward_logits = []
            ys = []
            ids = []
            for xy in data_loader_holdout:
                x = xy["image"]
                y = xy["label"]
                id = xy["id"]
                x = x.to(device)
                forward_logit = model.forward(x)
                
                forward_logits.append(forward_logit)
                ys.append(y)
                ids.append(id)
              
        is_fem = torch.from_numpy(is_female_train[torch.cat(ids, dim = 0).cpu()])   
        data_dict = {
            "forward_logits": torch.cat(forward_logits, dim = 0).cpu().tolist(),
            "ys": torch.cat(ys, dim = 0).cpu().tolist(),
            "ids": torch.cat(ids, dim = 0).cpu().tolist(),
            "is_fem": is_fem.tolist()
        }

        plt.show()
        torch.save(torch.cat(forward_logits, dim = 0).cpu(), os.path.join(path, "logits.pt"))
        torch.save(torch.cat(ys, dim = 0).cpu(), os.path.join(path, "ys.pt"))
        torch.save(torch.cat(ids, dim = 0).cpu(), os.path.join(path, "ids.pt"))
        torch.save(is_fem, os.path.join(path, "is_fem.pt"))
        
        
        import json     
        json.dump(data_dict, open(os.path.join(path, "meta.json"), 'w'))
        torch.save(model, os.path.join(path, "model.pt"))
        


#shutdown -s -f -t (14000)
for d in disease_label_list:
    for female_perc_in_training in female_percent_in_training_set:
        for i in np.arange(rs_min, rs_max):
            main(args, female_perc_in_training=female_perc_in_training,random_state = i,chose_disease_str=d)

DEVICE:cuda:0
------------------------------------------
------------------------------------------
------------------------------------------

run_config: NIH-Pneumothorax-fp50-npp1-rs0


  df_per_patient = df.groupby(['Patient ID', 'Patient Gender']).mean()
  df_per_patient = df.groupby([self.col_name_patient_id, self.col_name_gender]).mean()


file already exists and is loaded0
['Pneumothorax']


Loading Data: 100%|██████████| 8458/8458 [00:00<00:00, 45178.62it/s]


['Pneumothorax']


Loading Data: 100%|██████████| 1409/1409 [00:00<00:00, 45137.93it/s]


['Pneumothorax']


Loading Data: 100%|██████████| 8459/8459 [00:00<00:00, 49218.18it/s]


#train:  8458
#val:    1409
#test:   8459
using mixup approch/method


train epoch 1/38: 100%|██████████| 80/80 [00:55<00:00,  1.44batches/s, male val acc=0.5, female val acc=0.5]


E1:  avg val pred class: 0.0000, avg train pred class: 0.0034, loss_fn.class_imb_scaling: 4.75 curr lr: 0.0001000


train epoch 2/38: 100%|██████████| 80/80 [00:53<00:00,  1.50batches/s, male val acc=0.959, female val acc=0.945]


E2:  avg val pred class: 0.0014, avg train pred class: 0.0001, loss_fn.class_imb_scaling: 5.0 curr lr: 0.0000750


train epoch 3/38: 100%|██████████| 80/80 [00:53<00:00,  1.48batches/s, male val acc=0.96, female val acc=0.946]


E3:  avg val pred class: 0.0035, avg train pred class: 0.0059, loss_fn.class_imb_scaling: 5.25 curr lr: 0.0000250


train epoch 4/38: 100%|██████████| 80/80 [00:51<00:00,  1.56batches/s, male val acc=0.962, female val acc=0.946]


E4:  avg val pred class: 0.0163, avg train pred class: 0.0126, loss_fn.class_imb_scaling: 5.5 curr lr: 0.0001000


train epoch 5/38: 100%|██████████| 80/80 [00:52<00:00,  1.53batches/s, male val acc=0.957, female val acc=0.939]


E5:  avg val pred class: 0.0717, avg train pred class: 0.0494, loss_fn.class_imb_scaling: 5.25 curr lr: 0.0000970


train epoch 6/38: 100%|██████████| 80/80 [00:52<00:00,  1.52batches/s, male val acc=0.916, female val acc=0.901]


E6:  avg val pred class: 0.0774, avg train pred class: 0.0522, loss_fn.class_imb_scaling: 5.0 curr lr: 0.0000883


train epoch 7/38: 100%|██████████| 80/80 [00:53<00:00,  1.50batches/s, male val acc=0.926, female val acc=0.891]


E7:  avg val pred class: 0.0781, avg train pred class: 0.0575, loss_fn.class_imb_scaling: 4.75 curr lr: 0.0000750


train epoch 8/38:  18%|█▊        | 14/80 [00:09<00:43,  1.52batches/s, male val acc=0.918, female val acc=0.904]