In [None]:
import json
import os
import argparse

import torch
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from csnet.models.csnet import CSNet
from csnet.models.csnet_orig import CSNetOrig
from csnet.utils.plot import plot_projections
from csnet.utils.predict import predict

In [None]:
input_dir = 'data/semantic_3D/test/img'
output_dir = 'predictions/test'
model_path = 'model_test/dry-bush-38/best_model.pth'

batch_size = 4

### Setup and load model 

In [None]:
with open(os.path.join(os.path.dirname(model_path), 'config.json')) as f:
    config = json.load(f)
config = argparse.Namespace(**config)
config

In [None]:
if config.model.lower() == 'unet':
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=config.n_channels,
        strides=(2,) * (len(config.n_channels) - 1),
        num_res_units=config.num_res_units,
        norm=Norm.BATCH,
    )
elif config.model.lower() == 'csnet':
    net = CSNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=config.n_channels,
        strides=(2,) * (len(config.n_channels) - 1),
        num_res_units=config.num_res_units,
        norm=Norm.BATCH,
    )
elif config.model.lower() == 'csnet_orig':
    net = CSNetOrig(2, 1)
else:
    raise NotImplementedError(
        rf'{config.model} is an invalid model; must be one of ["unet", "csnet", "csnet_orig"]')

net.load_state_dict(torch.load(model_path))

### Predict

In [None]:
image, predicted = predict(input_dir, output_dir, net, roi_size=config.roi_size, return_last=True, batch_size=batch_size)

In [None]:
plot_projections([image, predicted], panel_size=6)