Skip to content

Commit

Permalink
Merge pull request #2228 from SCIInstitute/2226-deepssm-meshes
Browse files Browse the repository at this point in the history
2226 deepssm meshes
  • Loading branch information
akenmorris committed Apr 6, 2024
2 parents 006888b + 37b8993 commit fcce488
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 101 deletions.
2 changes: 1 addition & 1 deletion Examples/Python/deep_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def Run_Pipeline(args):
with open(config_file, "w") as outfile:
json.dump(model_parameters, outfile, indent=2)
# Train
DeepSSMUtils.trainDeepSSM(config_file)
DeepSSMUtils.trainDeepSSM(project, config_file)
open(status_dir + "step_10.txt", 'w').close()

######################################################################################
Expand Down
4 changes: 2 additions & 2 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def prepareConfigFile(config_filename, model_name, embedded_dim, out_dir, loader
fine_tune_epochs, fine_tune_learning_rate)


def trainDeepSSM(config_file):
def trainDeepSSM(project, config_file):
testPytorch()
trainer.train(config_file)
trainer.train(project, config_file)
return


Expand Down
37 changes: 6 additions & 31 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/eval_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Jadie Adams
import os
import re
import shutil
import subprocess
import numpy as np

import shapeworks as sw
from shapeworks.utils import *


def get_distance_meshes(out_dir, DT_dir, prediction_dir, mean_prefix):
Expand Down Expand Up @@ -45,37 +43,14 @@ def get_prefix(path):
return prefix


def get_mesh_from_DT(DT_list, mesh_dir):
if not os.path.exists(mesh_dir):
os.makedirs(mesh_dir)
mesh_files = []
for input_file in DT_list:
print(' ' + get_prefix(input_file))
output_vtk = mesh_dir + "original_" + get_prefix(input_file) + ".vtk"
image = sw.Image(input_file)
image.toMesh(isovalue=0).write(output_vtk)
mesh_files.append(output_vtk)
return sorted(mesh_files)


def get_mesh(filename, iso_value=0):
if filename.endswith('.nrrd'):
image = sw.Image(filename)
return image.toMesh(iso_value)
else:
return sw.Mesh(filename)


def get_mesh_from_particles(particle_list, mesh_dir, template_particles, template_mesh, planes=None):
if not os.path.exists(mesh_dir):
os.makedirs(mesh_dir)

warp = sw.MeshWarper()

# Create mesh from file (mesh or segmentation)
sw_mesh = get_mesh(template_mesh)
sw_mesh = get_mesh_from_file(template_mesh)
sw_particles = np.loadtxt(template_particles)
warp.generateWarp(sw_mesh, sw_particles)
initialize_mesh_warper(sw_mesh, sw_particles)

particle_dir = os.path.dirname(particle_list[0]) + '/'

