Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSSM input verification #2196

Merged
merged 28 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
16afe85
Throw exception when number of original domains doesn't match.
akenmorris Jan 29, 2024
79ba42b
When images or original files are not present, disallow deepssm mode
akenmorris Jan 29, 2024
c898602
Throw useful exception when image size will be too small.
akenmorris Jan 29, 2024
a65ec94
Add timing to steps of prep.
akenmorris Jan 30, 2024
ef01a55
Add one more timer
akenmorris Jan 31, 2024
3e48049
Set Eigen threads to the same as TBB.
akenmorris Jan 31, 2024
51f547d
Add Fine Tuning plot, combine csv
akenmorris Jan 31, 2024
d53d007
Fix logger being closed for FT
akenmorris Jan 31, 2024
bb64c35
Allow decay LR to be off
akenmorris Jan 31, 2024
d9de5be
Fix formatting
akenmorris Jan 31, 2024
a74920d
Reinitialze optimizer for fine tuning, use fine tuning learning rate
akenmorris Jan 31, 2024
3743976
Parse table doubles and limit precision for display.
akenmorris Jan 31, 2024
8dae42f
Fix table display (digits)
akenmorris Jan 31, 2024
6c9dc8a
Update parameter names and tooltips
akenmorris Feb 1, 2024
76f557b
Reorganize DeepSSM Prep dialog a bit.
akenmorris Feb 1, 2024
ddcf331
Update screenshot.
akenmorris Feb 1, 2024
9a0b7f4
Improve test mesh loading
akenmorris Feb 5, 2024
3703ba2
Merge branch 'deepssm_changes' into deepssm_input_verification
akenmorris Feb 5, 2024
681b729
Fix TL-DeepSSM when Decay Learning is off.
akenmorris Feb 5, 2024
8498ba3
Replace shapeworks executable usage with Python API of MeshWarper.
akenmorris Feb 5, 2024
c084adb
Fix typos
akenmorris Feb 5, 2024
7d0cb0b
Simplify and improve get_mesh_distance, also write back distance fiel…
akenmorris Feb 5, 2024
b75fdf9
Add more room for image spacing.
akenmorris Feb 5, 2024
e78bb0d
Add ability to set reconstructed meshes directly.
akenmorris Feb 5, 2024
12df6ec
Shift test mesh distance calc and results into Python from Studio.
akenmorris Feb 5, 2024
2ce8bb1
Fix compile
akenmorris Feb 5, 2024
17ac0e7
Fix image name for train/test when image is not already selected.
akenmorris Feb 5, 2024
ef3bcd8
Fix problem with boost create_directories
akenmorris Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Examples/Python/deep_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def Run_Pipeline(args):
"model_name": model_name,
"num_latent_dim": int(embedded_dim),
"paths": {
"out_dir": output_directory,
"out_dir": deepssm_dir,
"loader_dir": loader_dir,
"aug_dir": aug_dir
},
Expand Down Expand Up @@ -446,7 +446,7 @@ def Run_Pipeline(args):
predicted_val_local_particles = []
for particle_file, transform in zip(predicted_val_world_particles, val_transforms):
particles = np.loadtxt(particle_file)
local_particle_file = particle_file.replace("FT_Predictions/", "local_predictions/")
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
np.savetxt(local_particle_file, local_particles)
predicted_val_local_particles.append(local_particle_file)
Expand All @@ -468,8 +468,6 @@ def Run_Pipeline(args):

print("Validation mean mesh surface-to-surface distance: " + str(mean_dist))

# If tiny test or verify, check results and exit
check_results(args, mean_dist)
open(status_dir + "step_11.txt", 'w').close()

