In [1]:
from PIL import Image
import numpy as np
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode

import albumentations as A
from albumentations.pytorch import ToTensorV2
import argparse
import json
import pathlib
import unittest
from typing import Tuple


In [2]:
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname("."))))
import utils
args = argparse.Namespace()
args.num_clients = 10

args.dirichlet_alpha = 0.1
N_parties = args.num_clients

alpha = args.dirichlet_alpha
args.unique_patients = True

In [19]:
import torchxrayvision
from skmultilearn.model_selection import iterative_train_test_split
from skmultilearn.model_selection.measures import get_combination_wise_output_matrix

class Custom_RSNA_Dataset(torchxrayvision.datasets.RSNA_Pneumonia_Dataset):
    def __getitem__(self, index):
        sample = super().__getitem__(index)
        return sample['img'], sample['lab']
    
class Custom_NIH_Dataset(torchxrayvision.datasets.NIH_Dataset):
    def __getitem__(self, index):
        sample = super().__getitem__(index)
        return sample['img'], sample['lab']

class Custom_CheX_Dataset(torchxrayvision.datasets.CheX_Dataset):
    def __getitem__(self, index):
        sample = super().__getitem__(index)
        return sample['img'], sample['lab']

NIH_dataset = Custom_NIH_Dataset(imgpath="/home/suncheol/.data/NIH_224/images", views=["PA"], unique_patients=args.unique_patients)

labels = [label for _, label in NIH_dataset]
labels = np.array(labels)

indices = np.arange(len(labels)).reshape(-1, 1)
# drop no finding

no_finding_indices = np.where(labels.sum(axis=1) == 0)[0]
indices = np.delete(indices, no_finding_indices, axis=0)
labels = np.delete(labels, no_finding_indices, axis=0)

indices_train, labels_train, indices_test, labels_test = iterative_train_test_split(indices, labels, test_size=0.2)

In [42]:
len(indices_train), len(indices_test)

(22507, 5501)

In [43]:
splitdir = pathlib.Path("./splitfile" ) / "no_unique_patients"
splitdir.mkdir(exist_ok=True, parents=True)
NIH_dataset.csv.iloc[indices_train.flatten()].reset_index(drop=True).to_csv(splitdir / "NIH_dataset_train.csv", index=False)
NIH_dataset.csv.iloc[indices_test.flatten()].reset_index(drop=True).to_csv(splitdir / "NIH_dataset_test.csv", index=False)

In [46]:
train_dataset = Custom_NIH_Dataset(imgpath="/home/suncheol/.data/NIH_224/images",
                                 csvpath=splitdir / "NIH_dataset_train.csv",
                                 views=["PA"], unique_patients=args.unique_patients)
print(len(train_dataset))

22507


In [47]:
test_dataset = Custom_NIH_Dataset(imgpath="/home/suncheol/.data/NIH_224/images",
                                    csvpath=splitdir / "NIH_dataset_test.csv",
                                    views=["PA"], unique_patients=args.unique_patients)
print(len(test_dataset))

5501


In [48]:
public_dataset = Custom_RSNA_Dataset(imgpath="/home/suncheol/.data/rsna/stage_2_train_images",
                                    views=["PA"])

In [49]:
labels = np.array([label for _, label in train_dataset])

In [50]:
def replace_nonzero_values(matrix):
    # 행렬의 크기를 구합니다.
    rows, cols = matrix.shape
    
    # 모든 열에 대해 반복합니다.
    for col in range(cols):
        # col 번째 열에서 0이 아닌 값의 위치를 찾습니다.
        nonzero_indices = np.nonzero(matrix[:, col])[0]
        
        # 해당 위치의 값을 col로 치환합니다.
        matrix[nonzero_indices, col] = col
    
    return matrix

labels = replace_nonzero_values(labels)
 

