**Setup, Imports, and Function Definitions**

In [1]:
import environ
import os

if os.getcwd().endswith('notebooks'):
    os.chdir('..')

from sqlalchemy_utils.db_session import get_session
from sqlalchemy_utils.models_sqlalchemy_orm import Subject, Symptom, Domain, Subdomain
import numpy as np
import nibabel as nib
from PIL import Image
import gzip
from io import BytesIO
import boto3
from botocore.client import Config
from nilearn.maskers import NiftiMasker

env = environ.Env()

DO_ACCESS_KEY_ID = env('DO_SPACES_ACCESS_KEY_ID')
DO_SECRET_ACCESS_KEY = env('DO_SPACES_SECRET_ACCESS_KEY')
DO_STORAGE_BUCKET_NAME = env('DO_SPACES_BUCKET_NAME')
DO_S3_ENDPOINT_URL = env('DO_SPACES_ENDPOINT_URL')
DO_SPACES_LOCATION = env('DO_SPACES_LOCATION', default='nyc3')
DO_LOCATION = env('DO_LOCATION')
DO_S3_CUSTOM_DOMAIN = f'{DO_STORAGE_BUCKET_NAME}.{DO_SPACES_LOCATION}.digitaloceanspaces.com'

def determine_filetype(filepath: str) -> str:
    """
    Determines the filetype of a file based on its extension.
    """
    filetype_checks = {
        'nii.gz': lambda f: f.endswith('.nii.gz'),
        'nii': lambda f: '.nii' in f and not f.endswith('.nii.gz'),
        'npy': lambda f: f.endswith('.npy'),
        'npz': lambda f: f.endswith('.npz'),
        'gii': lambda f: f.endswith('.gii') or '.gii' in f,
        'mgz': lambda f: f.endswith('.mgz'),
        'surf': lambda f: any(s in f.lower() for s in ['lh.', 'rh.', '.surf']),
        'label': lambda f: '.label' in f.lower(),
        'annot': lambda f: '.annot' in f.lower(),
        'fsaverage': lambda f: 'fsaverage' in f.lower(),
        'freesurfer': lambda f: any(s in f.lower() for s in ['aparc', 'aseg', 'bert', 'curv', 'sulc', 'thickness']),
        'png': lambda f: f.endswith('.png'),
        'jpg': lambda f: f.endswith('.jpg') or f.endswith('.jpeg'),
        'trk.gz': lambda f: f.endswith('.trk.gz'),
        'trk': lambda f: f.endswith('.trk'),
        'edge': lambda f: f.endswith('.edge'),
        'mat': lambda f: f.endswith('.mat'),
        'jpeg': lambda f: f.endswith('.jpeg'),
        'gif': lambda f: f.endswith('.gif'),
        'pdf': lambda f: f.endswith('.pdf'),
        'txt': lambda f: f.endswith('.txt'),
        'csv': lambda f: f.endswith('.csv'),
        'xls': lambda f: f.endswith('.xls'),
        'xlsx': lambda f: f.endswith('.xlsx')
    }

    for filetype, check in filetype_checks.items():
        if check(filepath):
            return filetype
    
    return 'unknown'

def get_s3_client():
    """Create and return a boto3 client configured for DigitalOcean Spaces"""
    env = environ.Env()
    
    session = boto3.session.Session()
    client = session.client('s3',
        config=Config(s3={'addressing_style': 'virtual'}),
        region_name=env('DO_SPACES_LOCATION', default='nyc3'),
        endpoint_url=env('DO_SPACES_ENDPOINT_URL'),
        aws_access_key_id=env('DO_SPACES_ACCESS_KEY_ID'),
        aws_secret_access_key=env('DO_SPACES_SECRET_ACCESS_KEY'),
    )
    
    return client

