Skip to content

Commit

Permalink
Resolve #2230 - DeepSSM Augmentation should write VTK files in additi…
Browse files Browse the repository at this point in the history
…on to particles
  • Loading branch information
akenmorris committed Apr 9, 2024
1 parent fcce488 commit b059a68
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 32 deletions.
6 changes: 3 additions & 3 deletions Examples/Python/deep_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def Run_Pipeline(args):
DeepSSMUtils.groom_training_shapes(project)
project.save(spreadsheet_file)

reference_index = DeepSSMUtils.get_reference_index(project)
reference_index = sw.utils.get_reference_index(project)
print("Reference index: " + str(reference_index))
# print reference mesh name
print("Reference mesh: " + project_path + project.get_subjects()[reference_index].get_original_filenames()[0])
Expand Down Expand Up @@ -458,7 +458,7 @@ def Run_Pipeline(args):
'''
mean_MSE, std_MSE = DeepSSMUtils.analyzeMSE(predicted_val_world_particles, val_world_particles)
print("Validation world particle MSE: " + str(mean_MSE) + " +- " + str(std_MSE))
reference_index = DeepSSMUtils.get_reference_index(project)
reference_index = sw.utils.get_reference_index(project)
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]
# Get distance between clipped true and predicted meshes
Expand Down Expand Up @@ -520,7 +520,7 @@ def Run_Pipeline(args):
clipped true mesh and clipped mesh generated from predicted local particles
'''

reference_index = DeepSSMUtils.get_reference_index(project)
reference_index = sw.utils.get_reference_index(project)
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from shapeworks.utils import sw_message
from shapeworks.utils import sw_progress
from shapeworks.utils import sw_check_abort
import shapeworks as sw

################################# Augmentation Pipelines ###############################################

