In [1]:
import torch
import pickle
import numpy as np
from PIL import Image
#from datasets.BaseDataset import BaseDataset

In [2]:
import torch
import numpy as np


class BaseDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, path_to_images, sens_name, sens_classes, transform):
        super(BaseDataset, self).__init__()
        
        self.dataframe = dataframe        
        self.dataset_size = self.dataframe.shape[0]
        self.transform = transform
        self.path_to_images = path_to_images
        self.sens_name = sens_name
        self.sens_classes = sens_classes
        
        self.A = None
        self.Y = None
        self.AY_proportion = None
        
    def get_AY_proportions(self):
        if self.AY_proportion:
            return self.AY_proportion
        
        A_num_class = 2
        Y_num_class = 2
        A_label = self.A
        Y_label = self.Y
        
        A = self.A.tolist()
        Y = self.Y.tolist()
        ttl = len(A)
            
        len_A0Y0 = len([ay for ay in zip(A, Y) if ay == (0, 0)])
        len_A0Y1 = len([ay for ay in zip(A, Y) if ay == (0, 1)])
        len_A1Y0 = len([ay for ay in zip(A, Y) if ay == (1, 0)])
        len_A1Y1 = len([ay for ay in zip(A, Y) if ay == (1, 1)])

        assert (
            len_A0Y0 + len_A0Y1 + len_A1Y0 + len_A1Y1
        ) == ttl, "Problem computing train set AY proportion."
        A0Y0 = len_A0Y0 / ttl
        A0Y1 = len_A0Y1 / ttl
        A1Y0 = len_A1Y0 / ttl
        A1Y1 = len_A1Y1 / ttl
        
        self.AY_proportion = [[A0Y0, A0Y1], [A1Y0, A1Y1]]
        
        return self.AY_proportion
    
    def get_A_proportions(self):
        AY = self.get_AY_proportions()
        ret = [AY[0][0] + AY[0][1], AY[1][0] + AY[1][1]]
        np.testing.assert_almost_equal(np.sum(ret), 1.0)
        return ret

    def get_Y_proportions(self):
        AY = self.get_AY_proportions()
        ret = [AY[0][0] + AY[1][0], AY[0][1] + AY[1][1]]
        np.testing.assert_almost_equal(np.sum(ret), 1.0)
        return ret

    def set_A(self, sens_name):
        if sens_name == 'Sex':
            A = np.asarray(self.dataframe['Sex'].values != 'M').astype('float')
        elif sens_name == 'Age':
            A = np.asarray(self.dataframe['Age_binary'].values.astype('int') == 1).astype('float')
        elif sens_name == 'Race':
            A = np.asarray(self.dataframe['Race'].values == 'White').astype('float')
        elif self.sens_name == 'skin_type':
            A = np.asarray(self.dataframe['skin_binary'].values != 0).astype('float')
        elif self.sens_name == 'Insurance':
            self.A = np.asarray(self.dataframe['Insurance_binary'].values != 0).astype('float')
        else:
            raise ValueError("Does not contain {}".format(self.sens_name))
        return A

    def get_weights(self, resample_which):
        sens_attr, group_num = self.group_counts(resample_which)
        group_weights = [1/x.item() for x in group_num]
        sample_weights = [group_weights[int(i)] for i in sens_attr]
        return sample_weights
    
    def group_counts(self, resample_which = 'group'):
        if resample_which == 'group' or resample_which == 'balanced':
            if self.sens_name == 'Sex':
                mapping = {'M': 0, 'F': 1}
                groups = self.dataframe['Sex'].values
                group_array = [*map(mapping.get, groups)]
                
            elif self.sens_name == 'Age':
                if self.sens_classes == 2:
                    groups = self.dataframe['Age_binary'].values
                elif self.sens_classes == 5:
                    groups = self.dataframe['Age_multi'].values
                elif self.sens_classes == 4:
                    groups = self.dataframe['Age_multi4'].values.astype('int')
                group_array = groups.tolist()
                
            elif self.sens_name == 'Race':
                mapping = {'White': 0, 'non-White': 1}
                groups = self.dataframe['Race'].values
                group_array = [*map(mapping.get, groups)]
            elif self.sens_name == 'skin_type':
                if self.sens_classes == 2:
                    groups = self.dataframe['skin_binary'].values
                elif self.sens_classes == 6:
                    groups = self.dataframe['skin_type'].values
                group_array = groups.tolist()
            elif self.sens_name == 'Insurance':
                if self.sens_classes == 2:
                    groups = self.dataframe['Insurance_binary'].values
                elif self.sens_classes == 5:
                    groups = self.dataframe['Insurance'].values
                group_array = groups.tolist()
            else:
                raise ValueError("sensitive attribute does not defined in BaseDataset")
            
            if resample_which == 'balanced':
                #get class
                labels = self.Y.tolist()
                num_labels = len(set(labels))
                num_groups = len(set(group_array))
                
                group_array = (np.asarray(group_array) * num_labels + np.asarray(labels)).tolist()
                
        elif resample_which == 'class':
            group_array = self.Y.tolist()
            num_labels = len(set(group_array))
        
        self._group_array = torch.LongTensor(group_array)
        if resample_which == 'group':
            self._group_counts = (torch.arange(self.sens_classes).unsqueeze(1)==self._group_array).sum(1).float()
        elif resample_which == 'balanced':
            self._group_counts = (torch.arange(num_labels * num_groups).unsqueeze(1)==self._group_array).sum(1).float()
        elif resample_which == 'class':
            self._group_counts = (torch.arange(num_labels).unsqueeze(1)==self._group_array).sum(1).float()
        return group_array, self._group_counts
    
    def __len__(self):
        return self.dataset_size
    
    def get_labels(self): 
        # for sensitive attribute imbalance
        if self.sens_classes == 2:
            return self.A
        elif self.sens_classes == 5:
            return self.dataframe['Age_multi'].values.tolist()
        elif self.sens_classes == 4:
            return self.dataframe['Age_multi4'].values.tolist()

    def get_sensitive(self, sens_name, sens_classes, item):
        if sens_name == 'Sex':
            if item['Sex'] == 'M':
                sensitive = 0
            else:
                sensitive = 1
        elif sens_name == 'Age':
            if sens_classes == 2:
                sensitive = int(item['Age_binary'])
            elif sens_classes == 5:
                sensitive = int(item['Age_multi'])
            elif sens_classes == 4:
                sensitive = int(item['Age_multi4'])
        elif sens_name == 'Race':
            if item['Race'] == 'White':
                sensitive = 0
            else:
                sensitive = 1
        elif sens_name == 'skin_type':
            if sens_classes == 2:
                sensitive = int(item['skin_binary'])
            else:
                sensitive = int(item['skin_type'])
        elif self.sens_name == 'Insurance':
            if self.sens_classes == 2:
                sensitive = int(item['Insurance_binary'])
            elif self.sens_classes == 5:
                sensitive = int(item['Insurance'])
        else:
            raise ValueError('Please check the sensitive attributes.')
        return sensitive

