In [None]:
from pandas import read_csv, concat
from pathlib import Path
import seaborn as sns
from matplotlib import pyplot as plt

general_path = Path('datasets', 'general')
diseased_path = Path('datasets', 'diseased')
general_datasets = ['cam-can', 'nkirs', 'sald', 'hcp', 'hcp-aging', 'ukbb']
diseased_datasets = ['oasis', 'aibl', 'adni']
splits = ['train', 'val', 'test']

# Read the datasets and combine them
def combine_datasets(datasets, splits, dataset_path):
    dataframes = []
    for dataset in datasets:
        for split in splits:
            file_path = dataset_path / dataset / 'splits' / f'{split}.csv'
            if file_path.exists():
                df = read_csv(file_path)
                if not df.empty:
                    df['split'] = split
                    dataframes.append(df)

    return concat(dataframes)

general_df = combine_datasets(general_datasets, splits, general_path)
diseased_df = combine_datasets(diseased_datasets, splits, diseased_path)
diseased_df = diseased_df[diseased_df['dx'] != 'Dementia Unspecified']
diseased_df = diseased_df.rename(columns={'dx': 'Diagnosis'})
general_df = general_df.rename(columns={'split': 'Split'})
general_df['Split'] = general_df['Split'].str.capitalize()
general_df['Split'] = general_df['Split'].replace({'Val': 'Validation'})

In [None]:
pretrained_exp_path = Path('pretrain_exp', 'datasets')
pretrained_csvs = list(pretrained_exp_path.glob('*.csv'))
pretrained_dfs = []
for csv in pretrained_csvs:
    df = read_csv(csv)
    if not df.empty:
        pretrained_dfs.append(df)
pretrained_df = concat(pretrained_dfs)
pretrained_df = pretrained_df.rename(columns={'dx': 'Diagnosis'})
pretrained_df = pretrained_df.drop_duplicates(subset=['image_id'])
print('Size of pretrained_df:', pretrained_df.shape[0])
print('Size of diseased_df:', diseased_df.shape[0])

In [None]:
# check if there are images in pretrained_df but not in diseased_df and vice versa
pretrained_subject_ids = set(pretrained_df['subject_id'])
diseased_subject_ids = set(diseased_df['subject_id'])
print('Pretrained subject IDs not in diseased:', len(pretrained_subject_ids - diseased_subject_ids))
print('Diseased subject IDs not in pretrained:', len(diseased_subject_ids - pretrained_subject_ids))

pretrained_images = set(pretrained_df['image_id'])
diseased_images = set(diseased_df['image_id'])
print('Pretrained images not in diseased:', len(pretrained_images - diseased_images))
print('Diseased images not in pretrained:', len(diseased_images - pretrained_images))

In [None]:
# check age distribution of the extra subjects in diseased_df
in_diseased_but_not_in_pretrained = diseased_df[~diseased_df['subject_id'].isin(pretrained_df['subject_id'])]
print('Number of images in diseased_df not in pretrained_df:', in_diseased_but_not_in_pretrained.shape[0])
# plot its age distribution
plt.figure(figsize=(10, 5))
sns.histplot(data=in_diseased_but_not_in_pretrained, x='age_at_scan', bins=50)
plt.title('Age Distribution of Extra Subjects in Diseased Dataset')
plt.xlabel('Age')
plt.ylabel('Count')
plt.show()
# plot their diagnosis (column 'dx')
plt.figure(figsize=(5, 5))
sns.countplot(data=in_diseased_but_not_in_pretrained, x='Diagnosis')
plt.title('Diagnosis Distribution of Extra Subjects in Diseased Dataset')
plt.xlabel('Diagnosis')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.show()

## General population

In [None]:
# describe statistics across datasets and splits (number of samples, age range, female-male ratio, mean age and std)
stats = general_df.groupby(['dataset', 'Split']).agg(
    n_samples=('subject_id', 'count'),
    mean_age=('age_at_scan', 'mean'),
    quartile_1_age=('age_at_scan', lambda x: x.quantile(0.25)),
    quartile_2_age=('age_at_scan', lambda x: x.quantile(0.75)),
    female_ratio=('gender', lambda x: x.value_counts(normalize=True).get('female', 0)),
    bmi=('bmi', 'mean'),
    std_bmi=('bmi', 'std'))