In [51]:
def create_label_to_id_map(labels):
    label_to_id = {}
    index = 0
    for label in labels:
        label = frozenset(label)
        if label not in label_to_id:
            label_to_id[label] = index
            index += 1
    return label_to_id

def convert_id_to_label_map(label_to_id):
    return {v: k for k, v in label_to_id.items()}

def convert_labels_to_ids(labels, label_to_id):
    return [label_to_id[frozenset(label)] for label in labels]

label_to_id = create_label_to_id_map(labels)
id_to_label = convert_id_to_label_map(label_to_id)
label_ids = convert_labels_to_ids(labels, label_to_id)

In [52]:
args.num_classes = len(label_to_id)
y_data = label_ids

399

In [54]:
def get_clientid_from_index(split_dirichlet_data_index_dict):
    clientid_from_index = {}
    for clientid, indices in split_dirichlet_data_index_dict.items():
        for index in indices:
            clientid_from_index[index] = clientid
    return clientid_from_index

In [56]:
import pandas as pd
splitfile_csv = pd.read_csv(splitdir / "NIH_dataset_train.csv", index_col=None, header=0)

alphas = [0.1, 0.5, 1.0]
for alpha in alphas:
    utils.set_random_seed(42)
    dirichlet_accmul_count = utils.get_dirichlet_distribution_count(N_class, N_parties, y_data, alpha)
    split_dirichlet_data_index_dict = utils.get_split_data_index(y_data, dirichlet_accmul_count)
    splitfile_csv[f"client_id_{alpha}"] = splitfile_csv.index.map(get_clientid_from_index(split_dirichlet_data_index_dict))
    print(splitfile_csv[f"client_id_{alpha}"].value_counts())

party_id: 0, num of samples: 1997
party_id: 1, num of samples: 1039
party_id: 2, num of samples: 1649
party_id: 3, num of samples: 1476
party_id: 4, num of samples: 2033
party_id: 5, num of samples: 2531
party_id: 6, num of samples: 2557
party_id: 7, num of samples: 1476
party_id: 8, num of samples: 1865
party_id: 9, num of samples: 5884
client_id_0.1
9    5884
6    2557
5    2531
4    2033
0    1997
8    1865
2    1649
7    1476
3    1476
1    1039
Name: count, dtype: int64
party_id: 0, num of samples: 2328
party_id: 1, num of samples: 1173
party_id: 2, num of samples: 1932
party_id: 3, num of samples: 2544
party_id: 4, num of samples: 4077
party_id: 5, num of samples: 2317
party_id: 6, num of samples: 2180
party_id: 7, num of samples: 1942
party_id: 8, num of samples: 2240
party_id: 9, num of samples: 1774
client_id_0.5
4    4077
3    2544
0    2328
5    2317
8    2240
6    2180
7    1942
2    1932
9    1774
1    1173
Name: count, dtype: int64
party_id: 0, num of samples: 2168
party_

In [57]:
splitfile_csv.tail(20)

Unnamed: 0,index,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],...,y],view,has_masks,patientid,age_years,sex_male,sex_female,client_id_0.1,client_id_0.5,client_id_1.0
22487,112033,00030746_000.png,Consolidation|Effusion,0,30746,51,F,PA,2021,2021,...,0.194311,PA,False,30746,51.0,False,True,9,8,9
22488,112038,00030751_000.png,Infiltration,0,30751,36,M,PA,2020,2021,...,0.194314,PA,False,30751,36.0,True,False,9,8,9
22489,112047,00030753_007.png,Effusion|Pleural_Thickening,7,30753,54,F,PA,2021,2021,...,0.194311,PA,False,30753,54.0,False,True,9,8,9
22490,112050,00030753_010.png,Atelectasis|Effusion,10,30753,54,F,PA,2021,2011,...,0.194311,PA,False,30753,54.0,False,True,9,9,9
22491,112052,00030753_012.png,Atelectasis|Mass,12,30753,54,F,PA,2021,2021,...,0.194311,PA,False,30753,54.0,False,True,9,9,9
22492,112053,00030754_000.png,Infiltration,0,30754,53,M,PA,2020,2021,...,0.194314,PA,False,30754,53.0,True,False,9,8,9
22493,112054,00030755_000.png,Infiltration,0,30755,49,F,PA,2021,2021,...,0.194311,PA,False,30755,49.0,False,True,9,8,9
22494,112055,00030756_000.png,Infiltration,0,30756,30,M,PA,3056,2544,...,0.139,PA,False,30756,30.0,True,False,9,8,9
22495,112056,00030757_000.png,Infiltration,0,30757,54,M,PA,2020,2021,...,0.194314,PA,False,30757,54.0,True,False,9,9,9
22496,112058,00030759_000.png,Consolidation,0,30759,50,M,PA,2021,2021,...,0.194311,PA,False,30759,50.0,True,False,8,9,9


