The objectives of the Mayo Clinic Study of Aging were to determine in the population of Olmsted County, Minn., (1) the prevalence of MCI; (2) the incidence of MCI; (3) conversion rates from MCI to dementia or AD; (4) risk factors for MCI; and (5) risk factors for the progression from MCI to dementia or AD. The long-term goals of the Mayo Clinic Study of Aging are to develop tools to predict and prevent cognitive decline and dementia, develop risk-prediction models for cognitive impairment, and conduct aging-related research to promote successful aging.

# Setup and Libraries

In [None]:
# Cell 0: get current notebook path, its parent, and project root (parent of parent)
import os
from pathlib import Path

# Data manipulation
import pandas as pd
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Optional: improve plots appearance
sns.set(style="whitegrid")

ROOT_DIR = Path(os.getcwd()).resolve().parent


# Load the CSV and parse Imaging Protocol + merge with other MCSA csv

In [None]:
CSV_DIR = ROOT_DIR / "csv_dir"
df1 = pd.read_csv(CSV_DIR / "MCSA_all.csv")
df2 = pd.read_csv(CSV_DIR / "MCSA_Data.csv")

df1["Subject ID"] = df1["Subject ID"].astype(str)
df2["MCSA_ID"] = df2["MCSA_ID"].astype(str)

df1["Visit"] = pd.to_numeric(df1["Visit"], errors="coerce")
df2["visit_num"] = pd.to_numeric(df2["visit_num"], errors="coerce")

original_df = pd.merge(
    df1,
    df2,
    left_on=["Subject ID", "Visit"],
    right_on=["MCSA_ID", "visit_num"],
    how="left",
    indicator=True,
)

print(original_df["_merge"].value_counts())

# drop unnecessary columns and rename others
original_df = original_df.drop(columns=["_merge", "MCSA_ID", "visit_num"])
original_df = original_df.drop(columns=["Age"]).rename(columns={"calc_age_vis": "Age"})
original_df = original_df.drop(columns=["Sex"]).rename(columns={"Male": "Sex"})


# little adjustments
original_df["Subject ID"] = original_df["Subject ID"].astype("category")
original_df["Image ID"] = original_df["Image ID"].astype("category")
original_df["Weight"] = original_df["Weight"].replace(0, np.nan)
original_df["Age"] = original_df["Age"].replace(0, np.nan)
original_df["Study Date"] = pd.to_datetime(original_df["Study Date"], errors="coerce")

print(original_df.columns)
print(len(original_df))

fields = [
    "Acquisition Plane",
    "Slice Thickness",
    "Matrix Z",
    "Acquisition Type",
    "Manufacturer",
    "Mfg Model",
    "Field Strength",
    "Weighting",
]

numeric_fields = ["Slice Thickness", "Matrix Z", "Field Strength"]


def parse_imaging_protocol(text):
    if pd.isna(text):
        return {}

    items = text.split(";")
    parsed = {}

    for item in items:
        if "=" in item:
            key, value = item.split("=", 1)
            parsed[key.strip()] = value.strip()

    return parsed


# Parse the column into dictionaries
protocol_parsed = original_df["Imaging Protocol"].apply(parse_imaging_protocol)

# Create new columns
for field in fields:
    original_df[field] = protocol_parsed.apply(lambda x: x.get(field, np.nan))


for field in numeric_fields:
    original_df[field] = pd.to_numeric(original_df[field], errors="coerce")

print(original_df.columns)

# Filtered dataframe

In [None]:
# Apply filters
weighting_filter = original_df["Weighting"] == "T1"

df = original_df[
    (original_df["Modality"] == "MRI")
    & weighting_filter
    & (original_df["Matrix Z"] > 100)
    & (original_df["Slice Thickness"] < 1.4)
    & (original_df["Acquisition Type"] == "3D")
    & (original_df["Type"] == "Original")
].copy()

print(
    f"Filtered dataset size: {df.shape[0]} images from {df['Subject ID'].nunique()} subjects."
)

### No filter

In [None]:
# df = original_df
# print(
#     f"Filtered dataset size: {df.shape[0]} images from {df['Subject ID'].nunique()} subjects."
# )

In [None]:
df.head(2)

# Basic descriptions

In [None]:
# -------------------------------
# Column Names
# -------------------------------
print("=== Columns in the Dataset ===")
display(pd.DataFrame(df.columns, columns=["Column Names"]))

# -------------------------------
# Summary of Numeric Features
# -------------------------------
print("\n=== Numeric Features Summary ===")
display(df.describe().round(2))  # round to 2 decimals

