In [None]:
import pandas as pd
import numpy as np
import tabulate as tb
from typing import Dict


SEED = 42
np.random.seed(SEED)

In [83]:
# file_path = './DeepFake Annotations/A-FF++.csv'
file_path = './Skin Cancer MNIST/HAM10000_metadata.csv'
df = pd.read_csv(file_path, sep=',')
df['file_path'] = 'Skin Cancer MNIST/HAM10000_images_part_1/' + df['image_id'] + '.jpg'

df = df.drop(columns=['image_id', 'lesion_id', 'dx_type', 'age', 'localization'])

def get_balanced_subset(df, class_col, feature_col, feature_value, samples_per_class, randomize=True, reset_index=False):
    tmp = df[df[feature_col] == feature_value]
    
    unique_classes = tmp[class_col].unique()
    for cl in unique_classes:
        amt = len(tmp[tmp[class_col] == cl])
        if amt < samples_per_class:
            raise ValueError(f"Not enough samples for class '{cl}' in feature '{feature_value}'. "
                             f"Required: {samples_per_class}, Available: {len(tmp[tmp[class_col] == cl])}")
   
    tmp = tmp.groupby(class_col).head(samples_per_class)
    if randomize:
        tmp = tmp.sample(len(tmp), random_state=SEED)
    if reset_index:
        tmp = tmp.reset_index(drop=True)
    
    return tmp    

tmp = get_balanced_subset(df=df, class_col='dx', feature_col='sex', feature_value='male', samples_per_class=2)
print(tb.tabulate(tmp, headers='keys', tablefmt='psql'))

+------+-------+-------+-----------------------------------------------------------+
|      | dx    | sex   | file_path                                                 |
|------+-------+-------+-----------------------------------------------------------|
| 2462 | bcc   | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0028155.jpg |
| 2982 | nv    | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0031325.jpg |
|    0 | bkl   | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0027419.jpg |
| 9689 | akiec | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0029360.jpg |
| 1213 | mel   | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0027190.jpg |
| 2321 | vasc  | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0031270.jpg |
|   64 | nv    | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0024698.jpg |
|    1 | bkl   | male  | Skin Cancer MNIST/HAM10000_images_part_1/ISIC_0025030.jpg |
| 9690 | akiec | male  | Skin Cancer MNIST/HAM10000_images_part_1

In [84]:
def get_exp_data(df, class_col, feature_col, ratio : Dict, size, randomize=True, exclude_column=None, exclude_df=None):
    if randomize:
        df_rnd = df.sample(frac=1, random_state=SEED).reset_index(drop=True)
    else:
        df_rnd = df.copy()
        
    if exclude_column is not None and exclude_df is not None:
        if exclude_column not in df_rnd.columns:
            raise ValueError(f"Column '{exclude_column}' not found in DataFrame.")
        if exclude_column not in exclude_df.columns:
            raise ValueError(f"Column '{exclude_column}' not found in exclude DataFrame.")
        df_rnd = df_rnd[~df_rnd[exclude_column].isin(exclude_df[exclude_column])]
        
    uniq_classes = df_rnd[class_col].unique()
    uniq_features = df_rnd[feature_col].unique()
    
    def get_exp_data_inner(tmp_df, size):
        df_tmp = None
        for uf in uniq_features:
            if ratio.get(uf) is None:
                print(f"Feature '{uf}' not found in ratios. Skipping.")
                continue            
            c_amt = int(size * ratio[uf] / len(uniq_classes))
            if c_amt <= 0:
                raise ValueError(f"Calculated samples per class ({c_amt}) is less than or equal to zero for feature '{uf}' with ratio {ratio}.")
            tmp = get_balanced_subset(df=tmp_df, class_col=class_col, feature_col=feature_col, feature_value=uf, 
                                        samples_per_class=c_amt, randomize=False)
            if df_tmp is None:
                df_tmp = tmp
            else:
                df_tmp = pd.concat([df_tmp, tmp])
        return df_tmp
            
    df_res = get_exp_data_inner(df_rnd, size)
    
    if len(df_res) < size:
        print(f"Samples for ({len(df_res)}) are less than requested ({size}).")
    
    #PRINT RATIOS
    ratios = df_res[feature_col].value_counts(normalize=True).to_dict()
    print(f"Ratios for features: {ratios}")
    
    df_res = df_res.reset_index(drop=True)
    
    return df_res      

In [85]:
train_50_50 = get_exp_data(df, class_col='dx', feature_col='sex', ratio={'male':0.5, 'female': 0.5}, size=500)
# print(tb.tabulate(train_50_50, headers='keys', tablefmt='psql'))

test_50_50 = get_exp_data(df, class_col='dx', feature_col='sex', ratio={'male':0.5, 'female': 0.5}, size=100, exclude_column='file_path', exclude_df=train_50_50)
test_40_60 = get_exp_data(df, class_col='dx', feature_col='sex', ratio={'male':0.4, 'female': 0.6}, size=100, exclude_column='file_path', exclude_df=train_50_50)
test_30_70 = get_exp_data(df, class_col='dx', feature_col='sex', ratio={'male':0.3, 'female': 0.7}, size=100, exclude_column='file_path', exclude_df=train_50_50)

Feature 'unknown' not found in ratios. Skipping.
Samples for (490) are less than requested (500).
Ratios for features: {'male': 0.5, 'female': 0.5}
Feature 'unknown' not found in ratios. Skipping.
Samples for (98) are less than requested (100).
Ratios for features: {'male': 0.5, 'female': 0.5}
Feature 'unknown' not found in ratios. Skipping.
Samples for (91) are less than requested (100).
Ratios for features: {'female': 0.6153846153846154, 'male': 0.38461538461538464}
Feature 'unknown' not found in ratios. Skipping.
Samples for (98) are less than requested (100).
Ratios for features: {'female': 0.7142857142857143, 'male': 0.2857142857142857}