In [58]:
[max(value) for value in split_dirichlet_data_index_dict.values()]

[18530, 21425, 21372, 21973, 21277, 22351, 22306, 22460, 22442, 22506]

In [60]:
splitfile_csv.to_csv(splitdir / "NIH_dataset_train.csv", index=False)

In [112]:
args.unique_patients = False
datadir = pathlib.Path("/home/suncheol/.data/")
train_dataset = Custom_CheX_Dataset(imgpath= str(datadir / "CheXpert-v1.0-small"),
                                    csvpath= str(datadir / "CheXpert-v1.0-small/train.csv"),
                                    views=["PA"], unique_patients=args.unique_patients)



labels = [label for _, label in train_dataset]
labels = np.array(labels)
# replace -1, nan to 0
labels[labels == -1] = 0
labels[np.isnan(labels)] = 0

indices = np.arange(len(labels)).reshape(-1, 1)
# drop no finding

no_finding_indices = np.where(labels.sum(axis=1) == 0)[0]
indices = np.delete(indices, no_finding_indices, axis=0)
labels = np.delete(labels, no_finding_indices, axis=0)

if args.unique_patients:
    splitdir = pathlib.Path("./splitfile" ) / "unique_patients"
else:
    splitdir = pathlib.Path("./splitfile" ) / "no_unique_patients"
splitdir.mkdir(exist_ok=True, parents=True)
indices_train, labels_train, indices_test, labels_test = iterative_train_test_split(indices, labels, test_size=0.2)
train_dataset.csv.iloc[indices_train.flatten()].reset_index(drop=True).to_csv(splitdir / "CheX_dataset_train.csv", index=False)
train_dataset.csv.iloc[indices_test.flatten()].reset_index(drop=True).to_csv(splitdir / "CheX_dataset_train_test.csv", index=False)

In [113]:
len(indices_train), len(indices_test)

(11979, 2942)

In [73]:
valid_dataset = Custom_CheX_Dataset(imgpath= str(datadir / "CheXpert-v1.0-small"),
                                    csvpath= str(datadir / "CheXpert-v1.0-small/valid.csv"),
                                    views=["PA"], unique_patients=args.unique_patients)

labels = [label for _, label in valid_dataset]
labels = np.array(labels)
labels[labels == -1] = 0
labels[np.isnan(labels)] = 0

print(labels.shape, labels.sum(axis=1).shape)
indices = np.arange(len(labels)).reshape(-1, 1)
# drop no finding

no_finding_indices = np.where(labels.sum(axis=1) == 0)[0]
indices = np.delete(indices, no_finding_indices, axis=0)
labels = np.delete(labels, no_finding_indices, axis=0)
print(len(indices))
valid_dataset.csv.iloc[indices.flatten()].reset_index(drop=True).to_csv(splitdir / "CheX_dataset_valid.csv", index=False)

(33, 13) (33,)
15