stats

### Plot age

In [None]:
sns.set_theme(font='Roboto', style='white')
sns.set_style('ticks')
fig, ax = plt.subplots(figsize=(4.5, 5))
histplot= sns.histplot(data=general_df, x='age_at_scan', hue='Conjunto', alpha=0.8, legend=True, palette='GnBu_r', kde=True, ax=ax)
sns.move_legend(histplot, 'upper left', fontsize='large', title_fontsize='large')
ax.tick_params(axis='both', which='major', labelsize=13)
ax.set_yticks(ax.get_yticks()[1:])
histplot.legend_.get_frame().set_linewidth(0.0)
histplot.legend_.get_frame().set_facecolor('none')
plt.xlabel('')
plt.ylabel('')
plt.ylim(0, 1100)
#plt.title('Age distribution in General population')
sns.despine()
plt.savefig('age_distribution_general.png', dpi=300, transparent=True, bbox_inches='tight')
plt.show()

### Plot sex

In [None]:
sns.set_theme(font='Roboto', style='white')
sns.set_style('ticks')
fig, ax = plt.subplots(figsize=(1.9, 4.38))
sns.set_style({'axes.grid' : False})
sns.histplot(data=general_df, x='gender', hue='Split', multiple='stack', alpha=0.8, shrink=.7, legend=False, palette='GnBu_r', ax=ax)
for patch in ax.patches:
    patch.set_edgecolor('black')
    patch.set_linewidth(0.5)
ax.set_xlabel('')
ax.set_ylabel('')
ax.tick_params(axis='both', which='major', labelsize=13)
ax.set_xticklabels(['M', 'F'], fontsize=13)
ax.set_yticks(ax.get_yticks()[1:])
# plt.title('Sex distribution in all datasets')
sns.despine()
plt.tight_layout()
plt.savefig('sex_distribution_general.png', dpi=300, transparent=True)
plt.show()

### Plot BMI

In [None]:
without_nas_in_bmi = general_df.dropna(subset=['bmi'])
sns.set_theme(font='Roboto', style='white')
sns.set_style('ticks')
fig, ax = plt.subplots(figsize=(4.7, 5))
sns.histplot(data=without_nas_in_bmi, x='bmi', hue='Split', palette='GnBu_r', alpha=0.8, legend=False, kde=True, ax=ax)
ax.set_xlabel('BMI')
ax.set_ylabel('Count')
ax.tick_params(axis='both', which='major', labelsize=13)
ax.set_xlim(5, 50)
ax.set_yticks(ax.get_yticks()[1:])
plt.xlabel('')
plt.ylabel('')
plt.ylim(0, 1100)
sns.despine()
plt.savefig('bmi_distribution_general.png', dpi=300, transparent=True)
# ax.set_title('BMI distribution in all datasets')
plt.show()

## Diseased

In [None]:
stats = diseased_df.groupby(['dataset']).agg(
    n_samples=('subject_id', 'count'),
    min_age=('age_at_scan', 'min'),
    max_age=('age_at_scan', 'max'),
    mean_age=('age_at_scan', 'mean'),
    std_age=('age_at_scan', 'std'),
    females=('gender', lambda x: (x == 'female').sum()),
    female_ratio=('gender', lambda x: x.value_counts(normalize=True).get('female', 0)),)

stats

In [None]:
plt.figure(figsize=(7, 5))
sns.set_style('white', {'axes.grid' : False})
sns.histplot(data=diseased_df, x='age_at_scan', hue='Diagnosis', multiple='layer', alpha=0.5, legend=True, kde=True,
             palette='Set2')
plt.xlabel('Age')
plt.ylabel('Count')
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.title('Age distribution in Diseased population')
sns.despine()
plt.savefig('age_distribution_diseased.png', dpi=300, transparent=True)
plt.show()

In [None]:

