Skip to content

Commit

Permalink
DeepSSM input verification (#2196)
Browse files Browse the repository at this point in the history
* Throw exception when number of original domains doesn't match.

* When images or original files are not present, disallow deepssm mode

* Throw useful exception when image size will be too small.

* Add timing to steps of prep.

* Add one more timer

* Set Eigen threads to the same as TBB.

This came about due to a bug or unexpected interaction with calling DeepSSM commands from Studio.  After the image registration stuff runs, the OpenMP thread count is set to a crazy high level (e.g. 384 on a 16-core machine), which results in poor eigen performance per this page:

https://eigen.tuxfamily.org/dox/TopicMultiThreading.html

"Warning: On most OS it is very important to limit the number of threads to the number of physical cores, otherwise significant slowdowns are expected, especially for operations involving dense matrices."

* Add Fine Tuning plot, combine csv

* Fix logger being closed for FT

* Allow decay LR to be off

* Fix formatting

* Reinitialze optimizer for fine tuning, use fine tuning learning rate

* Parse table doubles and limit precision for display.

* Fix table display (digits)

* Update parameter names and tooltips

* Reorganize DeepSSM Prep dialog a bit.
Change percent variability to be consistent
Add read-only training percent to show the user the amount that will be used.

* Update screenshot.

* Improve test mesh loading

* Fix TL-DeepSSM when Decay Learning is off.

* Replace shapeworks executable usage with Python API of MeshWarper.

* Fix typos

* Simplify and improve get_mesh_distance, also write back distance field to prediction mesh.

* Add more room for image spacing.

* Add ability to set reconstructed meshes directly.

* Shift test mesh distance calc and results into Python from Studio.

* Fix compile

* Fix image name for train/test when image is not already selected.

* Fix problem with boost create_directories
  • Loading branch information
akenmorris committed Feb 6, 2024
1 parent acb2f40 commit 077c588
Show file tree
Hide file tree
Showing 20 changed files with 1,337 additions and 1,013 deletions.
11 changes: 6 additions & 5 deletions Examples/Python/deep_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def Run_Pipeline(args):
"model_name": model_name,
"num_latent_dim": int(embedded_dim),
"paths": {
"out_dir": output_directory,
"out_dir": deepssm_dir,
"loader_dir": loader_dir,
"aug_dir": aug_dir
},
Expand Down Expand Up @@ -446,7 +446,7 @@ def Run_Pipeline(args):
predicted_val_local_particles = []
for particle_file, transform in zip(predicted_val_world_particles, val_transforms):
particles = np.loadtxt(particle_file)
local_particle_file = particle_file.replace("FT_Predictions/", "local_predictions/")
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
np.savetxt(local_particle_file, local_particles)
predicted_val_local_particles.append(local_particle_file)
Expand All @@ -468,8 +468,6 @@ def Run_Pipeline(args):

print("Validation mean mesh surface-to-surface distance: " + str(mean_dist))

# If tiny test or verify, check results and exit
check_results(args, mean_dist)
open(status_dir + "step_11.txt", 'w').close()

######################################################################################
Expand Down Expand Up @@ -512,7 +510,7 @@ def Run_Pipeline(args):
predicted_test_local_particles = []
for particle_file, transform in zip(predicted_test_world_particles, test_transforms):
particles = np.loadtxt(particle_file)
local_particle_file = particle_file.replace("FT_Predictions/", "local_predictions/")
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
np.savetxt(local_particle_file, local_particles)
predicted_test_local_particles.append(local_particle_file)
Expand All @@ -530,6 +528,9 @@ def Run_Pipeline(args):
template_particles, template_mesh, test_out_dir,
planes=test_planes)
print("Test mean mesh surface-to-surface distance: " + str(mean_dist))

# If tiny test or verify, check results and exit
check_results(args, mean_dist)
open(status_dir + "step_12.txt", 'w').close()

print("All steps complete")
Expand Down
5 changes: 4 additions & 1 deletion Libs/Analyze/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ MeshGroup Shape::get_reconstructed_meshes(bool wait) {
return reconstructed_meshes_;
}

//---------------------------------------------------------------------------
void Shape::set_reconstructed_meshes(MeshGroup meshes) { reconstructed_meshes_ = meshes; }

//---------------------------------------------------------------------------
void Shape::reset_groomed_mesh() { groomed_meshes_ = MeshGroup(subject_->get_number_of_domains()); }

Expand Down Expand Up @@ -572,7 +575,7 @@ std::shared_ptr<Image> Shape::get_image_volume(std::string image_volume_name) {
std::shared_ptr<Image> image = std::make_shared<Image>(filename);
image_volume_ = image;
image_volume_filename_ = filename;
} catch (std::exception &ex) {
} catch (std::exception& ex) {
SW_ERROR("Unable to open file: {}", filename);
}
}
Expand Down
3 changes: 3 additions & 0 deletions Libs/Analyze/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class Shape {
//! Retrieve the reconstructed meshes
MeshGroup get_reconstructed_meshes(bool wait = false);

//! Set the reconstructed meshes
void set_reconstructed_meshes(MeshGroup meshes);

//! Reset the groomed mesh so that it will be re-created
void reset_groomed_mesh();

Expand Down
1 change: 1 addition & 0 deletions Libs/Common/ShapeworksUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ void ShapeWorksUtils::setup_threads() {
num_threads = std::max(1, atoi(num_threads_env));
}
SW_DEBUG("TBB using {} threads", num_threads);
Eigen::setNbThreads(num_threads);
tbb::global_control c(tbb::global_control::max_allowed_parallelism, num_threads);
}

