This notebook shows a minimal example of loading a saved TCD model and running on a new image. All configuration details are stored in a YAML file and are handled by the model runner.

In [None]:
import model
import rasterio
import matplotlib.pyplot as plt
import numpy as np

Let's load the model runner with the default config. This configuration file includes details for the model to be used, inference parameters (e.g. max number of instances), etc.

In [None]:
runner = model.ModelRunner("default.yaml")

And an image to test on:

In [None]:
image_path = "/home/josh/.darwin/datasets/restor/stratified_test_data/images/5c15321f63d9810007f8b06f_10_00000.tif"
image = rasterio.open(image_path)

Let's predict the results for this image. The first time this function is run, Torch will load the model so it'll take a bit of time. Repeat runs should be much faster. This model is run at native resolution (10 cm), but still does a good job at finding smaller trees in the image due to test-time augmentation.

The results here are quite bad, because the image is reshaped to a maximum size of 1024 px.

In [None]:
results = runner.predict_file(image_path)

In [None]:
print(f"Tree instances detected: {sum((results.pred_classes == 1) * (results.scores > 0.65))}")
print(f"Canopy instances detected: {sum((results.pred_classes == 0) * (results.scores > 0.65))}")

And visualise:

In [None]:
runner.visualise(image.read().transpose(1,2,0), results, figsize=(15,15))

We can also try to detect instances in tiled mode, which works a lot better:

In [None]:
results = runner.detect_tiled(image_path, tile_size=512, pad=100, skip_empty=True)

Let's plot the results - here as a simple mask by flattening the output instances into a single layer. This is much faster and more memory friendly than using Detectron's built-in visualiser. We'll also plot the tile boundaries as red boxes.

Note we also use masked numpy arrays to display different segmentation layers to preserve transparency in the plot.

In [None]:
plt.figure(figsize=(15,15))

extent=[image.bounds[0], image.bounds[2], image.bounds[1], image.bounds[3]]
plt.imshow(image.read().transpose((1,2,0)), extent=extent)
ax = plt.gca()

threshold = 0.5
image_mask = runner.merge_tiled_results(results, image, threshold)

for i, result in enumerate(results):
    
    _, bbox = result
    
    rect = plt.Rectangle(xy=(bbox.minx, bbox.miny),
                         width=bbox.maxx-bbox.minx,
                         height=bbox.maxy-bbox.miny,
                         alpha=0.25,
                         linewidth=4,
                         edgecolor='red',
                         facecolor='none')
    
    ax.add_patch(rect)
    
# Trees
masked = np.ma.masked_where(image_mask[:,:,0] == 0, image_mask[:,:,0])
plt.imshow(masked, alpha=0.8, extent=extent, cmap='Blues_r')

# Trees
masked = np.ma.masked_where(image_mask[:,:,1] == 0, image_mask[:,:,1])
plt.imshow(masked, alpha=0.8, extent=extent, cmap='Reds_r')

Here's a more complex example where we have a large image that we want to resample and predict on. This could possibly be done natively using torchgeo too (TODO), but for now we can just resample the image at the appropriate resolution. This image is something like 20k x 20k pixels - far too large to process in one go. So we need to tile this time.

The tile size here is relatively small, because there are some issues with memory consumption with larger images and large numbers of detected instances. The direct result of this is that it takes a long time. Unfortunately it's difficult to gracefully recover from OOM errors and this is something that we should think about working around, for example processing the image at the maximum tile size that will fit in memory, and progressively "fix" image regions that are too dense to process.

In [None]:
#from tile import Tiler, convert_to_projected
#tiler = Tiler("./data", "./data")
#convert_to_projected("./data/5a04a02cbac48e5b1c01282b.tiff", inplace=True)
#tiler.resample("./data/5a04a02cbac48e5b1c01282b.tiff", "./data/5a04a02cbac48e5b1c01282b_10.tiff")

In [None]:
image_path = "./data/5a04a02cbac48e5b1c01282b_10.tiff"
image = rasterio.open(image_path)

In [None]:
results = runner.detect_tiled(image_path, tile_size=768, pad=100)

In [None]:
plt.figure(figsize=(15,15))

extent=[image.bounds[0], image.bounds[2], image.bounds[1], image.bounds[3]]
plt.imshow(image.read().transpose((1,2,0)), extent=extent)
ax = plt.gca()

threshold = 0.5
image_mask = runner.merge_tiled_results(results, image, threshold)

for i, result in enumerate(results):
    
    _, bbox = result
    
    rect = plt.Rectangle(xy=(bbox.minx, bbox.miny),
                         width=bbox.maxx-bbox.minx,
                         height=bbox.maxy-bbox.miny,
                         alpha=0.05,
                         linewidth=4,
                         edgecolor='red',
                         facecolor='none')
    
    ax.add_patch(rect)
    
# Trees
masked = np.ma.masked_where(image_mask[:,:,0] == 0, image_mask[:,:,0])
plt.imshow(masked, alpha=0.8, extent=extent, cmap='Blues_r')

# Trees
masked = np.ma.masked_where(image_mask[:,:,1] == 0, image_mask[:,:,1])
plt.imshow(masked, alpha=0.8, extent=extent, cmap='Reds_r')