In [None]:
# change to your main working directory
# %cd [your-cwd]
# %pwd

# Load trainer and evaluate PCK

In [None]:
import torch
from spacejam.trainer import trainer_from_checkpoint, Trainer
from utilities.run_utils import forward_inference_all

model_pt_path = 'your_pt_path.pt'
trainer : Trainer = trainer_from_checkpoint(model_pt_path)
transformers_list, encoded_features_list, inputs_dicts_list = forward_inference_all(trainer.dataset, trainer.autoencoder, trainer.stn, trainer.atlas_handler, trainer.loss_handler, train_with_reflections=trainer.train_with_reflections)

In [None]:
eval_dict = trainer.eval()
for k, v in eval_dict.items():
    print(f'{k}: {v.item()}')

# Generate warped images and visualizations

In [None]:
NUM_IMAGES_TO_SHOW = 4

## Access the homography transformation parameters

In [None]:
from spacejam.models.transformers.homography_transformer import HomographyTransformer

for i in range(NUM_IMAGES_TO_SHOW):
    transformer: HomographyTransformer = transformers_list[i].transformers[0]  # type: ignore
    print(transformer.theta.cpu().numpy())
    print()

## Show some images warped

In [None]:
from matplotlib import pyplot as plt
from torchvision.utils import make_grid

images = []
for i in range(NUM_IMAGES_TO_SHOW):
    image = inputs_dicts_list[i]['images']
    transformed_image = transformers_list[i](image)
    vis = torch.cat([image[0], transformed_image[0]], dim=1)  
    images.append(vis)

grid = make_grid(images, nrow=NUM_IMAGES_TO_SHOW, padding=0)  
plt.figure(figsize=(5*NUM_IMAGES_TO_SHOW, 5))  
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.show()

## Create the atlas after training

In [None]:
images_warped = torch.stack([tr(inputs_dicts['images'])[0] for tr, inputs_dicts in zip(transformers_list, inputs_dicts_list)])
encoded_features = torch.stack([tr(encoded_features)[0] for tr, encoded_features in zip(transformers_list, encoded_features_list)])
masks_warped = torch.stack([tr(inputs_dicts['masks'])[0] for tr, inputs_dicts in zip(transformers_list, inputs_dicts_list)])

encoded_features = (encoded_features - encoded_features.min()) / (encoded_features.max() - encoded_features.min())

vis = encoded_features.mean(dim=0) * masks_warped.median(dim=0).values
plt.imshow(vis.permute(1,2,0).cpu().numpy())
