In [None]:
import nibabel as nib
from nilearn import image, plotting
import nibabel.freesurfer as fs
from nilearn.image import resample_to_img
import matplotlib.colors as mcolors
from nibabel.freesurfer import read_annot

from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import pandas as pd

import numpy as np
import matplotlib.image as mpimg
import re
import datetime
import os
#from dotenv import load_dotenv
import flywheel

from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.utils import ImageReader

today = datetime.datetime.now()


In [None]:
import os
import subprocess

# Source zshrc and get the env var you want
def get_env_from_zshrc(var_name):
    command = f"source ~/.zshrc && echo ${var_name}"
    result = subprocess.run(['zsh', '-c', command], stdout=subprocess.PIPE, text=True)
    return result.stdout.strip()

api_key = get_env_from_zshrc('FW_CLI_API_KEY')
# Connect to your Flywheel instance

fw = flywheel.Client(api_key=api_key)
display(f"User: {fw.get_current_user().firstname} {fw.get_current_user().lastname}")

In [None]:
#api_key = os.environ.get('FW_CLI_API_KEY')
fw = flywheel.Client(api_key=api_key)

display(f"User: {fw.get_current_user().id}")

project = fw.projects.find_first('label=LONISAC')
display(f"Project: {project.label}")
project = project.reload()
#find the last run of orbit, and get the outlier output
import logging
log = logging.getLogger(__name__)


analyses = project.analyses

last_run_date_orbit = None
all_runs = []
run_states = []
orbit = None
for asys in analyses:

    try:

        all_runs.append(asys.reload().gear_info.get('name'))
        run_states.append(asys.reload().job.get('state'))

        gear_name = asys.gear_info.get('name')
        created_date = asys.created

        if gear_name == "orbit":
            last_run_date_orbit = max(last_run_date_orbit, created_date) if last_run_date_orbit else created_date
            orbit = asys

       
    except Exception as e:
        log.info(f'Exception caught:  {e}')

last_run_date_orbit = last_run_date_orbit.strftime('%Y-%m-%d') if last_run_date_orbit else None
print("Date ORBIT last ran: ", last_run_date_orbit)

for f in orbit.files:
    if f.name.endswith('outliers.csv'):
        outlier_file = f
        download_dir = Path("/Users/Hajer/unity/fw-notebooks/QC")
        download_dir.mkdir(parents=True,exist_ok=True)
        download_path = os.path.join(download_dir , outlier_file.name)
        outlier_file.download(download_path)
        break


df_outliers = pd.read_csv(download_path)
df_outliers.head(3)

In [None]:

import skimage
import plotly.express as px
from skimage.transform import resize

def preprocess_nifti(nifti_path, target_height=300):
    # load data
    data = nib.load(nifti_path).get_fdata()
    data = np.nan_to_num(data)
    xyz = data.shape

    # Normalise entire data arrays
    min_val, max_val = np.min(data), np.max(data)
    if max_val > min_val:
        data = (255 * (data - min_val) / (max_val - min_val))
    else:
        data = np.zeros_like(data)  # Black image if constant
        
    # Determine orientation and choose slices correctly
    if 'SAG' in nifti_path.upper():
        orientation = "sag"
        plane = (1, 2)
        ax = 0
    elif 'COR' in nifti_path.upper():
        orientation = "cor"
        plane = (0, 2)
        ax = 1
    elif 'AXI' in nifti_path.upper():
        orientation = "axi"
        plane = (0, 1)
        ax = 2
    
    # Resize the image to target height
    w, h = xyz[plane[0]], xyz[plane[1]]
    scale_factor = target_height / h
    new_width = int(w * scale_factor)
    
    # Create new shape and resize
    new_shape = list(xyz)
    new_shape[plane[0]], new_shape[plane[1]] = new_width, target_height
    data = skimage.transform.resize(data, new_shape, mode='constant', preserve_range=True)
    
    return data, plane[0], ax
        