Expand Down
9 changes: 7 additions & 2 deletions Libs/Groom/Groom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ int Groom::get_total_ops() {
for (int i = 0; i < domains.size(); i++) {
auto params = GroomParameters(project_, domains[i]);

if (project_->get_original_domain_types().size() <= i) {
throw std::runtime_error("invalid domain, number of original file types does not match number of domains");
}

if (project_->get_original_domain_types()[i] == DomainType::Image) {
num_tools += params.get_isolate_tool() ? 1 : 0;
num_tools += params.get_fill_holes_tool() ? 1 : 0;
Expand Down Expand Up @@ -526,7 +530,6 @@ bool Groom::run_alignment() {
std::vector<Mesh> reference_meshes;
std::vector<Mesh> meshes;
for (size_t i = 0; i < subjects.size(); i++) {

if (!subjects[i]->is_excluded()) {
Mesh mesh = get_mesh(i, domain, true);
// if fixed subjects are present, only add the fixed subjects
Expand Down Expand Up @@ -711,7 +714,9 @@ std::string Groom::get_output_filename(std::string input, DomainType domain_type
path = base + "/" + prefix;

try {
boost::filesystem::create_directories(path);
if (!boost::filesystem::exists(path)) {
boost::filesystem::create_directories(path);
}
} catch (std::exception& e) {
throw std::runtime_error("Unable to create groom output directory: \"" + path + "\"");
}
Expand Down
4 changes: 3 additions & 1 deletion Libs/Image/Image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ Image& Image::write(const std::string& filename, bool compressed) {

// if the directory doesn't exist, create it
boost::filesystem::path dir(filename);
boost::filesystem::create_directories(dir.parent_path());
if (dir.has_parent_path() && !boost::filesystem::exists(dir.parent_path())) {
boost::filesystem::create_directories(dir.parent_path());
}

using WriterType = itk::ImageFileWriter<ImageType>;
WriterType::Pointer writer = WriterType::New();
Expand Down
2 changes: 1 addition & 1 deletion Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .run_utils import create_split, groom_training_shapes, groom_training_images, get_reference_index, \
run_data_augmentation, groom_val_test_images, prep_project_for_val_particles, groom_validation_shapes, \
prepare_data_loaders, get_deepssm_dir, get_split_indices, optimize_training_particles
prepare_data_loaders, get_deepssm_dir, get_split_indices, optimize_training_particles, process_test_predictions

from .config_file import prepare_config_file

Expand Down
167 changes: 83 additions & 84 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,92 +13,91 @@
'''
Network Test Function
predicts the PCA scores using the trained networks
returns the error measures and saves the predicted and poriginal particles for comparison
returns the error measures and saves the predicted and original particles for comparison
'''


def test(config_file, loader="test"):
with open(config_file) as json_file:
parameters = json.load(json_file)
model_dir = parameters["paths"]["out_dir"] + parameters["model_name"]+ '/'
pred_dir = model_dir + loader + '_predictions/'
loaders.make_dir(pred_dir)
if parameters["use_best_model"]:
model_path = model_dir + 'best_model.torch'
else:
model_path = model_dir + 'final_model.torch'
if parameters["fine_tune"]["enabled"]:
model_path_ft = model_path.replace(".torch", "_ft.torch")
else:
model_path_ft = model_path
loader_dir = parameters["paths"]["loader_dir"]
with open(config_file) as json_file:
parameters = json.load(json_file)
model_dir = parameters["paths"]["out_dir"] + parameters["model_name"] + '/'
pred_dir = model_dir + loader + '_predictions/'
loaders.make_dir(pred_dir)
if parameters["use_best_model"]:
model_path = model_dir + 'best_model.torch'
else:
model_path = model_dir + 'final_model.torch'
if parameters["fine_tune"]["enabled"]:
model_path_ft = model_path.replace(".torch", "_ft.torch")
else:
model_path_ft = model_path
loader_dir = parameters["paths"]["loader_dir"]

# load the loaders
sw_message("Loading " + loader + " data loader...")
test_loader = torch.load(loader_dir + loader)

# initialization
sw_message("Loading trained model...")
if parameters['tl_net']['enabled']:
model_tl = model.DeepSSMNet_TLNet(config_file)
model_tl.load_state_dict(torch.load(model_path))
device = model_tl.device
model_tl.to(device)
model_tl.eval()
else:
model_pca = model.DeepSSMNet(config_file)
model_pca.load_state_dict(torch.load(model_path))
device = model_pca.device
model_pca.to(device)
model_pca.eval()
model_ft = model.DeepSSMNet(config_file)
model_ft.load_state_dict(torch.load(model_path_ft))
model_ft.to(device)
model_ft.eval()

# load the loaders
sw_message("Loading "+ loader + " data loader...")
test_loader = torch.load(loader_dir + loader)
print("Done.\n")
# initalizations
sw_message("Loading trained model...")
if parameters['tl_net']['enabled']:
model_tl = model.DeepSSMNet_TLNet(config_file)
model_tl.load_state_dict(torch.load(model_path))
device = model_tl.device
model_tl.to(device)
model_tl.eval()
else:
model_pca = model.DeepSSMNet(config_file)
model_pca.load_state_dict(torch.load(model_path))
device = model_pca.device
model_pca.to(device)
model_pca.eval()
model_ft = model.DeepSSMNet(config_file)
model_ft.load_state_dict(torch.load(model_path_ft))
model_ft.to(device)
model_ft.eval()
# Get test names
test_names_file = loader_dir + loader + '_names.txt'
f = open(test_names_file, 'r')
test_names_string = f.read()
f.close()
test_names_string = test_names_string.replace("[", "").replace("]", "").replace("'", "").replace(" ", "")
test_names = test_names_string.split(",")
sw_message(f"Predicting for {loader} images...")
index = 0
pred_scores = []

# Get test names
test_names_file = loader_dir + loader + '_names.txt'
f = open(test_names_file, 'r')
test_names_string = f.read()
f.close()
test_names_string = test_names_string.replace("[","").replace("]","").replace("'","").replace(" ","")
test_names = test_names_string.split(",")
sw_message(f"Predicting for {loader} images...")
index = 0
pred_scores = []
pred_path = pred_dir + 'world_predictions/'
loaders.make_dir(pred_path)
pred_path_pca = pred_dir + 'pca_predictions/'
loaders.make_dir(pred_path_pca)

if parameters['tl_net']['enabled']:
predPath_tl = pred_dir + '/TL_Predictions'
loaders.make_dir(predPath_tl)
else:
predPath_ft = pred_dir + 'FT_Predictions/'
predPath_pca = pred_dir + 'PCA_Predictions/'
loaders.make_dir(predPath_ft)
loaders.make_dir(predPath_pca)
predicted_particle_files = []
for img, _, mdl, _ in test_loader:
if sw_check_abort():
sw_message("Aborted")
return
sw_message(f"Predicting {index+1}/{len(test_loader)}")
sw_progress((index+1) / len(test_loader))
img = img.to(device)
if parameters['tl_net']['enabled']:
mdl = torch.FloatTensor([1]).to(device)
[pred_tf, pred_mdl_tl] = model_tl(mdl, img)
pred_scores.append(pred_tf.cpu().data.numpy())
# save the AE latent space as shape descriptors
nmpred = predPath_tl + '/' + test_names[index] + '.npy'
np.save(nmpred, pred_tf.squeeze().detach().cpu().numpy())
nmpred = predPath_tl + '/' + test_names[index] + '.particles'
np.savetxt(nmpred, pred_mdl_tl.squeeze().detach().cpu().numpy())
else:
[pred, pred_mdl_pca] = model_pca(img)
[pred, pred_mdl_ft] = model_ft(img)
pred_scores.append(pred.cpu().data.numpy()[0])
nmpred = predPath_pca + '/predicted_pca_' + test_names[index] + '.particles'
np.savetxt(nmpred, pred_mdl_pca.squeeze().detach().cpu().numpy())
nmpred = predPath_ft + '/predicted_ft_' + test_names[index] + '.particles'
np.savetxt(nmpred, pred_mdl_ft.squeeze().detach().cpu().numpy())
predicted_particle_files.append(nmpred)
index += 1
sw_message("Test completed.")
return predicted_particle_files
predicted_particle_files = []
for img, _, mdl, _ in test_loader:
if sw_check_abort():
sw_message("Aborted")
return
sw_message(f"Predicting {index + 1}/{len(test_loader)}")
sw_progress((index + 1) / len(test_loader))
img = img.to(device)
particle_filename = pred_path + test_names[index] + '.particles'
if parameters['tl_net']['enabled']:
mdl = torch.FloatTensor([1]).to(device)
[pred_tf, pred_mdl_tl] = model_tl(mdl, img)
pred_scores.append(pred_tf.cpu().data.numpy())
# save the AE latent space as shape descriptors
filename = pred_path + test_names[index] + '.npy'
np.save(filename, pred_tf.squeeze().detach().cpu().numpy())
np.savetxt(particle_filename, pred_mdl_tl.squeeze().detach().cpu().numpy())
else:
[pred, pred_mdl_pca] = model_pca(img)
[pred, pred_mdl_ft] = model_ft(img)
pred_scores.append(pred.cpu().data.numpy()[0])
filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles'
np.savetxt(filename, pred_mdl_pca.squeeze().detach().cpu().numpy())
np.savetxt(particle_filename, pred_mdl_ft.squeeze().detach().cpu().numpy())
print("Predicted particle file: ", particle_filename)
predicted_particle_files.append(filename)
index += 1
sw_message("Test completed.")
return predicted_particle_files
Loading

0 comments on commit 077c588

Please sign in to comment.