In [None]:
import torch
import json
import pickle
from pathlib import Path
from ccb.experiment.experiment import Job, get_model_generator
from ruamel.yaml import YAML
from ccb.torch_toolbox.dataset import DataModule
from ccb.torch_toolbox.model import Model
import matplotlib.pyplot as plt
from ccb.dataset_converters.inspect_tools import float_image_to_uint8

In [None]:
best_config_dict_path = "/mnt/home/climate-change-benchmark/ccb/configs/best_hparams_segmentation.json"

with open(best_config_dict_path, "r") as f:
    config_dict = json.load(f)

## specify which model and dataset to visualize

In [None]:
# specify model, partition and dataset
MODEL_NAME = "resnet101_DeepLabV3"
PARTITION_SIZE = "1.00x_train"
DATASET = "NeonTree_segmentation"

BATCH_SIZE = 8

In [None]:
best_sweep_path = Path(config_dict[MODEL_NAME][PARTITION_SIZE][DATASET])
config_path = best_sweep_path / "config.yaml"

yaml = YAML()
with open(config_path, "r") as fd:
    config = yaml.load(fd)

task_specs_path = best_sweep_path.parents[1] / "task_specs.pkl"

with open(task_specs_path, "rb") as f:
    task_specs = pickle.load(f)

checkpoint_dir = best_sweep_path.parents[1] / "checkpoint"
best_ckpt = list(checkpoint_dir.glob("*.ckpt"))[-1]


## set up model for prediction

In [None]:
model_gen = get_model_generator(config["model"]["model_generator_module_name"])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model_gen.generate_model(task_specs=task_specs, config=config)

ckpt = torch.load(best_ckpt)

model.load_state_dict(ckpt["state_dict"])
model.to(device)

datamodule = DataModule(
    task_specs=task_specs,
    benchmark_dir=config["experiment"]["benchmark_dir"],
    partition_name=config["experiment"]["partition_name"],
    batch_size=BATCH_SIZE,
    num_workers=config["dataloader"]["num_workers"],
    train_transform=model_gen.get_transform(task_specs=task_specs, config=config, train=True),
    eval_transform=model_gen.get_transform(task_specs=task_specs, config=config, train=False),
    collate_fn=model_gen.get_collate_fn(task_specs, config),
    band_names=config["dataset"]["band_names"],
    format=config["dataset"]["format"],
)

test_loader = datamodule.test_dataloader()

## make a prediction and visualize results

In [None]:
batch = next(iter(test_loader))

preds = model.forward(batch["input"].to(device)).cpu()

In [None]:
fig, axs = plt.subplots(nrows=BATCH_SIZE, ncols=3, figsize=(20, BATCH_SIZE * 5))
for i in range(BATCH_SIZE):
    input = batch["input"][i].permute(1, 2, 0).numpy()
    # roughly unnormalize image
    orig_input = float_image_to_uint8([input])[0]
    
    ground_truth = batch["label"][i]
    pred = preds[i]
    
    axs[i][0].imshow(orig_input)
    axs[0][0].set_title("Input image")
    axs[i][0].axis("off")

    axs[i][1].imshow(ground_truth.numpy())
    axs[0][1].set_title("Ground Truth")
    axs[i][1].axis("off")

    axs[i][2].imshow(pred.argmax(0).numpy())
    axs[0][2].set_title("Pediction")
    axs[i][2].axis("off")

plt.suptitle(f"{MODEL_NAME} on {DATASET} with partiton {PARTITION_SIZE}")
plt.show()