In [1]:
import sys; sys.path.append('..')
import numpy as np
import torch
from osgeo import gdal
from main import UNet
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.contrib import itertools

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.10 (you have 1.4.8). Upgrade using: pip install --upgrade albumentations
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# RGB model
# CHECKPOINT = '../experiments/exp/lightning_logs/version_0/checkpoints/epoch=99-step=1400.ckpt'

# IRRG model
CHECKPOINT = '../experiments/irrg/lightning_logs/version_1/checkpoints/epoch=199-step=2800.ckpt'

TARGET = '/tank2/home/public/iceplant/predict/Dangermond2018_RGIR.tif'
OUTPUT = '/tank2/home/public/iceplant/predict/Dangermond2018_RGIR_pred.tif'

In [3]:
model = UNet.load_from_checkpoint(CHECKPOINT)

In [4]:
patch_size = 128
overlap = 32


transforms = A.Compose([
    A.ToFloat(max_value=255),
    ToTensorV2(),
])

# open the target raster
ds = gdal.Open(TARGET)
transform = ds.GetGeoTransform()
projection = ds.GetProjection()
x_size = ds.RasterXSize
y_size = ds.RasterYSize
num_bands = ds.RasterCount

# create the output raster
driver = gdal.GetDriverByName('GTiff')
# use compression
out_ds = driver.Create(OUTPUT, x_size, y_size, 1, gdal.GDT_Byte, options=['COMPRESS=LZW']) 
out_ds.SetGeoTransform(transform)
out_ds.SetProjection(projection)


def read_patch(dataset, x_off, y_off, x_size, y_size):
    patch = dataset.ReadAsArray(x_off, y_off, x_size, y_size)
    return patch.transpose(1, 2, 0).astype(np.uint8)

def write_patch(dataset, data, x_off, y_off):
    dataset.WriteArray(data.astype(np.uint8), x_off, y_off)


model.eval()

stride = patch_size - overlap

for y, x in itertools.product(range(0, y_size, stride), range(0, x_size, stride)):
    if (x + patch_size) > x_size or (y + patch_size) > y_size:
        # Handle edge cases
        x_patch_size = min(patch_size, x_size - x)
        y_patch_size = min(patch_size, y_size - y)
        patch = read_patch(ds, x, y, x_patch_size, y_patch_size)
        patch = np.pad(patch, ((0, patch_size - y_patch_size), (0, patch_size - x_patch_size), (0, 0)), 'reflect')
    else:
        x_patch_size = y_patch_size = patch_size
        patch = read_patch(ds, x, y, patch_size, patch_size)

    patch = transforms(image=patch)['image'].unsqueeze(0).to(model.device)

    with torch.no_grad():
        pred = model.model(patch)
        pred = torch.argmax(pred, dim=1).squeeze().cpu().numpy()

    if (x + patch_size) <= x_size and (y + patch_size) <= y_size:
        pred = pred[overlap//2:-overlap//2, overlap//2:-overlap//2]

    # Handling the case where the patch size is different
    pred = pred[:y_patch_size, :x_patch_size]

    write_patch(out_ds, pred, x, y)


out_ds.FlushCache()
del out_ds

100%|██████████| 62304/62304 [05:41<00:00, 182.29it/s]