def fetch_from_s3(filepath):
    """
    Fetch and load files from DigitalOcean Spaces using boto3
    
    Parameters:
    filepath (str): Path to file within the bucket
    
    Returns:
    Various: Loaded file data depending on the file type
    """
    env = environ.Env()
    bucket_name = env('DO_SPACES_BUCKET_NAME')
    
    # Get S3 client
    s3_client = get_s3_client()
    
    # Get the file from S3
    try:
        response = s3_client.get_object(Bucket=bucket_name, Key=os.path.join(DO_LOCATION, filepath))
        file_data = response['Body'].read()
    except Exception as e:
        raise Exception(f"Error fetching file from S3: {str(e)}")
    
    # Determine file type and process accordingly
    extension = determine_filetype(filepath)
    
    if extension == 'nii.gz':
        fh = nib.FileHolder(fileobj=gzip.GzipFile(fileobj=BytesIO(file_data)))
        return nib.Nifti1Image.from_file_map({'header': fh, 'image': fh})
    
    elif extension == 'nii':
        fh = nib.FileHolder(fileobj=BytesIO(file_data))
        return nib.Nifti1Image.from_file_map({'header': fh, 'image': fh})
    
    elif extension == 'npy':
        return np.load(BytesIO(file_data), allow_pickle=True)
    
    elif extension in ['png', 'jpg', 'jpeg']:
        return Image.open(BytesIO(file_data))
    
    else:
        raise ValueError(f"Unsupported file type: {extension}")


def get_subjects_with_symptom_and_connectivity_files(session, symptom_name):
    """
    Retrieve all subjects that have a specific symptom and at least one connectivity file.

    Args:
        session (Session): The SQLAlchemy session to use for the query.
        symptom_name (str): The name of the symptom to filter subjects by.

    Returns:
        List[Subject]: A list of `Subject` instances matching the criteria.
    """
    subjects = (
        session.query(Subject)
        .join(Subject.symptoms)  # Join with the Symptom table
        .filter(Symptom.name == symptom_name)  # Filter by symptom name
        .filter(Subject.connectivity_files.any())  # Ensure at least one ConnectivityFile exists
        .all()
    )
    return subjects

def get_subjects_with_subdomain_and_connectivity_files(session, subdomain_name):
    """
    Retrieve all subjects that have a specific subdomain and at least one connectivity file.

    Args:
        session (Session): The SQLAlchemy session to use for the query.
        subdomain_name (str): The name of the subdomain to filter subjects by.

    Returns:
        List[Subject]: A list of `Subject` instances matching the criteria.
    """
    subjects = (
        session.query(Subject)
        .join(Subject.subdomains)  # Join with the Subdomain table
        .filter(Subdomain.name == subdomain_name)  # Filter by subdomain name
        .filter(Subject.connectivity_files.any())  # Ensure at least one ConnectivityFile exists
        .all()
    )
    return subjects

def get_subjects_with_domain_and_connectivity_files(session, domain_name):
    """
    Retrieve all subjects that have a specific domain and at least one connectivity file.

    Args:
        session (Session): The SQLAlchemy session to use for the query.
        domain_name (str): The name of the domain to filter subjects by.

    Returns:
        List[Subject]: A list of `Subject` instances matching the criteria.
    """
    subjects = (
        session.query(Subject)
        .join(Subject.domains)  # Join with the Domain table
        .filter(Domain.name == domain_name)  # Filter by domain name
        .filter(Subject.connectivity_files.any())  # Ensure at least one ConnectivityFile exists
        .all()
    )
    return subjects