######################################################################################
Expand Down Expand Up @@ -512,7 +510,7 @@ def Run_Pipeline(args):
predicted_test_local_particles = []
for particle_file, transform in zip(predicted_test_world_particles, test_transforms):
particles = np.loadtxt(particle_file)
local_particle_file = particle_file.replace("FT_Predictions/", "local_predictions/")
local_particle_file = particle_file.replace("world_predictions/", "local_predictions/")
local_particles = sw.utils.transformParticles(particles, transform, inverse=True)
np.savetxt(local_particle_file, local_particles)
predicted_test_local_particles.append(local_particle_file)
Expand All @@ -530,6 +528,9 @@ def Run_Pipeline(args):
template_particles, template_mesh, test_out_dir,
planes=test_planes)
print("Test mean mesh surface-to-surface distance: " + str(mean_dist))

# If tiny test or verify, check results and exit
check_results(args, mean_dist)
open(status_dir + "step_12.txt", 'w').close()

print("All steps complete")
Expand Down
5 changes: 4 additions & 1 deletion Libs/Analyze/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ MeshGroup Shape::get_reconstructed_meshes(bool wait) {
return reconstructed_meshes_;
}

//---------------------------------------------------------------------------
void Shape::set_reconstructed_meshes(MeshGroup meshes) { reconstructed_meshes_ = meshes; }

//---------------------------------------------------------------------------
void Shape::reset_groomed_mesh() { groomed_meshes_ = MeshGroup(subject_->get_number_of_domains()); }

Expand Down Expand Up @@ -572,7 +575,7 @@ std::shared_ptr<Image> Shape::get_image_volume(std::string image_volume_name) {
std::shared_ptr<Image> image = std::make_shared<Image>(filename);
image_volume_ = image;
image_volume_filename_ = filename;
} catch (std::exception &ex) {
} catch (std::exception& ex) {
SW_ERROR("Unable to open file: {}", filename);
}
}
Expand Down
3 changes: 3 additions & 0 deletions Libs/Analyze/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class Shape {
//! Retrieve the reconstructed meshes
MeshGroup get_reconstructed_meshes(bool wait = false);

//! Set the reconstructed meshes
void set_reconstructed_meshes(MeshGroup meshes);

//! Reset the groomed mesh so that it will be re-created
void reset_groomed_mesh();

Expand Down
1 change: 1 addition & 0 deletions Libs/Common/ShapeworksUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ void ShapeWorksUtils::setup_threads() {
num_threads = std::max(1, atoi(num_threads_env));
}
SW_DEBUG("TBB using {} threads", num_threads);
Eigen::setNbThreads(num_threads);
tbb::global_control c(tbb::global_control::max_allowed_parallelism, num_threads);
}