# -------------------------------
# Summary of Categorical Features
# -------------------------------
print("\n=== Categorical Features Summary ===")
display(df.describe(include=["object", "category"]))

# -------------------------------
# Missing Values
# -------------------------------
missing_count = df.isnull().sum()
missing_percent = (missing_count / len(df) * 100).round(2)
missing_df = pd.DataFrame(
    {"Missing Count": missing_count, "Missing %": missing_percent}
).sort_values(by="Missing Count", ascending=False)

print("\n=== Missing Values by Column ===")
display(missing_df)

# Find columns with at least one missing value
cols_with_missing = df.columns[df.isnull().any()]
print(f"Columns with missing values ({len(cols_with_missing)}):\n")

# for col in cols_with_missing:
#     print(f"--- {col} ---")
#     # Show up to 5 rows where this column is missing
#     display(df[df[col].isnull()].head(3))

# Histograms

In [None]:
# Separate numeric and categorical columns
numeric_cols = df.select_dtypes(include=np.number).columns
categorical_cols = df.select_dtypes(include=["object", "category"]).columns

# -------------------------------
# Numeric Columns Histograms
# -------------------------------
for col in numeric_cols:
    plt.figure(figsize=(6, 4))

    # Plot histogram
    ax = sns.histplot(
        df[col].dropna(), bins=30, kde=False
    )  # disable KDE for counts clarity

    plt.title(f"Histogram of {col}")
    plt.xlabel(col)
    plt.ylabel("Count")

    # Annotate counts on top of each bin
    for patch in ax.patches:
        height = patch.get_height()
        if height > 0:  # only annotate non-empty bins
            ax.text(
                patch.get_x() + patch.get_width() / 2,  # center of bin
                height + 0.5,  # slightly above the bar
                int(height),  # show integer count
                ha="center",
                va="bottom",
                fontsize=8,
            )

    plt.show()

# -------------------------------
# Categorical Columns Bar Plots
# -------------------------------
categorical_cols = [
    "Visit",
    "Sex",
    "Research Group",
    "Modality",
    "Type",
    "Structure",
    "Laterality",
    "Image Type",
    "Registration",
    "Description",
    "Tissue",
    # Imaging Protocolâ€“derived categorical columns
    "Acquisition Plane",
    "Acquisition Type",
    # "Manufacturer",
    # "Mfg Model",
    "Weighting",
]

categorical_cols = [col for col in categorical_cols if col in df.columns]
for col in categorical_cols:
    plt.figure(figsize=(6, 4))

    counts = df[col].value_counts(dropna=False)
    total = counts.sum()
    order = counts.index

    ax = sns.countplot(y=col, data=df, order=order)

    # Add count + percentage labels
    for p, category in zip(ax.patches, order):
        count = counts[category]
        percent = 100 * count / total

        ax.text(
            p.get_width() + 0.5,
            p.get_y() + p.get_height() / 2,
            f"{count} ({percent:.1f}%)",
            va="center",
        )

    plt.title(f"Value Counts for {col}")
    plt.xlabel("Count")
    plt.ylabel(col)
    plt.tight_layout()
    plt.show()


# Study dates

In [None]:
# -------------------------------
# Study Date Distribution
# -------------------------------


plt.figure(figsize=(10, 5))

# Histogram of study dates
ax = sns.histplot(df["Study Date"].dropna(), bins=30, kde=False)

plt.title("Distribution of Study Dates")
plt.xlabel("Study Date")
plt.ylabel("Number of Studies")

# Annotate counts on top of each bin
for patch in ax.patches:
    height = patch.get_height()
    if height > 0:
        ax.text(
            patch.get_x() + patch.get_width() / 2,
            height + 0.5,
            int(height),
            ha="center",
            va="bottom",
            fontsize=8,
        )

plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# cumulative study count over time
plt.figure(figsize=(10, 4))
df_sorted = df.sort_values("Study Date")
df_sorted["Cumulative Count"] = range(1, len(df_sorted) + 1)
plt.plot(
    df_sorted["Study Date"], df_sorted["Cumulative Count"], marker="o", linestyle="-"
)
plt.title("Cumulative Study Count Over Time")
plt.xlabel("Study Date")
plt.ylabel("Cumulative Count")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


# Scans per session / coverage

Does a subject have more than one scan at a given visit?

How often does this happen?

In [None]:
df["Visit"] = df["Visit"].astype(str)

