This notebook assumes the script *aging_diffs.py* has already been executed

### Loading

In [None]:
from pathlib import Path
from ants import image_read, image_list_to_matrix, matrix_to_images

def keep_negative_values(imgs_list):
    neg_imgs_list = []
    for i, img in enumerate(imgs_list):
        neg_img_mtx = img.numpy()
        neg_img_mtx[neg_img_mtx > 0] = 0
        neg_img_mtx = abs(neg_img_mtx)
        neg_img = img.new_image_like(neg_img_mtx)
        neg_imgs_list.append(neg_img)
    return neg_imgs_list

save_path = Path('aging', 'decomposition')
mask_img = image_read('MNI152_T1_1mm_brain_mask.nii.gz')

In [None]:
age_changes_dir = Path('evaluation') / 'general' / 'test' / 'age_invariant' / 'e100' / 'age_changes'
aged_diffs = list((age_changes_dir / 'aged').glob('*.nii.gz'))
rejuvenated_diffs = list((age_changes_dir / 'rejuvenated').glob('*.nii.gz'))
diffs = aged_diffs + rejuvenated_diffs
imgs_list = [image_read(str(img)) for img in diffs]

### PCA

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import matplotlib.pyplot as plt

imgs_matrix = image_list_to_matrix(imgs_list, mask_img)
del imgs_list
scaler = StandardScaler(with_std=False)
centered_data = scaler.fit_transform(imgs_matrix)
del imgs_matrix
n_components = 30
pca = PCA(n_components=n_components)
principal_components = pca.fit_transform(centered_data)
explained_variance = pca.explained_variance_ratio_

#### Plot variance explained

In [None]:
import numpy as np

sns.set_theme()
sns.set_style('white')
fig, ax = plt.subplots(figsize=(6, 7))
n_components_to_plot = n_components

# Create x-axis values starting from 1
x_positions = np.arange(1, n_components_to_plot + 1)
variances_to_plot = explained_variance[:n_components_to_plot]

# Use a more professional color scheme
bar_color = sns.color_palette("deep")[0]
line_color = "#143A80"

# Create bars with better spacing
bars = ax.bar(x_positions, variances_to_plot, color=bar_color, width=1.0, alpha=0.9, 
              edgecolor='white', linewidth=0.5)

# Add line plot
ax.plot(x_positions, variances_to_plot, color=line_color, alpha=0.9, linewidth=2, 
        marker='o', markersize=3, markerfacecolor=line_color)

# Improve axis styling
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(0.8)
ax.spines['bottom'].set_linewidth(0.8)
ax.spines['left'].set_color('#333333')
ax.spines['bottom'].set_color('#333333')

# Set axis positions
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

# Improve x-axis
ax.set_xlim(0.5, n_components_to_plot + 0.5)
ax.set_xticks([1, 5, 10, 15, 20, 25, 30])
# ax.set_xlabel('Principal Component', fontsize=15, fontweight='bold', color='#333333')

# Improve y-axis with percentage formatting
max_variance = max(variances_to_plot)
if max_variance > 0.08:
    y_ticks = [0, 0.02, 0.04, 0.06, 0.10]
    y_labels = ['0%', '2%', '4%', '6%', '10%']
else:
    y_ticks = [0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06]
    y_labels = ['0%', '1%', '2%', '3%', '4%', '5%', '6%']

ax.set_yticks(y_ticks)
ax.set_yticklabels(y_labels)
ax.set_ylim(0, max_variance)
# ax.set_ylabel('Explained Variance Ratio', fontsize=15, fontweight='bold', color='#333333')

# Add grid for better readability
# ax.grid(True, axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
ax.set_axisbelow(True)

# Improve tick appearance
ax.tick_params(axis='both', which='major', labelsize=17, colors='#333333')
ax.tick_params(axis='x', which='major', length=4, width=1.0)
ax.tick_params(axis='y', which='major', length=4, width=1.0)

# Add cumulative variance as text annotation
cumulative_variance = np.cumsum(explained_variance[:30])
ax.text(0.52, 0.98, f'First 30 PCs: {cumulative_variance[-1]:.1%}', 
        transform=ax.transAxes, fontsize=18, verticalalignment='top', 
        horizontalalignment='right', bbox=dict(boxstyle='round,pad=0.3', 
        facecolor='white', alpha=1.0, edgecolor="#8b8b8b"))

# Final styling
save_path.mkdir(exist_ok=True)
fig.patch.set_facecolor('white')
plt.tight_layout()

# Uncomment to save the figure
fig.savefig(save_path / 'explained_variance.png', dpi=300, bbox_inches='tight', 
            # facecolor='white', edgecolor='none', 
            transparent=True)