def nifti_overlay_animation(native_path, seg_path, sub, outliers, target_height=300, alpha=0.4, cmap='viridis'):
    # Load NIfTI files
    native_img = nib.load(native_path)
    seg_img = nib.load(seg_path)

    native_data = native_img.get_fdata()
    seg_data = seg_img.get_fdata()

    # Rescale volumes to same shape (if needed)
    if native_data.shape != seg_data.shape:
        raise ValueError("Native and segmentation volumes must have the same shape.")

    # Normalize native scan to 0-255 (grayscale)
    native_data = native_data - np.min(native_data)
    native_data = (native_data / np.max(native_data)) * 255
    native_data = native_data.astype(np.uint8)

    # Resize to target height (optional)
    scale = target_height / native_data.shape[0]
    new_shape = (target_height, int(native_data.shape[1]*scale), native_data.shape[2])
    native_data_resized = resize(native_data, new_shape, preserve_range=True, order=1).astype(np.uint8)
    seg_data_resized = resize(seg_data, new_shape, preserve_range=True, order=0).astype(np.uint8)

    # Create overlay for each slice
    overlays = []
    for i in range(native_data_resized.shape[2]):
        background = np.stack([native_data_resized[:, :, i]]*3, axis=-1)  # Convert to RGB
        mask = seg_data_resized[:, :, i]

        # Apply colormap to mask
        mask_rgb = plt.get_cmap(cmap)(mask)[:, :, :3]  # Drop alpha channel
        mask_rgb = (mask_rgb * 255).astype(np.uint8)

        # Alpha blend
        overlay = background.copy()
        overlay[mask > 0] = ((1 - alpha) * background[mask > 0] + alpha * mask_rgb[mask > 0]).astype(np.uint8)

        overlays.append(overlay)

    # Convert to 4D array for plotly
    overlays_np = np.stack(overlays, axis=0)

    # Plotly animation
    print(f'Loading overlaid animation for {sub}, this may take a few seconds...')

    # num_slices = overlays_np.shape[0]
    # middle_index = num_slices // 2
    # rolled_data = np.roll(overlays_np, -middle_index, axis=0)

    # fig = px.imshow(rolled_data, animation_frame=0, aspect='equal')

    fig = px.imshow(overlays_np, animation_frame=0, aspect='equal')
    fig.update_layout(
    title={
        'text': f"<b>⚠️ Outlier ROIs (z-score and cov):</b><br>{outliers}",
        'x': 0.5,
        'xanchor': 'center',
        'yanchor': 'top',
        'font': dict(size=10)
    },
    autosize=True,
    height=600,
    coloraxis_showscale=False
)
    
    fig.update_layout(autosize=True, height=600, coloraxis_showscale=False)
    fig.update_layout(xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False))
    fig.update_layout(dragmode=False)

    # num_slices = overlays_np.shape[0]
    # middle_index = num_slices // 2

    # fig.data[0].z = overlays_np[middle_index]
    # fig.layout.sliders[0]['active'] = middle_index
    # fig.layout.updatemenus[0]['buttons'][0]['args'][1]['frame']['redraw'] = True

    # Show the plot
    fig.show()