In [132]:
train_dataset = Custom_CheX_Dataset(imgpath= str(datadir / "CheXpert-v1.0-small"),
                                    csvpath= str(splitdir / "CheX_dataset_train.csv"),
                                    views=["PA"], unique_patients=args.unique_patients)
print(len(train_dataset))

17859


In [115]:
test_dataset = Custom_CheX_Dataset(imgpath= str(datadir / "CheXpert-v1.0-small"),
                                    csvpath= str(splitdir / "CheX_dataset_train_test.csv"),
                                    views=["PA"], unique_patients=args.unique_patients)
print(len(test_dataset))

2942


In [75]:
valid_dataset = Custom_CheX_Dataset(imgpath= str(datadir / "CheXpert-v1.0-small"),
                                    csvpath= str(splitdir / "CheX_dataset_valid.csv"),
                                    views=["PA"], unique_patients=args.unique_patients)
print(len(valid_dataset))

15


In [None]:
public_dataset = Custom_RSNA_Dataset(imgpath="/home/suncheol/.data/rsna/stage_2_train_images",
                                    views=["PA"])

In [133]:
labels = np.array([label for _, label in train_dataset])
labels = np.array(labels)
labels[labels == -1] = 0
labels[np.isnan(labels)] = 0

In [134]:
def replace_nonzero_values(matrix):
    # 행렬의 크기를 구합니다.
    rows, cols = matrix.shape
    
    # 모든 열에 대해 반복합니다.
    for col in range(cols):
        # col 번째 열에서 0이 아닌 값의 위치를 찾습니다.
        nonzero_indices = np.nonzero(matrix[:, col])[0]
        
        # 해당 위치의 값을 col로 치환합니다.
        matrix[nonzero_indices, col] = col
    
    return matrix

labels = replace_nonzero_values(labels)
 

In [135]:
def create_label_to_id_map(labels):
    label_to_id = {}
    index = 0
    for label in labels:
        label = frozenset(label)
        if label not in label_to_id:
            label_to_id[label] = index
            index += 1
    return label_to_id

def convert_id_to_label_map(label_to_id):
    return {v: k for k, v in label_to_id.items()}

def convert_labels_to_ids(labels, label_to_id):
    return [label_to_id[frozenset(label)] for label in labels]

label_to_id = create_label_to_id_map(labels)
id_to_label = convert_id_to_label_map(label_to_id)
label_ids = convert_labels_to_ids(labels, label_to_id)

In [136]:
args.num_classes = len(label_to_id)
y_data = label_ids
N_class = args.num_classes
print(N_class)

508


In [137]:
def get_clientid_from_index(split_dirichlet_data_index_dict):
    clientid_from_index = {}
    for clientid, indices in split_dirichlet_data_index_dict.items():
        for index in indices:
            clientid_from_index[index] = clientid
    return clientid_from_index

In [138]:
import pandas as pd
splitfile_csv = pd.read_csv(splitdir / "CheX_dataset_train.csv", index_col=None, header=0)

alphas = [0.1, 0.5, 1.0]
for alpha in alphas:
    utils.set_random_seed(42)
    dirichlet_accmul_count = utils.get_dirichlet_distribution_count(N_class, N_parties, y_data, alpha)
    split_dirichlet_data_index_dict = utils.get_split_data_index(y_data, dirichlet_accmul_count)
    splitfile_csv[f"client_id_{alpha}"] = splitfile_csv.index.map(get_clientid_from_index(split_dirichlet_data_index_dict))
    print(splitfile_csv[f"client_id_{alpha}"].value_counts())