plt.show()

# Print summary statistics
print(f'Explained variance ratio: {explained_variance}')
print(f'Total explained variance: {sum(explained_variance):.3f}')
print(f'First 5 components explain: {sum(explained_variance[:5]):.1%}')
print(f'First 10 components explain: {sum(explained_variance[:10]):.1%}')

### NMF

#### Keep negative values only

In [None]:
imgs_list = keep_negative_values(imgs_list)
neg_diffs = image_list_to_matrix(imgs_list, mask_img)
del imgs_list

#### Compute NMF

In [None]:
from sklearn.decomposition import NMF
from numpy import save

def nmf_decomposition(imgs_matrix, mask_img, n_components=3, filename='nmf_neg'):
    nmf = NMF(n_components=n_components, init='random', random_state=0)
    W = nmf.fit_transform(imgs_matrix)
    H = nmf.components_
    print(f'W shape: {W.shape}')
    print(f'H shape: {H.shape}')
    print(f'Number of components: {nmf.n_components_}')
    print(f'NMF reconstruction error: {nmf.reconstruction_err_}')
    out_path = save_path / f'{n_components}_components'
    out_path.mkdir(exist_ok=True, parents=True)
    print(f'Saving NMF components to {out_path}')
    save_3d_components(H, mask_img, out_path, filename)
    save(out_path / f'{filename}_components.npy', H)
    save(out_path / f'{filename}_weights.npy', W)

def save_3d_components(H, mask_img, out_path, filename):
    for i in range(H.shape[0]):
        component_3d = matrix_to_images(H[i][None, :], mask_img)[0]
        component_3d.to_filename(str(out_path / f'{filename}{i + 1}_aging.nii.gz'))


nmf_decomposition(neg_diffs, mask_img, n_components=3)

#### Threshold components

In [None]:
import nibabel as nib
from numpy import stack, any, argmax, where, sum, sqrt, newaxis


def create_thresholded_binary_maps(comp_imgs, threshold=0.001, output_file='component_{id}_binary_map.nii.gz'):
    """
    Create a volume for each component where each voxel is labeled with the index of the component
    that has the maximum value above a given threshold. 0 where all are below threshold.
    """
    n_components = len(comp_imgs)
    comp_data = [img.get_fdata() for img in comp_imgs]
    all_comps = stack(comp_data, axis=-1)
    above_threshold = any(all_comps > threshold, axis=-1)
    max_comp_idx = argmax(all_comps, axis=-1) + 1  # +1 for 1-based labeling
    thresholded_vol = where(above_threshold, max_comp_idx, 0)
    for i in range(n_components):
        output_file_i = output_file.format(id=i+1)
        output_vol_i = where(thresholded_vol == (i + 1), 1, 0)
        output_img_i = nib.Nifti1Image(output_vol_i, comp_imgs[i].affine, comp_imgs[i].header)
        nib.save(output_img_i, output_file_i)


def normalize_components(component_files):
    """
    Normalize component data by L2 norm.
    """
    comp_data = []
    imgs = []
    for file in component_files:
        img = nib.load(file)
        imgs.append(img)
        comp_data.append(img.get_fdata())
    all_comps = stack(comp_data, axis=-1)
    norms = sqrt(sum(all_comps**2, axis=(0, 1, 2)))
    normalized_comps = all_comps / norms[newaxis, newaxis, newaxis, :]
    output_imgs = []
    for i, img in enumerate(imgs):
        output_img = nib.Nifti1Image(
            normalized_comps[..., i], img.affine, img.header)
        output_imgs.append(output_img)

    return output_imgs

components_path = (save_path / '3_components')
components = list(components_path.glob('nmf_neg*_aging.nii.gz'))
normalized_imgs = normalize_components(components)
thresholded_output = components_path / 'component_{id}_binary_map.nii.gz'
result = create_thresholded_binary_maps(normalized_imgs, threshold=0.001, output_file=str(thresholded_output))

#### Reconstruct AD and HC with precomputed aging components

In [None]:
from numpy import load
from sklearn.decomposition import NMF

def load_age_changes(path):
    age_changes_dir = Path(path)
    diffs = list((age_changes_dir).glob('*.nii.gz'))
    mask_img = image_read('MNI152_T1_1mm_brain_mask.nii.gz')
    imgs_list = [image_read(str(img)) for img in diffs]
    return imgs_list, mask_img