# scans per subject per visit
scans_per_subject_visit = (
    df.groupby(["Subject ID", "Visit"], observed=True)
    .size()
    .reset_index(name="n_scans")
)
# Keep only rows with at least 1 scan
scans_per_subject_visit = scans_per_subject_visit[
    scans_per_subject_visit["n_scans"] > 0
].copy()
# print(scans_per_subject_visit.head())
# print(scans_per_subject_visit["n_scans"].min())  # should now be >= 1


In [None]:
plt.figure(figsize=(6, 4))
ax = sns.countplot(
    x="n_scans",
    data=scans_per_subject_visit,
    order=sorted(scans_per_subject_visit["n_scans"].unique()),
)

plt.xlabel("Number of Scans per Subject per Visit")
plt.ylabel("Number of (Subject, Visit) Sessions")
plt.title("Multiplicity of Scans per Session")

# Annotate counts on top of bars
for p in ax.patches:
    ax.annotate(
        int(p.get_height()),
        (p.get_x() + p.get_width() / 2.0, p.get_height()),
        ha="center",
        va="bottom",
    )

plt.show()


# Longitudinal analysis - for filtered subset

In [None]:
df["Visit"].unique()

In [None]:
# Convert Visit to numeric
df["Visit_Num"] = pd.to_numeric(df["Visit"], errors="coerce")

# Sanity check
print(df["Visit_Num"].isna().sum())  # should be 0

# Scheduled visits are simply the numeric visit numbers
scheduled_visits = sorted(df["Visit_Num"].dropna().unique())

print("Scheduled Visits:", scheduled_visits)


scheduled_df = df.copy()

visits_per_subject = (
    scheduled_df.groupby("Subject ID", observed=True)["Visit_Num"]
    .nunique()
    .sort_values(ascending=False)
)

# Summary
summary_stats = visits_per_subject.describe()
print("Longitudinal Coverage per Subject:")
display(summary_stats)


In [None]:
df["Visit_Num"].unique()

In [None]:
# Count how many subjects have N visits -- How many dropouts, How many fully followed subjects, wether dataset is shallow or deep
visit_counts = visits_per_subject.value_counts().sort_index()

plt.figure(figsize=(6, 4))
ax = sns.barplot(x=visit_counts.index, y=visit_counts.values, palette="Blues_d")

# Annotate bars
for i, v in enumerate(visit_counts.values):
    ax.text(i, v + 0.5, str(v), ha="center", va="bottom", fontsize=10)

plt.xlabel("Number of Visits")
plt.ylabel("Number of Subjects")
plt.title("Longitudinal Follow-up Depth")
plt.ylim(0, visit_counts.values.max() * 1.1)
plt.show()


In [None]:
# how many subjects remain at each visit -- Where dropout occurs, Which visits are well populated
retention = (
    scheduled_df.groupby("Visit_Num", observed=True)["Subject ID"]
    .nunique()
    .sort_index()
)

plt.figure(figsize=(7, 4))
plt.plot(retention.index, retention.values, marker="o")

for x, y in zip(retention.index, retention.values):
    plt.text(x, y + 0.5, str(y), ha="center", va="bottom", fontsize=9)

plt.xlabel("Visit Number")
plt.ylabel("Number of Active Subjects")
plt.title("Subject Retention Over Visits")
plt.grid(True)
plt.ylim(0, retention.max() * 1.1)
plt.show()


In [None]:
# Ensure Visit is treated as a category so we know the set of visits
df["Visit_Str"] = df["Visit"].astype(str)

# Sort visit labels numerically (if numeric)
try:
    sorted_visits = sorted(df["Visit_Str"].unique(), key=lambda x: int(x))
except ValueError:
    sorted_visits = sorted(df["Visit_Str"].unique())

# How many columns/rows for the figure grid
n_visits = len(sorted_visits)
n_cols = 4  # change if you want wider or narrower layout
n_rows = (n_visits + n_cols - 1) // n_cols

plt.figure(figsize=(4 * n_cols, 3 * n_rows))

for i, visit_label in enumerate(sorted_visits):
    ax = plt.subplot(n_rows, n_cols, i + 1)
    visit_df = df[df["Visit_Str"] == visit_label]

    sns.histplot(visit_df["Age"].dropna(), bins=20, kde=False, ax=ax)
    ax.set_title(f"Visit {visit_label}")
    ax.set_xlabel("Age")
    ax.set_ylabel("Count")
    ax.set_xlim(df["Age"].min(), df["Age"].max())

plt.tight_layout()
plt.suptitle("Age Distribution at Each Visit", y=1.02, fontsize=16)
plt.show()