In [None]:
import torch
import torchio as tio
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import HTML

In [None]:
class Segmenter(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = UNet()
        self.ce_loss = torch.nn.CrossEntropyLoss(weight=torch.ones(3))
    def forward(self, x): return self.model(x)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Segmenter.load_from_checkpoint("checkpoints/liver-epoch=95-val_loss=0.28.ckpt").to(device).eval()

In [None]:
IDX = 4
subject = val_dataset[IDX]
imgs = subject["CT"]["data"][0].numpy()    
labels = subject["Label"]["data"][0].numpy() 

grid_sampler = tio.inference.GridSampler(subject, 96, (8, 8, 8))
aggregator = tio.inference.GridAggregator(grid_sampler)
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)

In [None]:
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['CT']["data"].to(device)
        locations = patches_batch[tio.LOCATION]
        pred = model(input_tensor)
        aggregator.add_batch(pred, locations)

pred_vol = aggregator.get_output_tensor().argmax(0).numpy() 

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
camera = Camera(fig)

for i in range(0, imgs.shape[2], 2): 
    ax1.imshow(imgs[:,:,i], cmap="bone")
    p_mask = np.ma.masked_where(pred_vol[:,:,i] == 0, pred_vol[:,:,i])
    ax1.imshow(p_mask, alpha=0.4, cmap="autumn", vmin=0, vmax=2)
    ax1.set_title("Model Prediction")
    ax1.axis("off")

    ax2.imshow(imgs[:,:,i], cmap="bone")
    l_mask = np.ma.masked_where(labels[:,:,i] == 0, labels[:,:,i])
    ax2.imshow(l_mask, alpha=0.4, cmap="jet", vmin=0, vmax=2)
    ax2.set_title("Ground Truth (Target)")
    ax2.axis("off")

    plt.tight_layout()
    camera.snap()

animation = camera.animate()
plt.close()
HTML(animation.to_html5_video())