def project_onto_components(imgs_matrix, components):
    n_components = components.shape[0]
    nmf_model = NMF(n_components=n_components, init='custom', random_state=0)
    nmf_model.components_ = components
    nmf_model.n_components_ = n_components
    W_new = nmf_model.transform(imgs_matrix)
    return W_new
    
ad_age_changes = Path('evaluation', 'diseased', 'test', 'age_invariant', 'e100', 'ad_rejuvenated')
hc_age_changes = Path('evaluation', 'diseased', 'test', 'age_invariant', 'e100', 'hc_rejuvenated')
aging_components = load(Path('aging', 'decomposition', '3_components', 'nmf_neg_components.npy'))

ad_imgs, ad_mask = load_age_changes(ad_age_changes)
hc_imgs, hc_mask = load_age_changes(hc_age_changes)
ad_imgs = keep_negative_values(ad_imgs)
hc_imgs = keep_negative_values(hc_imgs)
ad_imgs = image_list_to_matrix(ad_imgs, ad_mask)
hc_imgs = image_list_to_matrix(hc_imgs, hc_mask)
ad_projected = project_onto_components(ad_imgs, aging_components)
hc_projected = project_onto_components(hc_imgs, aging_components)
del ad_imgs, hc_imgs

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy import stats

def compare_nmf_groups(weights_group1, weights_group2, 
                       group1_name='Group 1', group2_name='Group 2',
                       save_path=None):

    # Normalize weights to percentages for easier interpretation
    W1_norm = weights_group1 / weights_group1.sum(axis=1, keepdims=True)
    W2_norm = weights_group2 / weights_group2.sum(axis=1, keepdims=True)

    # Create a long-form DataFrame for seaborn
    data_list = []
    for i in range(3):
        for val in W1_norm[:, i]:
            data_list.append({'Component': f'Component {i+1}', 'Normalized Weight': val, 'Group': group1_name})
        for val in W2_norm[:, i]:
            data_list.append({'Component': f'Component {i+1}', 'Normalized Weight': val, 'Group': group2_name})
    
    df = pd.DataFrame(data_list)

    _, ax = plt.subplots(figsize=(7, 6))
    
    # Create split violin plot with cut=0 to prevent KDE extending below zero
    sns.violinplot(data=df, x='Component', y='Normalized Weight', hue='Group',
                   split=True, palette=['steelblue', 'coral'],
                   cut=0, inner='quartile', width=0.5,
                   linecolor='black', linewidth=1.0,
                   ax=ax)

    ax.set_ylabel('Normalized Weight', fontsize=11)
    ax.set_xlabel('')
    ax.tick_params(axis='y', labelsize=12)
    ax.tick_params(axis='x', labelsize=12)
    ax.grid(axis='y', alpha=0.3)
    ax.legend(loc='upper right', fontsize=12)

    # Perform statistical tests and add significance markers
    components = ['Component 1', 'Component 2', 'Component 3']
    for i, comp in enumerate(components):
        stat, p_value = stats.mannwhitneyu(W1_norm[:, i], W2_norm[:, i], alternative='two-sided')
        y_max = max(W1_norm[:, i].max(), W2_norm[:, i].max())

        if p_value < 0.001:
            sig = '***'
        elif p_value < 0.01:
            sig = '**'
        elif p_value < 0.05:
            sig = '*'
        else:
            sig = 'ns'

        ax.text(i, y_max * 1.02, sig, ha='center', fontsize=14)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path / 'nmf_violinplot_comparison.png', dpi=300, bbox_inches='tight', transparent=True)
    plt.show()

    print("\n" + "="*70)
    print("STATISTICAL COMPARISON SUMMARY")
    print("="*70)

    for i in range(3):
        print(f"\nComponent {i+1}:")
        print(f"  {group1_name}: mean={W1_norm[:, i].mean():.3f}, std={W1_norm[:, i].std():.3f}")
        print(f"  {group2_name}: mean={W2_norm[:, i].mean():.3f}, std={W2_norm[:, i].std():.3f}")

        stat, p_value = stats.mannwhitneyu(W1_norm[:, i], W2_norm[:, i], alternative='two-sided')
        cohens_d = (W1_norm[:, i].mean() - W2_norm[:, i].mean()) / \
                   np.sqrt((W1_norm[:, i].std()**2 + W2_norm[:, i].std()**2) / 2)
        print(f"  Mann-Whitney U test: U={stat:.2f}, p={p_value:.4f}")
        print(f"  Effect size (Cohen's d): {cohens_d:.3f}")

    print("="*70 + "\n")

compare_nmf_groups(ad_projected, hc_projected,
                   group1_name='Alzheimer\'s Disease', 
                   group2_name='Healthy Controls',
                   save_path=save_path)