def plot_balanced_distributions(balanced_dvh, balanced_dvp, balanced_hvp):
    balanced_dvh['dataset'] = 'AD vs. HC'
    balanced_dvp['dataset'] = 'AD vs. MCI'
    balanced_hvp['dataset'] = 'MCI vs. HC'

    # Rename gender values from "male" and "female" to "M" and "F"
    balanced_dvh['gender'] = balanced_dvh['gender'].replace({'male': 'M', 'female': 'F'})
    balanced_dvp['gender'] = balanced_dvp['gender'].replace({'male': 'M', 'female': 'F'})
    balanced_hvp['gender'] = balanced_hvp['gender'].replace({'male': 'M', 'female': 'F'})

    # sns.reset_defaults()
    sns.set_theme(font='Roboto', style='white')
    sns.set_style('ticks')
    fig = plt.figure(constrained_layout=True, figsize=(7.7, 4))
    gs = fig.add_gridspec(3, 2, width_ratios=[2, 1], height_ratios=[1, 1, 1])

    # Create the main violin plot axis
    ax_violin = fig.add_subplot(gs[:, 0])

    # Plot each dataset separately to ensure proper split violins
    datasets_data = [
        (balanced_dvh, 'AD vs. HC', 0),
        (balanced_dvp, 'AD vs. MCI', 1), 
        (balanced_hvp, 'MCI vs. HC', 2)
    ]

    all_classes = concat([balanced_dvh, balanced_dvp, balanced_hvp])['dx'].unique()
    colors = sns.color_palette('Set2', n_colors=len(all_classes))
    colors = [colors[-1], colors[0], colors[1]]
    class_color_map = dict(zip(all_classes, colors))

    for data, dataset_name, _ in datasets_data:
        temp_df = data.copy()
        
        sns.violinplot(
            data=temp_df,
            x=[dataset_name] * len(temp_df),  # All rows get the same x value
            y='age_at_scan',
            hue='dx',
            split=True,
            inner='quartile',
            ax=ax_violin,
            palette=[class_color_map[cls] for cls in temp_df['dx'].unique()],
            edgecolor='black',
            linewidth=0.5
        )

    ax_violin.set_title('Age distribution')
    ax_violin.set_xlabel('')
    ax_violin.set_ylabel('Age')
    ax_violin.grid(False)
    ax_violin.set_axisbelow(True)
    ax_violin.spines['top'].set_visible(False)
    ax_violin.spines['right'].set_visible(False)

    # Remove the default legend from violin plot
    ax_violin.get_legend().remove()

    # Gender bar plots, stacked vertically, one for each dataset
    datasets = ['AD vs. HC', 'AD vs. MCI', 'MCI vs. HC']
    datasets_df = [balanced_dvh, balanced_dvp, balanced_hvp]

    for i, (ds, df) in enumerate(zip(datasets, datasets_df)):
        ax_bar = fig.add_subplot(gs[i, 1])
        sns.countplot(
            data=df,
            y='gender',
            hue='dx',
            ax=ax_bar,
            palette=[class_color_map[cls] for cls in df['dx'].unique()],
            edgecolor='black',
            linewidth=0.5
        )
        ax_bar.set_ylabel(ds)
        ax_bar.set_xlabel('')
        if i == 0:
            ax_bar.set_title('Sex distribution')
        ax_bar.grid(False)
        ax_bar.spines['top'].set_visible(False)
        ax_bar.spines['right'].set_visible(False)
        ax_bar.get_legend().remove()

    # Create custom legend at center bottom with 3 columns
    legend_elements = [plt.Rectangle((0,0),1,1, facecolor=class_color_map[cls], alpha=0.7, label=cls) 
                    for cls in all_classes]
    fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.08), ncol=3)
    plt.tight_layout()
    plt.show()
    return fig

In [None]:
balanced_dvh = read_csv('balanced_dvh.csv')
balanced_dvp = read_csv('balanced_dvp.csv')
balanced_hvp = read_csv('balanced_hvp.csv')

fig_balanced = plot_balanced_distributions(balanced_dvh, balanced_dvp, balanced_hvp)
fig_balanced.savefig('diseased_balanced_distributions.png', bbox_inches='tight', dpi=300, transparent=True)