In [None]:
import numpy as np
import torch
import torchvision.transforms as transforms
import datasets
import pandas as pd
import random
import torchio as tio
from utils.spatial_transforms import ToTensor

from torchvision.transforms._transforms_video import (
    NormalizeVideo,
)

from torch.utils.data import WeightedRandomSampler


def get_dataset(opt):
    data_setting = opt['data_setting']
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)
    if opt['is_3d']:
        mean_3d = [0.45, 0.45, 0.45]
        std_3d = [0.225, 0.225, 0.225]
        sizes = {'ADNI': (192, 192, 128), 'ADNI3T': (192, 192, 128), 'OCT': (192, 192, 96), 'COVID_CT_MD': (224, 224, 80)}
        if data_setting['augment']:
            transform_train = transforms.Compose([
                tio.transforms.RandomFlip(),
                tio.transforms.RandomAffine((-15, 15)),
                tio.transforms.CropOrPad(sizes[opt['dataset_name']]),
                
                ToTensor(),
                NormalizeVideo(mean_3d, std_3d),
            ])
        else:
            transform_train = transforms.Compose([
                tio.transforms.CropOrPad(sizes[opt['dataset_name']]),
                ToTensor(),
                NormalizeVideo(mean_3d, std_3d),
            ])
    
        transform_test = transforms.Compose([
            tio.transforms.CropOrPad(sizes[opt['dataset_name']]),
            ToTensor(),
            NormalizeVideo(mean_3d, std_3d),
        ])
    elif opt['is_tabular']:
        pass
    else:
        if data_setting['augment']:
            transform_train = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation((-15, 15)),
                transforms.RandomCrop((224, 224)),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])
    
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
    
    g = torch.Generator()
    g.manual_seed(opt['random_seed'])
    def seed_worker(worker_id):
        np.random.seed(opt['random_seed'] )
        random.seed(opt['random_seed'])
        
    image_path = data_setting['image_feature_path']
    train_meta = pd.read_csv(data_setting['train_meta_path']) 
    val_meta = pd.read_csv(data_setting['val_meta_path'])
    test_meta = pd.read_csv(data_setting['test_meta_path'])   
    
    if opt['bianry_train_multi_test'] == -1:
        val_test_classes = opt['sens_classes']
    else:
        val_test_classes = opt['bianry_train_multi_test']
    
    if opt['is_3d']:
        dataset_name = getattr(datasets, opt['dataset_name'])
        train_data = dataset_name(train_meta, image_path, opt['sensitive_name'], opt['sens_classes'], transform_train)
        val_data = dataset_name(val_meta, image_path, opt['sensitive_name'], val_test_classes, transform_test)
        test_data = dataset_name(test_meta, image_path, opt['sensitive_name'], val_test_classes, transform_test)
    elif opt['is_tabular']:
        # different format
        dataset_name = getattr(datasets, opt['dataset_name'])
        data_train_path = data_setting['data_train_path']
        data_val_path = data_setting['data_val_path']
        data_test_path = data_setting['data_test_path']
        
        data_train_df = pd.read_csv(data_train_path)
        data_val_df = pd.read_csv(data_val_path)
        data_test_df = pd.read_csv(data_test_path)
        
        train_data = dataset_name(train_meta, data_train_df, opt['sensitive_name'], opt['sens_classes'], None)
        val_data = dataset_name(val_meta, data_val_df, opt['sensitive_name'], val_test_classes, None)
        test_data = dataset_name(test_meta, data_test_df, opt['sensitive_name'], val_test_classes, None)
    
    else:
        dataset_name = getattr(datasets, opt['dataset_name'])
        pickle_train_path = data_setting['pickle_train_path']
        pickle_val_path = data_setting['pickle_val_path']
        pickle_test_path = data_setting['pickle_test_path']
        train_data = dataset_name(train_meta, pickle_train_path, opt['sensitive_name'], opt['sens_classes'], transform_train)
        val_data = dataset_name(val_meta, pickle_val_path, opt['sensitive_name'], val_test_classes, transform_test)
        test_data = dataset_name(test_meta, pickle_test_path, opt['sensitive_name'], val_test_classes, transform_test)
    
    print('loaded dataset ', opt['dataset_name'])
        
    if opt['experiment']=='resampling' or opt['experiment']=='GroupDRO' or opt['experiment']=='resamplingSWAD':
        weights = train_data.get_weights(resample_which = opt['resample_which'])
        sampler = WeightedRandomSampler(weights, len(weights), replacement=True, generator = g)
    else:
        sampler = None

    train_loader = torch.utils.data.DataLoader(
                            train_data, batch_size=opt['batch_size'], 
                            sampler=sampler,
                            shuffle=(opt['experiment']!='resampling' and opt['experiment']!='GroupDRO' and opt['experiment']!='resamplingSWAD'), num_workers=8, 
                            worker_init_fn=seed_worker, generator=g, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
                          val_data, batch_size=opt['batch_size'],
                          shuffle=True, num_workers=8, worker_init_fn=seed_worker, generator=g, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
                           test_data, batch_size=opt['batch_size'],
                           shuffle=True, num_workers=8, worker_init_fn=seed_worker, generator=g, pin_memory=True)

    return train_data, val_data, test_data, train_loader, val_loader, test_loader, val_meta, test_meta

