In [None]:
!pip install torch_geometric
!pip install dipy
!pip install boto3

In [None]:
from sklearn.metrics import pairwise_distances
import os
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table
from dipy.reconst.dti import TensorModel
from dipy.io.image import load_nifti
from dipy.reconst.csdeconv import auto_response_ssst, ConstrainedSphericalDeconvModel
from dipy.direction import ProbabilisticDirectionGetter
from dipy.data import default_sphere
from dipy.tracking import utils
from dipy.tracking.local_tracking import LocalTracking
from dipy.io.image import save_nifti
from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
from tqdm import tqdm
import itertools
import pickle
from dipy.tracking import Streamlines
from dipy.segment.clustering import QuickBundles
import boto3
from collections import defaultdict
import torch
from torch_geometric.data import HeteroData
import nibabel as nib
import numpy as np
from dipy.align.imaffine import transform_centers_of_mass
from scipy.ndimage import affine_transform
import pandas as pd
from pathlib import Path
from dipy.tracking.streamline import length, values_from_volume



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
aws_access_key_id = ''
aws_secret_access_key = ''
session = boto3.Session(
    aws_access_key_id=aws_access_key_id,
    aws_secret_access_key=aws_secret_access_key,
    region_name='us-east-1'
)

bucket = 'hcp-openaccess'


s3 = session.client('s3')

In [None]:
# Farthest Point downsampling
def downsample(subj_dir):
    bvals = np.loadtxt(subj_dir / "bvals")
    bvecs = np.loadtxt(subj_dir / "bvecs")

    nonzero_idx = np.where(bvals > 100)[0]
    bvecs_nonzero = bvecs[:, nonzero_idx].T  # Shape: (N, 3)

    bvecs_nonzero /= np.linalg.norm(bvecs_nonzero, axis=1, keepdims=True) + 1e-6 # Normalize vectors to unit length

    n_keep = len(bvecs_nonzero) // 3
    print(f"Keeping {n_keep} directions out of {len(bvecs_nonzero)}")


    selected_idx = [np.random.randint(len(bvecs_nonzero))] # Initialize selection with a random point
    distances = pairwise_distances(bvecs_nonzero, bvecs_nonzero[selected_idx], metric='cosine').squeeze()

    for _ in range(1, n_keep):
        idx = np.argmax(distances)
        selected_idx.append(idx)
        new_distances = pairwise_distances(bvecs_nonzero, [bvecs_nonzero[idx]], metric='cosine').squeeze()
        distances = np.minimum(distances, new_distances)

    selected_idx_full = list(np.where((bvals >= 0) & (bvals <= 100))[0]) + [nonzero_idx[i] for i in selected_idx]

    img = nib.load(subj_dir / "data.nii.gz")
    data = img.get_fdata()

    subset_data = data[..., selected_idx_full]
    subset_bvals = bvals[selected_idx_full]
    subset_bvecs = bvecs[:, selected_idx_full]

    subset_img = nib.Nifti1Image(subset_data, img.affine)
    nib.save(subset_img, subj_dir / "dwi_subsampled.nii.gz")
    np.savetxt(subj_dir / "dwi_subsampled.bval", subset_bvals[np.newaxis, :], fmt='%.0f')
    np.savetxt(subj_dir / "dwi_subsampled.bvec", subset_bvecs, fmt='%.6f')