def get_data(sub_label, ses_label, gear, input_gear, v):
    
    from fw_client import FWClient
    #api_key = os.environ.get('FW_CLI_API_KEY')
    fw_ = FWClient(api_key=api_key)

    subject = project.subjects.find_first(f'label="{sub_label}"')
    subject = subject.reload()
    sub_label = subject.label
    
    session = subject.sessions.find_first(f'label="{ses_label}"')
    session = session.reload()
    ses_label = session.label
    print(gear, input_gear)

    analyses = session.analyses

    # If there are no analyses containers, we know that this gear was not run
    if len(analyses) == 0:
        run = 'False'
        status = 'NA'
        print('No analysis containers')
    else:
        try:
            if input_gear.startswith("gambas"):
                gear = gear + "-gambas"
                print("Looking for anaylyses from ", gear)
            matches = [asys for asys in analyses if asys.label.startswith(gear) and asys.job.get('state') == "complete"]

            print("Matches: ", len(matches),[asys.label for asys in matches] )
            # If there are no matches, the gear didn't run
            if len(matches) == 0:
                run = 'False'
                status = 'NA'
                print(f"Did not find any matched, {gear} did not run.")
            # If there is one match, that's our target
            elif len(matches) == 1:
                run = 'True'
                #status = matches[0].job.get('state')
                #print(status)
                #print("Inputs ", matches[0])
                print("All files...", len(matches[0].files), [f.name for f in matches[0].files])
                for file in matches[0].files:  
                   
                    if file.name.endswith('.nii.gz') : #or re.search(rf'{input_gear}.*\.nii.gz', file.name):

                    #if file.name.endswith('aparc+aseg.nii.gz') or file.name.endswith('synthSR.nii.gz') or re.search('ResCNN.*\.nii.gz', file.name) or re.search('mrr_fast.*\.nii.gz', file.name) or re.search('mrr-axireg.*\.nii.gz', file.name) or re.search('.*\.zip', file.name):
                        parcellation = file
                        print("Found ", file.name)
                        if file :
                            download_dir = Path(f'./QC/{sub_label}/{ses_label}/')
                            download_dir.mkdir(parents=True,exist_ok=True)
                            download_path = os.path.join(download_dir , parcellation.name)
                            
                            parcellation.download(download_path)
                            print('Downloaded parcellation ',download_path)
                            input_file = matches[0].inputs[0]

                            print("Type of input file", type(input_file))  # Should be something like flywheel.models.file.FileReference
                            
                            
                            input_file=matches[0].inputs[0]
                            print(f"N inputs for {matches[0].label}", len(matches[0].inputs), input_file.name)
                            download_path = os.path.join(download_dir , input_file.name)
                            fw.download_input_from_analysis(matches[0].id, input_file.name, download_path)
                           
                            return input_file



            else:
                last_run_date = max([asys.created for asys in matches])
                last_run_analysis = [asys for asys in matches if asys.created == last_run_date]

                # There should only be one exact match
                last_run_analysis = last_run_analysis[0]

                run = 'True'
                #status = last_run_analysis.job.get('state')
                print("All files...", len(last_run_analysis.files), [f.name for f in last_run_analysis.files])
                for file in last_run_analysis.files:
                    if file.name.endswith('.nii.gz'): #or re.search(rf'{input_gear}.*\.nii.gz', file.name):
                    #if file.name.endswith('aparc+aseg.nii.gz') or file.name.endswith('synthSR.nii.gz') or re.search('ResCNN.*\.nii.gz', file.name) or re.search('mrr_fast.*\.nii.gz', file.name) or re.search('mrr-axireg.*\.nii.gz', file.name) or re.search('.*\.zip', file.name):
                        parcellation = file
                        print("Found ", file.name)
                        if file :
                            download_dir = Path(f'./QC/{sub_label}/{ses_label}/')
                            download_dir.mkdir(parents=True,exist_ok=True)
                            download_path = os.path.join(download_dir , parcellation.name)
                            parcellation.download(download_path)
                            print('Downloaded parcellation ',download_path)
                            print("Type of parcellation file" , type(parcellation))  # Should be something like flywheel.models.file.FileReference
                            # print(vars(parcellation))

                            input_file=last_run_analysis.inputs[0]
                            print(f"N inputs for {last_run_analysis.label}", len(last_run_analysis.inputs), input_file.name)
                            download_path = os.path.join(download_dir , input_file.name)
                            #input_file_obj = fw.files.find_one(f'parent_ref.id={input_file.id},name={input_file.name}')
                            fw.download_input_from_analysis(last_run_analysis.id, input_file.name, download_path)
                            
                            return input_file
        except Exception as e:
            print(f"Exception caught for {sub_label} {ses_label} {parcellation.name}: ", e)
            

df_outliers = pd.read_csv("/Users/Hajer/unity/fw-notebooks/QC/minimorph_outliers.csv")

project.label
# gear = "mrr-axireg"
# subjects = project.subjects()
# get_data(gear, subjects)
for _, row in df_outliers.iterrows():
    input_gear, v = row["input gear v"].split("/")[0], row["input gear v"].split("/")[1]
    sub_label, ses_label = row["subject"],row["session"]
    gear = "minimorph"
    input_file= get_data(sub_label, ses_label, gear, input_gear, v)
metrics =["Are the main brain structures correctly segmented (e.g., gray/white matter, ventricles)? (Yes/No)",
          "Are there any major errors or artifacts? (No/Minor/Major)",
          "Overall segmentation quality (Good/Acceptable/Poor)", 
          "Comments"]


def load_ratings(RATINGS_FILE):
    if os.path.exists(RATINGS_FILE):
        return pd.read_csv(RATINGS_FILE)
    return pd.DataFrame(columns=["User", "Timestamp", "Project", "Subject", "Session"] + metrics)
    
# Function to simplify acquisition labels
def simplify_label(label):
    # Initialize empty result
    result = []
    
    # Check for orientation
    if 'AXI' in label.upper():
        result.append('AXI')
    elif 'COR' in label.upper():
        result.append('COR')
    elif 'SAG' in label.upper():
        result.append('SAG')
        
    elif 'Localizer' in label:
        result.append('LOC')
        
    # Check for T1/T2
    if 'T1' in label.upper():
        result.append('T1')
    elif 'T2' in label.upper():
        result.append('T2')
        
    # Check for fast vs. standard labels
    if 'FAST' in label.upper():
        result.append('FAST')
    elif 'STANDARD' in label.upper():
        result.append('STANDARD')
        
    # Check for Gray_White
    if 'Gray' in label.upper():
        result.append('GrayWhite')
        
    # Return combined result or original label if no matches
    return '_'.join(result) if result else label