Expand Down
9 changes: 7 additions & 2 deletions Libs/Groom/Groom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ int Groom::get_total_ops() {
for (int i = 0; i < domains.size(); i++) {
auto params = GroomParameters(project_, domains[i]);

if (project_->get_original_domain_types().size() <= i) {
throw std::runtime_error("invalid domain, number of original file types does not match number of domains");
}

if (project_->get_original_domain_types()[i] == DomainType::Image) {
num_tools += params.get_isolate_tool() ? 1 : 0;
num_tools += params.get_fill_holes_tool() ? 1 : 0;
Expand Down Expand Up @@ -526,7 +530,6 @@ bool Groom::run_alignment() {
std::vector<Mesh> reference_meshes;
std::vector<Mesh> meshes;
for (size_t i = 0; i < subjects.size(); i++) {

if (!subjects[i]->is_excluded()) {
Mesh mesh = get_mesh(i, domain, true);
// if fixed subjects are present, only add the fixed subjects
Expand Down Expand Up @@ -711,7 +714,9 @@ std::string Groom::get_output_filename(std::string input, DomainType domain_type
path = base + "/" + prefix;

try {
boost::filesystem::create_directories(path);
if (!boost::filesystem::exists(path)) {
boost::filesystem::create_directories(path);
}
} catch (std::exception& e) {
throw std::runtime_error("Unable to create groom output directory: \"" + path + "\"");
}
Expand Down
4 changes: 3 additions & 1 deletion Libs/Image/Image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ Image& Image::write(const std::string& filename, bool compressed) {

// if the directory doesn't exist, create it
boost::filesystem::path dir(filename);
boost::filesystem::create_directories(dir.parent_path());
if (dir.has_parent_path() && !boost::filesystem::exists(dir.parent_path())) {
boost::filesystem::create_directories(dir.parent_path());
}

using WriterType = itk::ImageFileWriter<ImageType>;
WriterType::Pointer writer = WriterType::New();
Expand Down
2 changes: 1 addition & 1 deletion Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .run_utils import create_split, groom_training_shapes, groom_training_images, get_reference_index, \
run_data_augmentation, groom_val_test_images, prep_project_for_val_particles, groom_validation_shapes, \
prepare_data_loaders, get_deepssm_dir, get_split_indices, optimize_training_particles
prepare_data_loaders, get_deepssm_dir, get_split_indices, optimize_training_particles, process_test_predictions

from .config_file import prepare_config_file

Expand Down
167 changes: 83 additions & 84 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,92 +13,91 @@
'''
Network Test Function
predicts the PCA scores using the trained networks
returns the error measures and saves the predicted and poriginal particles for comparison
returns the error measures and saves the predicted and original particles for comparison
'''


def test(config_file, loader="test"):
with open(config_file) as json_file:
parameters = json.load(json_file)
model_dir = parameters["paths"]["out_dir"] + parameters["model_name"]+ '/'
pred_dir = model_dir + loader + '_predictions/'
loaders.make_dir(pred_dir)
if parameters["use_best_model"]:
model_path = model_dir + 'best_model.torch'
else:
model_path = model_dir + 'final_model.torch'
if parameters["fine_tune"]["enabled"]:
model_path_ft = model_path.replace(".torch", "_ft.torch")
else:
model_path_ft = model_path
loader_dir = parameters["paths"]["loader_dir"]
with open(config_file) as json_file:
parameters = json.load(json_file)
model_dir = parameters["paths"]["out_dir"] + parameters["model_name"] + '/'
pred_dir = model_dir + loader + '_predictions/'
loaders.make_dir(pred_dir)
if parameters["use_best_model"]:
model_path = model_dir + 'best_model.torch'
else:
model_path = model_dir + 'final_model.torch'
if parameters["fine_tune"]["enabled"]:
model_path_ft = model_path.replace(".torch", "_ft.torch")
else:
model_path_ft = model_path
loader_dir = parameters["paths"]["loader_dir"]

# load the loaders
sw_message("Loading " + loader + " data loader...")
test_loader = torch.load(loader_dir + loader)

# initialization
sw_message("Loading trained model...")
if parameters['tl_net']['enabled']:
model_tl = model.DeepSSMNet_TLNet(config_file)
model_tl.load_state_dict(torch.load(model_path))
device = model_tl.device
model_tl.to(device)
model_tl.eval()
else:
model_pca = model.DeepSSMNet(config_file)
model_pca.load_state_dict(torch.load(model_path))
device = model_pca.device
model_pca.to(device)
model_pca.eval()
model_ft = model.DeepSSMNet(config_file)
model_ft.load_state_dict(torch.load(model_path_ft))
model_ft.to(device)
model_ft.eval()

# load the loaders
sw_message("Loading "+ loader + " data loader...")
test_loader = torch.load(loader_dir + loader)
print("Done.\n")
# initalizations
sw_message("Loading trained model...")
if parameters['tl_net']['enabled']:
model_tl = model.DeepSSMNet_TLNet(config_file)
model_tl.load_state_dict(torch.load(model_path))
device = model_tl.device
model_tl.to(device)
model_tl.eval()
else:
model_pca = model.DeepSSMNet(config_file)
model_pca.load_state_dict(torch.load(model_path))
device = model_pca.device
model_pca.to(device)
model_pca.eval()
model_ft = model.DeepSSMNet(config_file)
model_ft.load_state_dict(torch.load(model_path_ft))
model_ft.to(device)
model_ft.eval()
# Get test names
test_names_file = loader_dir + loader + '_names.txt'
f = open(test_names_file, 'r')
test_names_string = f.read()
f.close()
test_names_string = test_names_string.replace("[", "").replace("]", "").replace("'", "").replace(" ", "")
test_names = test_names_string.split(",")
sw_message(f"Predicting for {loader} images...")
index = 0
pred_scores = []

# Get test names
test_names_file = loader_dir + loader + '_names.txt'
f = open(test_names_file, 'r')
test_names_string = f.read()
f.close()
test_names_string = test_names_string.replace("[","").replace("]","").replace("'","").replace(" ","")
test_names = test_names_string.split(",")
sw_message(f"Predicting for {loader} images...")
index = 0
pred_scores = []
pred_path = pred_dir + 'world_predictions/'
loaders.make_dir(pred_path)
pred_path_pca = pred_dir + 'pca_predictions/'
loaders.make_dir(pred_path_pca)

if parameters['tl_net']['enabled']:
predPath_tl = pred_dir + '/TL_Predictions'
loaders.make_dir(predPath_tl)
else:
predPath_ft = pred_dir + 'FT_Predictions/'
predPath_pca = pred_dir + 'PCA_Predictions/'
loaders.make_dir(predPath_ft)
loaders.make_dir(predPath_pca)
predicted_particle_files = []
for img, _, mdl, _ in test_loader:
if sw_check_abort():
sw_message("Aborted")
return
sw_message(f"Predicting {index+1}/{len(test_loader)}")
sw_progress((index+1) / len(test_loader))
img = img.to(device)
if parameters['tl_net']['enabled']:
mdl = torch.FloatTensor([1]).to(device)
[pred_tf, pred_mdl_tl] = model_tl(mdl, img)
pred_scores.append(pred_tf.cpu().data.numpy())
# save the AE latent space as shape descriptors
nmpred = predPath_tl + '/' + test_names[index] + '.npy'
np.save(nmpred, pred_tf.squeeze().detach().cpu().numpy())
nmpred = predPath_tl + '/' + test_names[index] + '.particles'
np.savetxt(nmpred, pred_mdl_tl.squeeze().detach().cpu().numpy())
else:
[pred, pred_mdl_pca] = model_pca(img)
[pred, pred_mdl_ft] = model_ft(img)
pred_scores.append(pred.cpu().data.numpy()[0])
nmpred = predPath_pca + '/predicted_pca_' + test_names[index] + '.particles'
np.savetxt(nmpred, pred_mdl_pca.squeeze().detach().cpu().numpy())
nmpred = predPath_ft + '/predicted_ft_' + test_names[index] + '.particles'
np.savetxt(nmpred, pred_mdl_ft.squeeze().detach().cpu().numpy())
predicted_particle_files.append(nmpred)
index += 1
sw_message("Test completed.")
return predicted_particle_files
predicted_particle_files = []
for img, _, mdl, _ in test_loader:
if sw_check_abort():
sw_message("Aborted")
return
sw_message(f"Predicting {index + 1}/{len(test_loader)}")
sw_progress((index + 1) / len(test_loader))
img = img.to(device)
particle_filename = pred_path + test_names[index] + '.particles'
if parameters['tl_net']['enabled']:
mdl = torch.FloatTensor([1]).to(device)
[pred_tf, pred_mdl_tl] = model_tl(mdl, img)
pred_scores.append(pred_tf.cpu().data.numpy())
# save the AE latent space as shape descriptors
filename = pred_path + test_names[index] + '.npy'
np.save(filename, pred_tf.squeeze().detach().cpu().numpy())
np.savetxt(particle_filename, pred_mdl_tl.squeeze().detach().cpu().numpy())
else:
[pred, pred_mdl_pca] = model_pca(img)
[pred, pred_mdl_ft] = model_ft(img)
pred_scores.append(pred.cpu().data.numpy()[0])
filename = pred_path_pca + '/predicted_pca_' + test_names[index] + '.particles'
np.savetxt(filename, pred_mdl_pca.squeeze().detach().cpu().numpy())
np.savetxt(particle_filename, pred_mdl_ft.squeeze().detach().cpu().numpy())
print("Predicted particle file: ", particle_filename)
predicted_particle_files.append(filename)
index += 1
sw_message("Test completed.")
return predicted_particle_files
Loading
Loading