In [None]:
def get_streamlines_and_FA(subj_dir):
    DATA = subj_dir / "dwi_subsampled.nii.gz"
    BVAL = str(subj_dir / "dwi_subsampled.bval")
    BVEC = str(subj_dir / "dwi_subsampled.bvec")
    MASK = subj_dir / "nodif_brain_mask.nii.gz"

    data, affine = load_nifti(DATA)
    mask, _ = load_nifti(MASK)
    bvals,bvecs = read_bvals_bvecs(BVAL,BVEC)
    gtab = gradient_table(bvals,bvecs,b0_threshold=100)

    tensor_model = TensorModel(gtab)
    tensor_fit = tensor_model.fit(data, mask=mask)

    FA = tensor_fit.fa

    save_nifti(subj_dir / "fa_map.nii.gz", FA, affine)

    stopping_criterion = ThresholdStoppingCriterion(FA, 0.2)
    wm_mask = (FA>0.2) & (mask>0)

    response, ratio = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
    csd_model = ConstrainedSphericalDeconvModel(gtab, response)
    csd_fit = csd_model.fit(data, mask=mask)

    prob_dg = ProbabilisticDirectionGetter.from_shcoeff(csd_fit.shm_coeff, max_angle=30., sphere=default_sphere)
    seeds = utils.random_seeds_from_mask(wm_mask, affine)

    streamlines_generator = LocalTracking(
        prob_dg,
        stopping_criterion,
        seeds,
        affine,
        step_size=0.5
    )
    streamlines = Streamlines(tqdm(itertools.islice(streamlines_generator, 100000)))

    with open(str(subj_dir / "streamlines.pkl"), "wb") as f:
        pickle.dump(streamlines, f)

    return streamlines, affine

In [None]:
#aligns atlas
def align_images(subj_dir):
    moving_img = nib.load(subj_dir / "aparc+aseg.nii.gz")
    moving_data = moving_img.get_fdata()
    moving_affine = moving_img.affine

    fixed_img = nib.load(subj_dir / "fa_map.nii.gz")
    fixed_data = fixed_img.get_fdata()
    fixed_affine = fixed_img.affine

    #align center of mass
    c_of_mass = transform_centers_of_mass(fixed_data, fixed_affine,
                                      moving_data, moving_affine)
    transform_affine = c_of_mass.affine
    affine_transform_matrix = np.linalg.inv(fixed_affine) @ transform_affine @ moving_affine

    resampled = affine_transform(
        moving_data,
        matrix=np.linalg.inv(affine_transform_matrix[:3, :3]),
        offset=affine_transform_matrix[:3, 3],
        output_shape=fixed_data.shape,
        order=0  # NEAREST NEIGHBOR
    )

    resampled_img = nib.Nifti1Image(resampled, fixed_affine)
    nib.save(resampled_img, subj_dir / "aparc+aseg_in_fa_space.nii.gz")


In [None]:
def filter_streamlines_by_length(streamlines, min_len=10, max_len=250):
    lengths = list(length(streamlines))
    mask = np.logical_and(np.array(lengths) > min_len,
                          np.array(lengths) < max_len)
    return streamlines[mask]

