# PyTorch EO Semantic Segmentation Example
## Step 5: Model Inference
*Rob Knapen, Wageningen Environmental Research*
<br>

This notebook provides an example of using a trained model that has been converted to TorchScript for inference, taking a geotiff file as input. GDAL is used to read and write geotiff files, since we want to be able to convert the method to C++ in the end.

In [2]:
import os
import h5py
import torch
import torch.nn as nn

In [3]:
from osgeo import gdal

In [14]:
# try to open the input file
dataset = gdal.Open("../data/raw/sentinel2_2018_flevopolder_10m_7x4bands.tif")
if not dataset:
    print("could not read the file")

In [15]:
# get some information from the file
print("Driver: {}/{}".format(dataset.GetDriver().ShortName, dataset.GetDriver().LongName))
print("Size is {} x {} x {}".format(dataset.RasterXSize, dataset.RasterYSize, dataset.RasterCount))
print("Projection is {}".format(dataset.GetProjection()))
geotransform = dataset.GetGeoTransform()
if geotransform:
    print("Origin = ({}, {})".format(geotransform[0], geotransform[3]))
    print("Pixel Size = ({}, {})".format(geotransform[1], geotransform[5]))

Driver: GTiff/GeoTIFF
Size is 5490 x 2170 x 28
Projection is PROJCS["WGS 84 / UTM zone 31N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",3],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32631"]]
Origin = (674900.0, 5853960.0)
Pixel Size = (10.0, -10.0)


In [27]:
# load the torchscript model
model = torch.jit.load("../models/rvo_crops_segnet_224x224x28_77classes_100epochs_model_full_traced.pt")
print(model)

RecursiveScriptModule(
  original_name=SegNet
  (maxpool): RecursiveScriptModule(original_name=MaxPool2d)
  (unpool): RecursiveScriptModule(original_name=MaxUnpool2d)
  (c1): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=ReLU)
  )
  (c2): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=ReLU)
  )
  (c3): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=ReLU)
  )
  (c4): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_na

In [49]:
# get some raster band information
band = dataset.GetRasterBand(1)
print("Band Type={}".format(gdal.GetDataTypeName(band.DataType)))

min = band.GetMinimum()
max = band.GetMaximum()
if not min or not max:
    (min,max) = band.ComputeRasterMinMax(True)
print("Min={:.3f}, Max={:.3f}".format(min,max))

if band.GetOverviewCount() > 0:
    print("Band has {} overviews".format(band.GetOverviewCount()))

if band.GetRasterColorTable():
    print("Band has a color table with {} entries".format(band.GetRasterColorTable().GetCount()))

Band Type=UInt16
Min=1.000, Max=7267.000


In [52]:
# read multiple bands from the dataset into a numpy array
arr = dataset.ReadAsArray(
    xoff=500,
    yoff=500,
    xsize=256,
    ysize=256,
    buf_obj=None,
    buf_xsize=None,
    buf_ysize=None,
    buf_type=None,
    resample_alg=0,
    callback=None,
    callback_data=None,
    interleave='band',
    band_list=None
).astype("float32")

arr.shape

(28, 256, 256)

In [53]:
# normalize the band data before the inference
def normalize(ds, tile_np):
    for index in range(tile_np.shape[0]):
        band = ds.GetRasterBand(index + 1)
        (_, b_max) = band.ComputeRasterMinMax(True)
        tile_np[index] /= b_max

normalize(dataset, arr)
arr

array([[[0.13774598, 0.14338791, 0.12948947, ..., 0.21122885,
         0.2055869 , 0.20531169],
        [0.13224164, 0.13361773, 0.13898446, ..., 0.2022843 ,
         0.19017476, 0.18797302],
        [0.12852621, 0.13444337, 0.144764  , ..., 0.19017476,
         0.18742259, 0.18700977],
        ...,
        [0.24494289, 0.21411861, 0.2105408 , ..., 0.15769918,
         0.16691895, 0.17063437],
        [0.24714462, 0.23489748, 0.22086143, ..., 0.15205725,
         0.16389157, 0.16554287],
        [0.19526628, 0.20090821, 0.19554149, ..., 0.15260768,
         0.164442  , 0.16100179]],

       [[0.15785852, 0.15194054, 0.14271952, ..., 0.22006606,
         0.22873658, 0.2152491 ],
        [0.14712359, 0.14423342, 0.1507019 , ..., 0.20272502,
         0.21043216, 0.20093587],
        [0.14753647, 0.14726122, 0.14987613, ..., 0.19818331,
         0.20217451, 0.20988163],
        ...,
        [0.28901735, 0.23713185, 0.23630609, ..., 0.17547481,
         0.17602532, 0.18620974],
        [0.2

In [56]:
# re-order the dimensions (if needed) and add a batch dimension
tile = arr.transpose(0,1,2)
tile_t = torch.from_numpy(tile).unsqueeze(dim=0)
tile_t.shape

torch.Size([1, 28, 256, 256])

In [59]:
# infer the output, take the max, and remove the batch dimension
model.eval()
pred = model(tile_t).argmax(1).squeeze()

# the predicted crop class per cell
pred

tensor([[9, 9, 9,  ..., 9, 9, 9],
        [9, 9, 9,  ..., 9, 9, 9],
        [9, 9, 9,  ..., 9, 9, 9],
        ...,
        [9, 9, 9,  ..., 9, 9, 9],
        [9, 9, 9,  ..., 9, 9, 9],
        [9, 9, 9,  ..., 9, 9, 9]])

In [60]:
pred.shape

torch.Size([256, 256])