## Imports & Config

In [None]:
import torch
import yaml
from arcade_dataset import load_dataset, visualize_batch
import matplotlib.pyplot as plt
from model import VesselSegmentationModel
import pytorch_lightning as pl
from time import time
torch.cuda.empty_cache()

In [None]:
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

In [None]:
config = yaml.load(open('model_overfit_config.yaml', 'r'), Loader=yaml.FullLoader)

In [None]:
modalities, H_in, W_in = config['input']['image_shape']

## Model Definition

In [None]:
model_path = 'models/final_model/...'
model = VesselSegmentationModel.load_from_checkpoint(model_path)

In [None]:
trainer = pl.Trainer(accelerator='auto')
trainer.test(model)

### Test the model on a single image

In [None]:
dataset = load_dataset(split='val')
next(iter(dataset)).keys()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=True)

In [None]:
batch = next(iter(dataloader))

In [None]:
visualize_batch(batch, num_classes=25, num_images=3)

In [None]:
# Look at the output of the model
model.eval()
model.to(device)

In [None]:
from tqdm.notebook import tqdm


for i in tqdm(range(30)):    
    batch = next(iter(dataloader))
    x, y_gt = batch['transformed_image'], batch['separate_masks']
    decoder_output, vae_output, _, labels, _ = model(x)
    img = 0
    
    valid_channels = [c for c in range(25) if y_gt[img, c, :, :].mean() > 0.000001 and decoder_output[img, c, :, :].mean() > 0.000001]
    plt.subplots(figsize=(10, 10), ncols=3, nrows=len(valid_channels), sharex=True, sharey=True)

    for i, c in enumerate(valid_channels):
        plt.subplot(len(valid_channels), 3, 3 * i + 1)
        plt.imshow(x[img, 0, :, :].cpu().numpy(), cmap='gray')
        plt.title('Input Image')
        plt.axis('off')

        plt.subplot(len(valid_channels), 3, 3 * i + 2)
        plt.imshow(y_gt[img, c, :, :].cpu().numpy(), cmap='gray')
        plt.title('Ground Truth Mask')
        plt.axis('off')

        plt.subplot(len(valid_channels), 3, 3 * i + 3)
        plt.imshow(decoder_output[img, c, :, :].detach().cpu().numpy(), cmap='gray')
        plt.title('Decoder Output')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(f'plots/outputs/model_output_{str(hash(time()))[:10]}.svg', format='svg', bbox_inches='tight')
    plt.show()

  0%|          | 0/30 [00:00<?, ?it/s]