In [None]:
def generate_graph(subj_dir):
    with open(subj_dir / "streamlines.pkl", "rb") as f:
        streamlines = pickle.load(f)
    streamlines = Streamlines(streamlines)

    roi_img = nib.load(subj_dir / "aparc+aseg_in_fa_space.nii.gz")
    roi_data = roi_img.get_fdata().astype(int)
    roi_affine = roi_img.affine

    fa_img = nib.load(subj_dir / "fa_map.nii.gz")
    fa_data= fa_img.get_fdata()

    streamlines = filter_streamlines_by_length(streamlines)
    qb = QuickBundles(threshold=10.0)
    clusters = qb.cluster(streamlines)
    filtered_clusters = [c for c in clusters if len(c) > 20]

    print("Number of clusters:", len(filtered_clusters))

    #graph construction

    total_streamlines = sum(len(c) for c in filtered_clusters)
    cluster_feats = []
    cluster_roi_edges = defaultdict(set)

    for idx, cluster in enumerate(filtered_clusters):
        cluster_streamlines = [streamlines[i] for i in cluster.indices]
        all_points = np.concatenate(cluster_streamlines)

        fa_vals = values_from_volume(fa_data, cluster_streamlines, roi_affine)
        fa_vals = np.concatenate(fa_vals)
        mean_fa = np.mean(fa_vals)
        pos = len(cluster.indices) / total_streamlines

        cluster_feats.append([mean_fa, pos])

        voxel_coords = np.round(nib.affines.apply_affine(np.linalg.inv(roi_affine), all_points)).astype(int)
        valid_mask = (
            (voxel_coords[:, 0] >= 0) & (voxel_coords[:, 0] < roi_data.shape[0]) &
            (voxel_coords[:, 1] >= 0) & (voxel_coords[:, 1] < roi_data.shape[1]) &
            (voxel_coords[:, 2] >= 0) & (voxel_coords[:, 2] < roi_data.shape[2])
        )
        voxel_coords = voxel_coords[valid_mask]
        roi_labels = roi_data[voxel_coords[:, 0], voxel_coords[:, 1], voxel_coords[:, 2]]
        unique_roi_labels = np.unique(roi_labels)
        for roi in unique_roi_labels:
            if roi != 0:
                cluster_roi_edges[idx].add(roi)

    roi_labels = np.unique(roi_data)
    roi_labels = roi_labels[roi_labels != 0]

    roi_label_to_index = {label: idx for idx, label in enumerate(roi_labels)} # turn into consecutive indices

    one_hot_size = len(roi_labels)
    roi_feats = []

    for label in roi_labels:
        roi_mask = roi_data == label
        fa_vals = fa_data[roi_mask]
        mean_fa = np.mean(fa_vals)

        # One-hot vector using remapped index
        one_hot = np.zeros(one_hot_size)
        one_hot[roi_label_to_index[label]] = 1

        # Combine one-hot with mean FA
        feature = np.concatenate([one_hot, [mean_fa]])
        roi_feats.append(feature)

    cluster_feats = torch.tensor(cluster_feats, dtype=torch.float)
    roi_feats = torch.tensor(roi_feats, dtype=torch.float)

    data = HeteroData()

    data['cluster'].x = cluster_feats
    data['roi'].x = roi_feats

    cluster_to_roi_edges = []

    for cluster_idx, roi_label_set in cluster_roi_edges.items():
        for roi_label in roi_label_set:
            roi_idx = roi_label_to_index[roi_label]
            cluster_to_roi_edges.append([cluster_idx, roi_idx])

    edge_index = torch.tensor(cluster_to_roi_edges, dtype=torch.long).t().contiguous()
    data[('cluster', 'intersects', 'roi')].edge_index = edge_index
    reverse_edges = edge_index[[1, 0], :]
    data[('roi', 'intersects_rev', 'cluster')].edge_index = reverse_edges

    return data



In [None]:
required_files = [
    "data.nii.gz",
    "bvals",
    "bvecs",
    "nodif_brain_mask.nii.gz",
    "aparc+aseg.nii.gz"
]

delete_files = [
    "data.nii.gz",
    "bvals",
    "bvecs",
    "nodif_brain_mask.nii.gz",
    "dwi_subsampled.nii.gz",
    "dwi_subsampled.bval",
    "dwi_subsampled.bvec"
]

save_dir = Path('/content/drive/MyDrive/HCP Data')
raw_dir = save_dir / 'raw_data'
log_file = save_dir / 'logs/processed_subjects.txt'

raw_dir.mkdir(parents=True, exist_ok=True)
log_file.parent.mkdir(exist_ok=True)

if log_file.exists():
    with open(log_file) as f:
        processed = set(line.strip() for line in f)
else:
    processed = set()


df = pd.read_csv('/content/drive/MyDrive/HCP Data/behavioral_data.csv')
filtered_df = df[df['3T_dMRI_Compl'] == True]
subject_list = filtered_df['Subject'].astype(str).tolist()



for subj in subject_list:
    if subj in processed:
        continue

    print(f"Processing subject: {subj}")
    subj_dir = raw_dir / subj
    subj_dir.mkdir(exist_ok=True)

    for fname in required_files:
        key = f"HCP_1200/{subj}/T1w/Diffusion/{fname}"
        local_path = subj_dir / fname
        s3.download_file(bucket, key, str(local_path))

    downsample(subj_dir)

    streamlines, affine = get_streamlines_and_FA(subj_dir)

    align_images(subj_dir)

    graph = generate_graph(subj_dir)

    label = df.loc[df['Subject'] == int(subj), 'Gender'].values[0]

    torch.save(graph, save_dir / f"gender_graphs/{label}/{subj}.pt")

    for fname in delete_files:
        file_path = subj_dir / fname
        open(file_path, 'wb').close()
        os.remove(file_path)


    with open(log_file, "a") as f:
        f.write(subj + "\n")