party_id: 0, num of samples: 2051
party_id: 1, num of samples: 550
party_id: 2, num of samples: 1261
party_id: 3, num of samples: 1158
party_id: 4, num of samples: 1603
party_id: 5, num of samples: 3336
party_id: 6, num of samples: 2253
party_id: 7, num of samples: 1241
party_id: 8, num of samples: 2531
party_id: 9, num of samples: 1875
client_id_0.1
5    3336
8    2531
6    2253
0    2051
9    1875
4    1603
2    1261
7    1241
3    1158
1     550
Name: count, dtype: int64
party_id: 0, num of samples: 2210
party_id: 1, num of samples: 1468
party_id: 2, num of samples: 1483
party_id: 3, num of samples: 1585
party_id: 4, num of samples: 2805
party_id: 5, num of samples: 2293
party_id: 6, num of samples: 2079
party_id: 7, num of samples: 1335
party_id: 8, num of samples: 1144
party_id: 9, num of samples: 1457
client_id_0.5
4    2805
5    2293
0    2210
6    2079
3    1585
2    1483
1    1468
9    1457
7    1335
8    1144
Name: count, dtype: int64
party_id: 0, num of samples: 1958
party_i

In [139]:
splitfile_csv.tail(20)

Unnamed: 0,Path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,...,Fracture,Support Devices,view,patientid,age_years,sex_male,sex_female,client_id_0.1,client_id_0.5,client_id_1.0
17839,CheXpert-v1.0-small/train/patient34573/study3/...,Male,72,Frontal,PA,,,,1.0,1.0,...,,,PA,34573,72.0,True,False,9,8,8
17840,CheXpert-v1.0-small/train/patient34573/study1/...,Male,71,Frontal,PA,,,0.0,1.0,,...,,,PA,34573,71.0,True,False,9,9,9
17841,CheXpert-v1.0-small/train/patient34573/study2/...,Male,72,Frontal,PA,,0.0,,,1.0,...,,,PA,34573,72.0,True,False,8,9,9
17842,CheXpert-v1.0-small/train/patient34574/study5/...,Male,81,Frontal,PA,,-1.0,1.0,,,...,,,PA,34574,81.0,True,False,6,9,9
17843,CheXpert-v1.0-small/train/patient34576/study1/...,Male,38,Frontal,PA,,,,1.0,,...,,,PA,34576,38.0,True,False,8,9,9
17844,CheXpert-v1.0-small/train/patient34579/study1/...,Male,87,Frontal,PA,,,,,,...,1.0,1.0,PA,34579,87.0,True,False,7,9,9
17845,CheXpert-v1.0-small/train/patient34583/study2/...,Male,39,Frontal,PA,,1.0,,,,...,,,PA,34583,39.0,True,False,9,8,9
17846,CheXpert-v1.0-small/train/patient34584/study6/...,Female,59,Frontal,PA,,,,,,...,,,PA,34584,59.0,False,True,8,9,9
17847,CheXpert-v1.0-small/train/patient34584/study5/...,Female,58,Frontal,PA,,,,1.0,,...,,,PA,34584,58.0,False,True,8,9,9
17848,CheXpert-v1.0-small/train/patient34587/study1/...,Male,73,Frontal,PA,,-1.0,,,,...,,1.0,PA,34587,73.0,True,False,8,8,9


In [140]:
[max(value) for value in split_dirichlet_data_index_dict.values()]

[12807, 16096, 17837, 17526, 17857, 17742, 17654, 17773, 17839, 17858]

In [141]:
splitfile_csv.to_csv(splitdir / "CheX_dataset_train.csv", index=False)

In [151]:
# # files in /home/suncheol/code/FedTest/0_FedMHAD_vit/_prepare_dataset/splitfile/no_unique_patients

# path = pathlib.Path("/home/suncheol/code/FedTest/0_FedMHAD_vit/_prepare_dataset/splitfile/unique_patients")

# for filename in path.glob("*.csv"):
#     # replace -1, nan to 0
#     if "CheX" in filename.name:
#         splitfile_csv = pd.read_csv(filename, index_col=None, header=0)
#         splitfile_csv.fillna(0, inplace=True)
#         splitfile_csv.replace(-1, 0, inplace=True)
#         splitfile_csv.to_csv(filename, index=False)