In [3]:
class PAPILA(BaseDataset):
    def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform):
        super(PAPILA, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform)

        with open(path_to_pickles, 'rb') as f: 
            self.tol_images = pickle.load(f)
            
        self.A = self.set_A(sens_name) 
        self.Y = (np.asarray(self.dataframe['Diagnosis'].values) > 0).astype('float')
        self.AY_proportion = None
        
    def __getitem__(self, idx):
        item = self.dataframe.iloc[idx]
        img = Image.fromarray(self.tol_images[idx])
        img = self.transform(img)

        label = torch.FloatTensor([item['Diagnosis']])
        
        sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item)
                
        return idx, img, label, sensitive

# data preprocessing

In [5]:
import h5py
import pandas as pd
import numpy as np
#import cv2
import os
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import train_test_split
import pickle
import time

## Preprocess metadata

In [None]:
# read metadata
path = 'data/PAPILA/'

demo_data = pd.read_csv(path + 'HAM10000_metadata.csv')
demo_data

In [None]:
Counter(demo_data['dataset'])

In [None]:
# add image path to the metadata
pathlist = demo_data['image_id'].values.tolist()
paths = ['HAM10000_images/' + i + '.jpg' for i in pathlist]
demo_data['Path'] = paths

In [None]:
# remove age/sex == null 
demo_data = demo_data[~demo_data['age'].isnull()]
demo_data = demo_data[~demo_data['sex'].isnull()]
demo_data

In [None]:
# unify the value of sensitive attributes
sex = demo_data['sex'].values
sex[sex == 'male'] = 'M'
sex[sex == 'female'] = 'F'
demo_data['Sex'] = sex
demo_data

