In [1]:
import pandas as pd

data = pd.read_csv('../data/diagnoses_10.csv')

# Getting the top 5 diagnoses only
data = data.drop(columns=['Dx_55827005'])
data.head()
data.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 45152 entries, 0 to 45151
Data columns (total 6 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   ID            45152 non-null  object
 1   Dx_426177001  45152 non-null  bool  
 2   Dx_426783006  45152 non-null  bool  
 3   Dx_164890007  45152 non-null  bool  
 4   Dx_427084000  45152 non-null  bool  
 5   Dx_164934002  45152 non-null  bool  
dtypes: bool(5), object(1)
memory usage: 573.3+ KB


In [10]:
label_counts = data.drop(columns='ID').sum().sort_values(ascending=False)
print(label_counts)
target_count = label_counts.min()  # or set manually if too low
print(target_count)

Dx_426177001    16559
Dx_426783006     8125
Dx_164890007     8060
Dx_427084000     7255
Dx_164934002     7043
dtype: int64
7043


In [11]:
from collections import defaultdict
import pandas as pd

# Create a dict to hold sample indices for each label
label_to_indices = defaultdict(set)

# Build a mapping from label → rows where it's 1
for label in data.columns:
    if label == 'ID': continue
    label_to_indices[label] = set(data[data[label] == 1].index)

print(label_to_indices)

defaultdict(<class 'set'>, {'Dx_426177001': {32768, 1, 2, 32769, 4, 32770, 32771, 7, 8, 9, 32772, 11, 12, 13, 32773, 15, 16, 17, 18, 32778, 20, 32780, 32781, 23, 24, 25, 32785, 32786, 28, 32788, 30, 32790, 32791, 33, 32793, 35, 32795, 32796, 32797, 32798, 40, 32800, 32801, 32802, 32803, 45, 32805, 32806, 32807, 49, 50, 51, 52, 32812, 32813, 32814, 32815, 32816, 58, 32818, 32819, 61, 62, 32822, 32823, 32824, 32825, 32826, 32827, 69, 32829, 71, 32831, 73, 32833, 75, 32835, 32836, 78, 79, 80, 32840, 82, 83, 84, 85, 86, 87, 32847, 89, 32849, 91, 32851, 32852, 94, 95, 96, 32856, 98, 99, 32859, 101, 32861, 103, 32863, 32864, 106, 32866, 108, 109, 110, 111, 112, 113, 32873, 115, 116, 32876, 118, 32878, 32879, 121, 122, 123, 32883, 125, 32885, 32886, 128, 129, 32889, 32890, 132, 133, 32893, 32894, 136, 32896, 32897, 139, 140, 141, 142, 32902, 144, 145, 32905, 32906, 148, 149, 32909, 151, 152, 153, 154, 32914, 156, 157, 158, 159, 160, 32920, 162, 32922, 32923, 165, 32925, 167, 32927, 169, 32929

In [12]:
import random

selected_indices = set()
label_counter = defaultdict(int)

# Keep picking until all labels reach the target
while min(label_counter.values(), default=0) < target_count:
    for label in label_to_indices:
        # Skip if target is met
        if label_counter[label] >= target_count:
            continue

        # Choose an unused row for this label
        remaining = list(label_to_indices[label] - selected_indices)
        if not remaining:
            continue  # Exhausted

        chosen = random.choice(remaining)
        selected_indices.add(chosen)

        # Update label counts
        row = data.loc[chosen]
        for l in label_to_indices:
            if row[l] == 1:
                label_counter[l] += 1

In [13]:
balanced_df = data.loc[list(selected_indices)].copy()


In [14]:
balanced_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 29016 entries, 0 to 44979
Data columns (total 6 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   ID            29016 non-null  object
 1   Dx_426177001  29016 non-null  bool  
 2   Dx_426783006  29016 non-null  bool  
 3   Dx_164890007  29016 non-null  bool  
 4   Dx_427084000  29016 non-null  bool  
 5   Dx_164934002  29016 non-null  bool  
dtypes: bool(5), object(1)
memory usage: 595.1+ KB


In [15]:
label_counts = balanced_df.drop(columns='ID').sum().sort_values(ascending=False)
print(label_counts)
target_count = label_counts.min()  # or set manually if too low
print(target_count)

Dx_426177001    7043
Dx_426783006    7043
Dx_164890007    7043
Dx_427084000    7043
Dx_164934002    7043
dtype: int64
7043


In [16]:
# Drop the ID column just in case
label_cols = balanced_df.drop(columns='ID')

# Calculate percentage of samples with each diagnosis
diagnosis_percentages = label_cols.mean() * 100

# Optional: sort from most common to least
diagnosis_percentages = diagnosis_percentages.sort_values(ascending=False)

print(diagnosis_percentages)


Dx_426177001    24.272815
Dx_426783006    24.272815
Dx_164890007    24.272815
Dx_427084000    24.272815
Dx_164934002    24.272815
dtype: float64


In [17]:
balanced_df.to_csv('../data/diagnoses_cut2.csv', index=False)