This script runs inference on a tile, using the specified model. If requested, it will also save the predictions to file.

In [None]:
import torch
import rasterio
import sys
sys.path.append("..")
import utils.data as dt
import utils.model_utils as mu

The block below includes all the parameters for inference. We can pass in different parameters using `papermill`.

In [None]:
# some parameters about the datasets on which to perform inference
paths = [
    "/datadrive/snake/lakes/le7-2015/splits/train/GL083789E28642N-20150801.tif", 
    "/datadrive/snake/lakes/le7-2015/splits/train/GL087427E28754N-20151101.tif"
]
stats = "/datadrive/snake/lakes/le7-2015/splits/train/processed/statistics.csv"
dataset_opts = {}

# path to the model to load, and default options
model_fn = "/datadrive/snake/glaciers/test_models/snake_200.pt"
model_opts = {"model": "DELSE", "pth_model": "/datadrive/snake/lakes/models/MS_DeepLab_resnet_trained_VOC.pth"}

# Where to save predictions? If None, doesn't save anything
out_dir = "/datadrive/snake/lakes/le7-2015/preds/"

The block below loads the model that we've previously trained. Any options that were used to customize the model architecture need to be passed in through the `model_opts` parameter.

In [None]:
import torch
from models.unet import UnetModel
from models.networks import backend_cnnn
device = torch.device(cuda if torch.cuda.is_available() else cpu)

# specify the model, and load the state dict
if model_opts["model"] == "U-Net":
    model = UnetModel(**model_opts)
elif model_opts["model"] == "DELSE":
    model = backend_cnn(**model_opts)

model.load_state_dict(model_fn)
model.eval()
model.to(device)

Next, we construct a dataloader over all the paths that we want to perform inference on. I'm assuming here that the dataset class knows how to handle any imputation and renormalization. I'm also assuming that it will not crop the input image. If the dataset class doesn't have that much flexibility, then we can just create a small `inference` dataset class that will just do the imputation and renormalization.

In [None]:
from torch.utils.dataset import DataLoader

# create a dataloader, and perform inference
dataset = DataSet(paths, **dataset_opts)
loader = DataLoader(dataset)
mu.save_inferences(model, loader, out_dir)

Finally, it can be nice to look at a few images directly in the notebook, just so we don't have to copy all the tif's and load QGIS to see a couple predictions.

In [None]:
import matplotlib.pyplot as plt

for i in range(y_hat.shape[0]):
    x_ = np.transpose(x[i], (1, 2, 0))
    plt.imshow(10 * x_ / np.nanmax(x_))
    plt.imshow(y_hat[i])