In [None]:
# split subjects to different age groups
demo_data['Age_multi'] = demo_data['age'].values.astype('int')
demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(-1,19), 0, demo_data['Age_multi'])
demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(20,39), 1, demo_data['Age_multi'])
demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(40,59), 2, demo_data['Age_multi'])
demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(60,79), 3, demo_data['Age_multi'])
demo_data['Age_multi'] = np.where(demo_data['Age_multi']>=80, 4, demo_data['Age_multi'])

demo_data['Age_binary'] = demo_data['age'].values.astype('int')
demo_data['Age_binary'] = np.where(demo_data['Age_binary'].between(-1, 60), 0, demo_data['Age_binary'])
demo_data['Age_binary'] = np.where(demo_data['Age_binary']>= 60, 1, demo_data['Age_binary'])
demo_data

In [None]:
# convert to binary labels
# benign: bcc, bkl, dermatofibroma, nv, vasc
# maglinant: akiec, mel

labels = demo_data['dx'].values.copy()
labels[labels == 'akiec'] = '1'
labels[labels == 'mel'] = '1'
labels[labels != '1'] = '0'

labels = labels.astype('int')

demo_data['binaryLabel'] = labels
demo_data

## split dataset to train/test/varify

In [None]:
def split_811(all_meta, patient_ids):
    sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.2, random_state=0)
    sub_val, sub_test = train_test_split(sub_val_test, test_size=0.5, random_state=0)
    train_meta = all_meta[all_meta.lesion_id.isin(sub_train)]
    val_meta = all_meta[all_meta.lesion_id.isin(sub_val)]
    test_meta = all_meta[all_meta.lesion_id.isin(sub_test)]
    return train_meta, val_meta, test_meta

sub_train, sub_val, sub_test = split_811(demo_data, np.unique(demo_data['lesion_id']))

In [None]:
sub_train.to_csv('your_path/fariness_data/HAM10000/split/new_train.csv')
sub_val.to_csv('your_path/fariness_data/HAM10000/split/new_val.csv')
sub_test.to_csv('your_path/fariness_data/HAM10000/split/new_test.csv')

In [None]:
# you can have a look of some examples here
img = cv2.imread('your_path/fariness_data/HAM10000/HAM10000_images/ISIC_0027419.jpg')
print(img.shape)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

## Save images into pickle files

In [None]:
test_meta = pd.read_csv('your_path/fariness_data/HAM10000/split/new_train.csv')

path = 'your_path/fariness_data/HAM10000/pkls/'
images = []
start = time.time()
for i in range(len(test_meta)):

    img = cv2.imread(path + test_meta.iloc[i]['Path'])
    # resize to the input size in advance to save time during training
    img = cv2.resize(img, (256, 256))
    images.append(img)
    
end = time.time()
end-start
with open(path + 'train_images.pkl', 'wb') as f:
    pickle.dump(images, f)

In [6]:
import parse_args
import json
import numpy as np
import pandas as pd
from utils import basics
import glob


def train(model, opt):
    for epoch in range(opt['total_epochs']):
        ifbreak = model.train(epoch)
        if ifbreak:
            break
     
    # record val metrics for hyperparameter selection
    pred_df = model.record_val()
    return pred_df
    

if __name__ == '__main__':
    
    opt, wandb = parse_args.collect_args()
    if not opt['test_mode']:
        
        random_seeds = np.random.choice(range(100), size = 3, replace=False).tolist()
        val_df = pd.DataFrame()
        test_df = pd.DataFrame()
        print('Random seed: ', random_seeds)
        for random_seed in random_seeds:
            opt['random_seed'] = random_seed
            model = basics.get_model(opt, wandb)
            pred_df = train(model, opt)
            val_df = pd.concat([val_df, pred_df])
            
            pred_df = model.test()
            test_df = pd.concat([test_df, pred_df])
            
        stat_val = basics.avg_eval(val_df, opt, 'val')
        stat_test = basics.avg_eval(test_df, opt, 'test')
        model.log_wandb(stat_val.to_dict())
        model.log_wandb(stat_test.to_dict())        
    else:
        
        if opt['cross_testing']:
            
            test_df = pd.DataFrame()
            method_model_path = opt['cross_testing_model_path']
            model_paths = glob.glob(method_model_path + '/cross_domain_*.pth')
            for model_path in model_paths:
                opt['cross_testing_model_path_single'] = model_path
                model = basics.get_model(opt, wandb)
                pred_df = model.test()
                
                test_df = pd.concat([test_df, pred_df])
            stat_test = basics.avg_eval(test_df, opt, 'cross_testing')
            
            model.log_wandb(stat_test.to_dict())

ModuleNotFoundError: No module named 'parse_args'