Skip to content

Commit

Permalink
adjusted code to also save masks for dsistudio folder for fa map etc
Browse files Browse the repository at this point in the history
  • Loading branch information
arefks committed Jan 23, 2024
1 parent 23427aa commit 8e7f9d3
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions bin/register_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import os
import glob
import csv

import nibabel as nib
import numpy as np
def main():
parser = argparse.ArgumentParser(description="Image Registration and Resampling")

Expand Down Expand Up @@ -89,6 +90,35 @@ def main():
print(f"Errors:\n{result.stderr.decode()}")
except Exception as e:
print(f"An error occurred: {e}")

# Adjust by flipping for the fa0 and ad files
# Locate all files in the "RegisteredTractMasks" folder
registered_masks_path = os.path.join(args.input,"**","RegisteredTractMasks", "*.nii.gz")
registered_masks_files = glob.glob(registered_masks_path, recursive=True)

for mask_file in registered_masks_files:
# Load the mask file
mask_nifti = nib.load(mask_file)
mask_data = mask_nifti.get_fdata()

# Flip the mask along the Y and Z axes
#flipped_mask = np.flip(np.flip(mask_data, axis=1), axis=2)
flipped_mask = np.flip(mask_data, axis=2)


# Determine the directory to save the adjusted mask
mask_dir = os.path.dirname(os.path.dirname(mask_file))
adjusted_masks_dir = os.path.join(mask_dir, "DSI_studio", "RegisteredTractMasks_adjusted")

# Create the directory if it doesn't exist
if not os.path.exists(adjusted_masks_dir):
os.makedirs(adjusted_masks_dir)

# Define the new file name and save the flipped mask
adjusted_mask_file = os.path.join(adjusted_masks_dir, os.path.basename(mask_file).replace(".nii.gz","_flipped.nii.gz"))
nib.save(nib.Nifti1Image(flipped_mask, mask_nifti.affine), adjusted_mask_file)

print("Flipping of RegisteredTractMasks and saving to RegisteredTractMasks_adjusted: DONE\n")

# Save the missing matrix paths to a CSV file
missing_matrix_csv_path = os.path.join(args.input, "missing_matrix_paths.csv")
Expand All @@ -104,4 +134,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit 8e7f9d3

Please sign in to comment.