def save_rating(ratings_file, responses):
    df = load_ratings(ratings_file)
    new_entry = pd.DataFrame([responses], 
                              columns=df.columns)
    
    df = pd.concat([df, new_entry], ignore_index=True)
    # df.loc[:, 'Acquisition'] = df['Acquisition'].apply(simplify_label)
    
    df.to_csv(ratings_file, index=False)
    print(f"\nSaved rating: {ratings_file}")
    
    
    
def find_csv_file(directory, username):
    username_cleaned = username.replace(" ", "")
    
    for root, _, files in os.walk(directory):
        for file in files:
            if username_cleaned in file and file.endswith(".csv"):
                return os.path.join(root, file)  # Return the first matching file found

    return None  # No matching file found

In [None]:
username=input(f"Enter your username: ")

download_dir = "/Users/Hajer/unity/GF Sprint25/QC"


In [None]:
from tqdm import tqdm
from datetime import datetime
from IPython.display import display, clear_output


timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

pbar = tqdm(df_outliers.subject, desc="Subjects evaluated", position=0, leave=True)

exit_ = False
for _, row in df_outliers.iterrows():
    input_gear, v = row["input gear v"].split("/")[0], row["input gear v"].split("/")[1]
    sub_label, ses_label, project = row["subject"],row["session"], row["project_label"]
    native_scan , segmentation = None, None
    files = os.listdir(path=f"{download_dir}/{sub_label}/{ses_label}")

    for file in files:
        if file.endswith('segmentation.nii.gz'):
            segmentation_path = os.path.join(f"{download_dir}/{sub_label}/{ses_label}",file)
        else:
            native_scan_path = os.path.join(f"{download_dir}/{sub_label}/{ses_label}",file)

    native_scan = nib.load(native_scan_path)        
    segmentation = nib.load(segmentation_path)


    if native_scan and segmentation: #if files are not empty
                
        # Resample the segmentation to match the native scan space if needed
        resampled_segmentation = image.resample_to_img(segmentation, native_scan, interpolation="nearest")
        parc_data = resampled_segmentation.get_fdata()

        # Get unique labels (excluding 0 which is background)
        unique_labels = np.unique(parc_data)
        unique_labels = unique_labels[unique_labels > 0]  # Remove background

        # Generate distinct colors for each region
        num_labels = len(unique_labels)
        colors = plt.cm.tab20(np.linspace(0, 1, num_labels))  # Use tab20 colormap for categorical data

        # Create a colormap that maps each label to a color
        label_to_color = {label: colors[i] for i, label in enumerate(unique_labels)}
        cmap_list = [label_to_color.get(label, (0, 0, 0, 0)) for label in range(int(unique_labels.max()) + 1)]  # Ensure mapping
        cmap = mcolors.ListedColormap(cmap_list)

       

        #Now the user fills in the responses

        responses = [username, timestamp, project, sub_label, ses_label] 

        # Have user input evaluation of image quality
        print("\n")
        
        df_outliers.drop(columns=['is_outlier'])
        filtered_cols = [col for col in df_outliers.columns if (
                                                (col.endswith("_zscore") or col.endswith("_cov"))
                                                and not col.startswith("n_roi_outliers")
                                            )]
        # 2. From those, keep only columns where the value is 1
        outlier_rois = [col for col in filtered_cols if row[col] == 1]
        
        print("Outlier regions for this subject (using z score and cov):")
        print(f"{set(outlier_rois)}")
        outliers = list(set(outlier_rois))
        nifti_overlay_animation(
            native_path=native_scan_path,
            seg_path=segmentation_path,
            sub=sub_label, 
            outliers=outliers,
            target_height=300,
            alpha=0.4,
            cmap='magma'  # Any matplotlib colormap works
        )

        

        for metric in metrics[:-1]:
            response = ""
            while response.upper() not in ["YES", "NO", "ACCEPTABLE", "GOOD", "POOR","MAJOR","MINOR"]:
                print('Answer with one of the options listed in the question.\n')
                response =input(f"{metric}:")


            # Store responses for paired image set
            responses.append(response)

        # Have user input any comments to elaborate on artefact presence
        comments =input(f"   Any {metrics[-1]}: ")
        responses.append(comments)
        # responses.append(timestamp)

        print(responses)
        ratings_file = os.path.join(download_dir,f'Parcellation_QC_{username.replace(" ","")}.csv')
        save_rating(ratings_file,responses)

        # if i < len(metrics) - 2:  # Only ask if not the last metric
        should_continue = input("Do you want to continue to the next record? (Y/n): ").strip().lower()
        if should_continue == "n":
            print("Exiting after current response.")
            break

        clear_output(wait=True)
        tqdm.write(" ")