Skip to content

Commit

Permalink
updated save_preds to disk for register.py
Browse files Browse the repository at this point in the history
  • Loading branch information
alanqrwang committed Oct 6, 2023
1 parent 3710285 commit 7cbdc54
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ wandb/*
training_output/*
output/*
pretraining_output/*
register_output/*

*.sh
job_out/*
job_err/*
register.ipynb
register.ipynb

*.npy
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ python register.py \

`--moving_seg` and `--fixed_seg` are optional. If provided, the script will compute the Dice score between the registered moving segmentation map and the fixed segmentation map. Otherwise, it will only compute MSE between the registered moving image and the fixed image.

Add the flag `--save_preds` to save outputs to disk. The default location is `./register_output/`.

For all inputs, ensure that pixel values are min-max normalized to the $[0,1]$ range and that the spatial dimensions are $(L, W, H) = (128, 128, 128)$.

## Training KeyMorph
Expand Down
31 changes: 28 additions & 3 deletions register.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from argparse import ArgumentParser
import torchio as tio
from scipy.stats import loguniform
from pathlib import Path

from keymorph.keypoint_aligners import ClosedFormAffine, TPS
from keymorph.net import ConvNetFC, ConvNetCoM
Expand All @@ -26,7 +27,7 @@ def parse_args():
parser.add_argument("--save_dir",
type=str,
dest="save_dir",
default="./training_output/",
default="./register_output/",
help="Path to the folder where outputs are saved")

parser.add_argument('--load_path', type=str, default=None,
Expand Down Expand Up @@ -135,6 +136,11 @@ def _get_tps_lmbda(num_samples, args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
print('Number of GPUs: {}'.format(torch.cuda.device_count()))

# Create save path
save_path = Path(args.save_dir)
if not os.path.exists(save_path) and args.save_preds:
os.makedirs(save_path)

# Set seed
random.seed(args.seed)
np.random.seed(args.seed)
Expand Down Expand Up @@ -204,8 +210,8 @@ def _get_tps_lmbda(num_samples, args):
registration_model = KeyMorph(network, kp_aligner, args.num_keypoints, args.dim)
registration_model.eval()

for fixed in fixed_loader:
for moving in moving_loader:
for i, fixed in enumerate(fixed_loader):
for j, moving in enumerate(moving_loader):

# Get images and segmentations from TorchIO subject
img_f, img_m = fixed['img'][tio.DATA], moving['img'][tio.DATA]
Expand All @@ -227,6 +233,7 @@ def _get_tps_lmbda(num_samples, args):
img_a = align_img(grid, img_m)
if seg_available:
seg_a = align_img(grid, seg_m)
points_a = kp_aligner.points_from_points(points_m, points_f, points_m, lmbda=lmbda)

# import matplotlib.pyplot as plt
# fig, axes = plt.subplots(1, 3, figsize=(10, 3))
Expand All @@ -249,5 +256,23 @@ def _get_tps_lmbda(num_samples, args):
metrics['jdstd'] = loss_ops.jdstd(grid)
metrics['jdlessthan0'] = loss_ops.jdlessthan0(grid, as_percentage=True)

if args.save_preds:
assert args.batch_size == 1 # TODO: fix this
img_a_path = save_path / f'img_a_{i}_{j}.npy'
seg_a_path = save_path / f'seg_a_{i}_{j}.npy'
points_f_path = save_path / f'points_f_{i}_{j}.npy'
points_m_path = save_path / f'points_m_{i}_{j}.npy'
points_a_path = save_path / f'points_a_{i}_{j}.npy'
grid_path = save_path / f'grid_{i}_{j}.npy'
print('Saving:\n{}\n{}\n{}\n{}\n{}\n{}'.format(img_a_path, seg_a_path,
points_f_path, points_m_path,
points_a_path, grid_path))
np.save(img_a_path, img_a[0].cpu().detach().numpy())
np.save(seg_a_path, np.argmax(seg_a.cpu().detach().numpy(), axis=1))
np.save(points_f_path, points_f[0].cpu().detach().numpy())
np.save(points_m_path, points_m[0].cpu().detach().numpy())
np.save(points_a_path, points_a[0].cpu().detach().numpy())
np.save(grid_path, grid[0].cpu().detach().numpy())

for name, metric in metrics.items():
print(f'[Eval Stat] {name}: {metric:.5f}')

0 comments on commit 7cbdc54

Please sign in to comment.