Expand All @@ -85,7 +60,7 @@ def get_mesh_from_particles(particle_list, mesh_dir, template_particles, templat
out_filename = out_filename.replace(particle_dir, mesh_dir)
out_mesh_filenames.append(out_filename)
sw_particles = np.loadtxt(particle_list[i])
out_mesh = warp.buildMesh(sw_particles)
out_mesh = reconstruct_mesh(sw_particles)
out_mesh.write(out_filename)

if planes is not None:
Expand Down Expand Up @@ -144,7 +119,7 @@ def get_mesh_distances(pred_particle_files, mesh_list, template_particles, templ
mean_distances.append(-1)
continue
print(f"Computing distance between {mesh_list[index]} and {pred_mesh_list[index]}")
orig_mesh = get_mesh(mesh_list[index], iso_value=0.5)
orig_mesh = get_mesh_from_file(mesh_list[index], iso_value=0.5)
if planes is not None:
orig_mesh.clip(planes[index][0], planes[index][1], planes[index][2])
pred_mesh = sw.Mesh(pred_mesh_list[index])
Expand Down
68 changes: 44 additions & 24 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/train_viz.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,50 @@
import os
import numpy as np
from shapeworks.utils import *


# Writes particles and error scalars for best, median, and worst
# pred_particles and true_particles are numpy array with dims: # in set, # of particles, 3 coordinates
def write_examples(pred_particles, true_particles, filenames, out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
# get min, mean, and max errors
mses = np.mean(np.mean((pred_particles - true_particles)**2, axis=2), axis=1)
median_index = np.argsort(mses)[len(mses)//2]
indices = [np.argmin(mses), median_index, np.argmax(mses)]
names = ["best", "median", "worst"]
for i in range(3):
# get particles
pred = pred_particles[indices[i]]
# write particle file
out_particle_file = out_dir + names[i] + ".particles"
np.savetxt(out_particle_file, pred)
# get scalar field for error
out_scalar_file = out_dir + names[i] + ".scalars"
scalars = np.mean((pred - true_particles[indices[i]])**2, axis=1)
np.savetxt(out_scalar_file, scalars)
# write index out to file as an integer
out_index_file = out_dir + names[i] + ".index"
f = open(out_index_file, "w")
f.write(filenames[indices[i]])
f.close()


if not os.path.exists(out_dir):
os.makedirs(out_dir)
# get min, mean, and max errors
mses = np.mean(np.mean((pred_particles - true_particles) ** 2, axis=2), axis=1)
median_index = np.argsort(mses)[len(mses) // 2]
indices = [np.argmin(mses), median_index, np.argmax(mses)]
names = ["best", "median", "worst"]
for i in range(3):
# get particles
pred = pred_particles[indices[i]]

# write particle file
out_particle_file = out_dir + names[i] + ".particles"
np.savetxt(out_particle_file, pred)

# get scalar field for error
out_scalar_file = out_dir + names[i] + ".scalars"
scalars = np.mean((pred - true_particles[indices[i]]) ** 2, axis=1)
np.savetxt(out_scalar_file, scalars)

# write index out to file as an integer
out_index_file = out_dir + names[i] + ".index"
f = open(out_index_file, "w")
f.write(filenames[indices[i]])
f.close()

# reconstruct mesh
mesh = reconstruct_mesh(pred)
# interpolate scalars to mesh

# reshape pred to be 1D
pred = pred.flatten()

# print type of pred
print(f"pred type: {type(pred)}")
print(f"pred shape: {pred.shape}")
print(f"scalars type: {type(scalars)}")
print(f"scalars shape: {scalars.shape}")

mesh.interpolate_scalars_to_mesh("deepssm_error", pred, scalars)
out_mesh_file = out_dir + names[i] + ".vtk"
mesh.write(out_mesh_file)
14 changes: 10 additions & 4 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from DeepSSMUtils import losses
from DeepSSMUtils import train_viz
from DeepSSMUtils import loaders
from shapeworks.utils import sw_message
from shapeworks.utils import sw_progress
from shapeworks.utils import sw_check_abort
import DeepSSMUtils
from shapeworks.utils import *

'''
Train helper
Expand Down Expand Up @@ -66,7 +65,14 @@ def set_scheduler(opt, sched_params):
return scheduler


def train(config_file):
def train(project, config_file):
subjects = project.get_subjects()
project_path = project.get_project_path() + "/"
reference_index = DeepSSMUtils.get_reference_index(project)
template_mesh = project_path + subjects[reference_index].get_groomed_filenames()[0]
template_particles = project_path + subjects[reference_index].get_local_particle_filenames()[0]
initialize_mesh_warper_from_files(template_mesh, template_particles)

with open(config_file) as json_file:
parameters = json.load(json_file)
if parameters["tl_net"]["enabled"]:
Expand Down
2 changes: 1 addition & 1 deletion Python/shapeworks/shapeworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .conversion import sw2vtkImage, sw2vtkMesh
from .plot import plot_meshes, plot_volumes, plot_meshes_volumes_mix, add_mesh_to_plotter, add_volume_to_plotter, plot_mesh_contour,plot_pca_metrics,\
pca_loadings_violinplot,plot_mode_line,visualize_reconstruction,lda_plot
from .utils import num_subplots, postive_factors, save_images, get_file_with_ext, find_reference_image_index, find_reference_mesh_index, load_mesh
from .utils import num_subplots, positive_factors, save_images, get_file_with_ext, find_reference_image_index, find_reference_mesh_index, load_mesh
from .data import get_file_list, sample_images, sample_meshes
from .stats import compute_pvalues_for_group_difference,lda
from .network_analysis import NetworkAnalysis
Expand Down
Loading

0 comments on commit fcce488

Please sign in to comment.