Expand Down Expand Up @@ -59,9 +60,13 @@ def point_based_aug(out_dir, orig_img_list, orig_point_list, num_samples, num_di
gen_image_dir = out_dir + "Generated-Images/"
if not os.path.exists(gen_image_dir):
os.makedirs(gen_image_dir)
get_mesh_dir = out_dir + "Generated-Meshes/"
if not os.path.exists(get_mesh_dir):
os.makedirs(get_mesh_dir)
gen_embeddings = []
gen_points_paths = []
gen_image_paths = []
gen_mesh_paths = []
if processes != 1:
generate_image_params_list = []
# Sample to generate new examples
Expand All @@ -87,6 +92,9 @@ def point_based_aug(out_dir, orig_img_list, orig_point_list, num_samples, num_di
gen_points_path = gen_point_dir + name + ".particles"
np.savetxt(gen_points_path, gen_points)
gen_points_paths.append(gen_points_path)
# Generate mesh
gen_mesh_path = get_mesh_dir + name + ".vtk"
sw.utils.reconstruct_mesh(gen_points).write(gen_mesh_path)
# Generate image
base_image_path = orig_img_list[base_index]
base_particles_path = orig_point_list[base_index]
Expand Down
2 changes: 1 addition & 1 deletion Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from DeepSSMUtils import image_utils
from DeepSSMUtils import run_utils

from .run_utils import create_split, groom_training_shapes, groom_training_images, get_reference_index, \
from .run_utils import create_split, groom_training_shapes, groom_training_images, \
run_data_augmentation, groom_val_test_images, prep_project_for_val_particles, groom_validation_shapes, \
prepare_data_loaders, get_deepssm_dir, get_split_indices, optimize_training_particles, process_test_predictions

Expand Down
28 changes: 6 additions & 22 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import shapeworks as sw
from bokeh.util.terminal import trace
from shapeworks.utils import sw_message
from shapeworks.utils import sw_progress
from shapeworks.utils import sw_check_abort
from shapeworks.utils import sw_message, sw_progress, sw_check_abort

import DataAugmentationUtils
import DeepSSMUtils
Expand Down Expand Up @@ -125,21 +123,6 @@ def prep_project_for_val_particles(project):
project.set_subjects(subjects)


def get_reference_index(project):
""" Get the index of the reference subject chosen by grooming alignment."""
params = project.get_parameters("groom")
reference_index = params.get("alignment_reference_chosen")
return int(reference_index)


def get_image_filename(subject):
""" Get the image filename for a subject. """
image_map = subject.get_feature_filenames()
# get the first image
image_name = list(image_map.values())[0]
return image_name


def get_deepssm_dir(project):
""" Get the directory for deepssm data"""
project_path = project.get_project_path()
Expand Down Expand Up @@ -197,8 +180,8 @@ def groom_training_images(project):

deepssm_dir = get_deepssm_dir(project)

ref_index = get_reference_index(project)
ref_image = sw.Image(get_image_filename(subjects[ref_index]))
ref_index = sw.utils.get_reference_index(project)
ref_image = sw.Image(sw.utils.get_image_filename(subjects[ref_index]))
ref_mesh = sw.utils.load_mesh(subjects[ref_index].get_groomed_filenames()[0])

# apply alignment transform
Expand Down Expand Up @@ -248,7 +231,7 @@ def groom_training_images(project):
sw_message("Aborted")
return

image_name = get_image_filename(subjects[i])
image_name = sw.utils.get_image_filename(subjects[i])
sw_progress(i / (len(subjects) + 1), f"Grooming Training Image: {image_name}")
image = sw.Image(image_name)
subject = subjects[i]
Expand All @@ -273,6 +256,7 @@ def groom_training_images(project):

def run_data_augmentation(project, num_samples, num_dim, percent_variability, sampler, mixture_num=0, processes=1):
""" Run data augmentation on the training images. """
sw.utils.initialize_project_mesh_warper(project)
deepssm_dir = get_deepssm_dir(project)
aug_dir = deepssm_dir + "augmentation/"

Expand Down Expand Up @@ -381,7 +365,7 @@ def groom_val_test_images(project, indices):
sw_message("Aborted")
return

image_name = get_image_filename(subjects[i])
image_name = sw.utils.get_image_filename(subjects[i])
sw_progress(count / (len(val_test_indices) + 1),
f"Grooming val/test image {image_name} ({count}/{len(val_test_indices)})")
count = count + 1
Expand Down
7 changes: 1 addition & 6 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ def set_scheduler(opt, sched_params):


def train(project, config_file):
subjects = project.get_subjects()
project_path = project.get_project_path() + "/"
reference_index = DeepSSMUtils.get_reference_index(project)
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]
initialize_mesh_warper_from_files(template_mesh, template_particles)
sw.utils.initialize_project_mesh_warper(project)

with open(config_file) as json_file:
parameters = json.load(json_file)
Expand Down
33 changes: 33 additions & 0 deletions Python/shapeworks/shapeworks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,24 @@ def initialize_mesh_warper_from_files(template_mesh_filename, template_particles
initialize_mesh_warper(sw_mesh, sw_particles)


def initialize_project_mesh_warper(project):
"""
This function initializes a MeshWarper object using the template mesh and particles from a given project.
Parameters:
project (shapeworks.Project): The project to be used for mesh warping.
Returns:
None
"""
subjects = project.get_subjects()
project_path = project.get_project_path() + "/"
reference_index = sw.utils.get_reference_index(project)
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]
initialize_mesh_warper_from_files(template_mesh, template_particles)


def reconstruct_mesh(particles):
"""
This function uses the global MeshWarper object to build a mesh from a given set of particles.
Expand Down Expand Up @@ -438,3 +456,18 @@ def get_mesh_from_file(filename, iso_value=0):
return image.toMesh(iso_value)
else:
return sw.Mesh(filename)


def get_reference_index(project):
""" Get the index of the reference subject chosen by grooming alignment."""
params = project.get_parameters("groom")
reference_index = params.get("alignment_reference_chosen")
return int(reference_index)


def get_image_filename(subject):
""" Get the image filename for a subject. """
image_map = subject.get_feature_filenames()
# get the first image
image_name = list(image_map.values())[0]
return image_name

0 comments on commit b059a68

Please sign in to comment.