def calculate_percent_overlap_at_threshold_for_symptom(symptom_name, threshold=5):
    """
    Calculate the percent overlap at T5 for a given symptom.
    Finds all lesion network files for subjects with the given symptom, and calculates the percent overlap at a given threshold.
    """

    session = get_session()

    subjects = get_subjects_with_symptom_and_connectivity_files(session, symptom_name)
    nifti_images = []
    for subject in subjects:
        for file in subject.connectivity_files:
            if file.path.endswith('nii.gz') or file.path.endswith('nii'):
                nifti_images.append(fetch_from_s3(f"{file.path}"))
                break

    masker = NiftiMasker(mask_img='static/images/MNI152_T1_2mm_brain_mask.nii.gz').fit()
    image_data = np.atleast_2d(masker.transform(nifti_images))

    image_data_thresholded_pos = np.where(image_data >= threshold, 1, 0)
    image_data_thresholded_neg = np.where(image_data <= -threshold, 1, 0)

    image_data_percent_pos = np.mean(image_data_thresholded_pos, axis=0) * 100
    image_data_percent_neg = -np.mean(image_data_thresholded_neg, axis=0) * 100

    combined_percent_overlap = np.where(np.abs(image_data_percent_pos) > np.abs(image_data_percent_neg), image_data_percent_pos, image_data_percent_neg)
    
    percent_overlap_image = masker.inverse_transform(combined_percent_overlap)
    session.close()
    return percent_overlap_image

def calculate_percent_overlap_at_threshold_for_subdomain(subdomain_name, threshold=5):
    """
    Calculate the percent overlap at T5 for a given subdomain.
    Finds all lesion network files for subjects with the given subdomain, and calculates the percent overlap at a given threshold.
    """

    session = get_session()

    subjects = get_subjects_with_subdomain_and_connectivity_files(session, subdomain_name)
    nifti_images = []
    for subject in subjects:
        for file in subject.connectivity_files:
            if file.path.endswith('nii.gz') or file.path.endswith('nii'):
                nifti_images.append(fetch_from_s3(f"{file.path}"))
                break

    masker = NiftiMasker(mask_img='static/images/MNI152_T1_2mm_brain_mask.nii.gz').fit()
    image_data = np.atleast_2d(masker.transform(nifti_images))

    image_data_thresholded_pos = np.where(image_data >= threshold, 1, 0)
    image_data_thresholded_neg = np.where(image_data <= -threshold, 1, 0)

    image_data_percent_pos = np.mean(image_data_thresholded_pos, axis=0) * 100
    image_data_percent_neg = -np.mean(image_data_thresholded_neg, axis=0) * 100

    combined_percent_overlap = np.where(np.abs(image_data_percent_pos) > np.abs(image_data_percent_neg), image_data_percent_pos, image_data_percent_neg)
    
    percent_overlap_image = masker.inverse_transform(combined_percent_overlap)
    session.close()
    return percent_overlap_image

def calculate_percent_overlap_at_threshold_for_domain(domain_name, threshold=5):
    """
    Calculate the percent overlap at T5 for a given domain.
    Finds all lesion network files for subjects with the given domain, and calculates the percent overlap at a given threshold.
    """

    session = get_session()

    subjects = get_subjects_with_domain_and_connectivity_files(session, domain_name)
    nifti_images = []
    for subject in subjects:
        for file in subject.connectivity_files:
            if file.path.endswith('nii.gz') or file.path.endswith('nii'):
                nifti_images.append(fetch_from_s3(f"{file.path}"))
                break

    masker = NiftiMasker(mask_img='static/images/MNI152_T1_2mm_brain_mask.nii.gz').fit()
    image_data = np.atleast_2d(masker.transform(nifti_images))

    image_data_thresholded_pos = np.where(image_data >= threshold, 1, 0)
    image_data_thresholded_neg = np.where(image_data <= -threshold, 1, 0)

    image_data_percent_pos = np.mean(image_data_thresholded_pos, axis=0) * 100
    image_data_percent_neg = -np.mean(image_data_thresholded_neg, axis=0) * 100

    combined_percent_overlap = np.where(np.abs(image_data_percent_pos) > np.abs(image_data_percent_neg), image_data_percent_pos, image_data_percent_neg)
    
    percent_overlap_image = masker.inverse_transform(combined_percent_overlap)
    session.close()
    return percent_overlap_image

**Specify Symptom Name and Output File Name**

In [2]:
symptom_name = 'ocd'

output_file_name = f'{symptom_name}_percent_overlap_at_t5.nii.gz'

percent_overlap_image = calculate_percent_overlap_at_threshold_for_symptom(symptom_name=symptom_name, threshold=5)

percent_overlap_image.to_filename(output_file_name)