# Data Augmentation for ADNI AD data for balancing

In [2]:
import os
import cv2
import pandas as pd
from tqdm import tqdm
from glob import glob
import albumentations as A

# --- CONFIGURATION ---
csv_path = '/Volumes/Samsung_PSSD_T7_Shield/ADNI/ADNI1_Baseline_Only.csv'
data_dir = "/Volumes/Samsung_PSSD_T7_Shield/ADNI/ADNI FULLY PREPPROCCSED"
aug_dir = "/Volumes/Samsung_PSSD_T7_Shield/ADNI/AUGMENTED_AD_ONLY_SUBJECT_LEVEL"
os.makedirs(aug_dir, exist_ok=True)

AUG_PER_SUBJECT = 3  # Number of augmentations per AD subject

# Load subject labels and keep only AD
df = pd.read_csv(csv_path)
df['PTID'] = df['PTID'].str.replace('_', '-', regex=False)
ad_subjects = df[df['DX_bl'] == 'AD']['PTID'].tolist()

# Define Replayable transform
def get_replayable_transform():
    return A.ReplayCompose([
        A.RandomCrop(width=192, height=192),
        A.Resize(224, 224),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.RandomScale(scale_limit=0.1, p=0.5),
    ])

# Process only AD subjects
for subject_folder in tqdm(os.listdir(data_dir)):
    subject_id = subject_folder.replace("_", "-")
    if subject_id not in ad_subjects:
        continue

    slice_dir = os.path.join(data_dir, subject_folder, "sagittal_slices")
    if not os.path.isdir(slice_dir):
        continue

    slice_files = sorted(glob(os.path.join(slice_dir, "*.png")))
    if len(slice_files) != 50:
        continue

    for aug_idx in range(AUG_PER_SUBJECT):
        out_dir = os.path.join(aug_dir, f"{subject_folder}_aug{aug_idx}")
        os.makedirs(out_dir, exist_ok=True)

        # Load first slice to generate a transform replay
        first_img = cv2.imread(slice_files[0])
        first_img = cv2.cvtColor(first_img, cv2.COLOR_BGR2RGB)

        replay_tfm = get_replayable_transform()
        replay_result = replay_tfm(image=first_img)
        replay = replay_result['replay']  # capture the transform sequence

        for slice_file in slice_files:
            img = cv2.imread(slice_file)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            transformed = A.ReplayCompose.replay(replay, image=img)['image']
            out_path = os.path.join(out_dir, os.path.basename(slice_file))
            cv2.imwrite(out_path, cv2.cvtColor(transformed, cv2.COLOR_RGB2BGR))


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 640/640 [00:49<00:00, 13.01it/s]
