In [None]:
import pandas as pd
from prettytable import PrettyTable
import os
import matplotlib.pyplot as plt

In [None]:
exp_data = {
    "Name": ["FullData-Multiclass", "FullData-Binary", "AugTrain_OrgTest-Multiclass", "AugTrain_OrgTest-Binary"],
    "Train Set": ["Original + Augmented", "Original + Augmented", "Augmented", "Augmented"],
    "Test Set": ["Original + Augmented", "Original + Augmented", "Original", "Original"],
    "Classes": ["All(4)", "NonDemented vs VeryMildDemented", "All(4)", "NonDemented vs VeryMildDemented"],
    "Purpose": ["Baseline multiclass performance", "Binary classification for early detection", "Generalization on real data", "Early detection generalization"]
}

exp_df = pd.DataFrame(exp_data)

'''exp_table = PrettyTable()
exp_table.field_names = exp_df.columns.tolist()

for row in exp_df.itertuples(index=False):
    exp_table.add_row(row)'''

exp_df

In [None]:
def count_images_per_class(folder):
    return {cls: len(os.listdir(os.path.join(folder, cls))) for cls in os.listdir(folder)}

aug_path = "/kaggle/input/augmented-alzheimer-mri-dataset/AugmentedAlzheimerDataset"
org_path = "/kaggle/input/augmented-alzheimer-mri-dataset/OriginalDataset"

aug_counts = count_images_per_class(aug_path)
org_counts = count_images_per_class(org_path)

In [None]:
all_classes = sorted(set(aug_counts.keys()).union(org_counts.keys()))

mri_data = {
    "Classes": all_classes,
    "Augmented": [aug_counts.get(cls, 0) for cls in all_classes],
    "Original": [org_counts.get(cls, 0) for cls in all_classes]
}

mri_df = pd.DataFrame(mri_data)

mri_table = PrettyTable()
mri_table.field_names = mri_df.columns.tolist()

for row in mri_df.itertuples(index=False):
    mri_table.add_row(row)

print(mri_table)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].bar(aug_counts.keys(), aug_counts.values(), color='skyblue')
axs[0].set_title("Augmented Dataset")
axs[0].set_ylabel("Number of Images")
axs[0].tick_params(axis='x', rotation=45)

axs[1].bar(org_counts.keys(), org_counts.values(), color='salmon')
axs[1].set_title("Original Dataset")
axs[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()