This script takes the output tiles from gjengroingMetaModelDownloadImageTiles.Rmd and passes them to the meta model for canopy height inference. Please see https://github.com/facebookresearch/HighResCanopyHeight for information on installing model dependencies.

In [None]:
import argparse
import os
import torch
from pathlib import Path
import torch.nn as nn
import rasterio
import torch.nn.functional as F
from models.backbone import SSLVisionTransformer
from models.dpt_head import DPTHead
import pytorch_lightning as pl
from models.regressor import RNet
import torchvision.transforms.functional as TF
import numpy as np

First we have a function that blends areas in overlapping regions of the tiles. If your output has strong edge effects, try adding overlap when downloading the tiles. This function reduces these edge effects by ensuring adjacent tiles are better matched after normalisation

In [None]:
def blend_overlap(region1, region2, weight=0.5):
    return (weight * region1 + (1 - weight) * region2)

This step performs image normalisation. First tiles are passed to a pretrained model that modifies the aerial image histogram to match that of Maxar imagery (what the model was trained on) and then performs a normalisation step using precalculated band averages and standard deviation. 

In [None]:
def normalize_image(image, model_norm, device, overlap=None, weight=0.5):
    if image.shape[0] != 3:
        raise ValueError(f"Expected 3-channel image, but got {image.shape[0]} channels.")

    x = torch.unsqueeze(image, dim=0).to(device)

    # Use model to predict p5 and p95 values
    norm_img = model_norm(x).detach()

    p5I = [norm_img[0][0].item(), norm_img[0][1].item(), norm_img[0][2].item()]
    p95I = [norm_img[0][3].item(), norm_img[0][4].item(), norm_img[0][5].item()]

    p5In = [np.percentile(image[i, :, :].numpy(), 20) for i in range(3)]
    p95In = [np.percentile(image[i, :, :].numpy(), 80) for i in range(3)]

    normIn = image.clone()
    
    for i in range(3):
        normIn[i, :, :] = (image[i, :, :] - p5In[i]) * ((p95I[i] - p5I[i]) / (p95In[i] - p5In[i])) + p5I[i]
    
    # If overlap region is provided, blend the overlap
    if overlap is not None:
        for i in range(3):
            overlap_region = overlap[i, :, :]
            normIn[i, :, :] = blend_overlap(normIn[i, :, :], overlap_region, weight=weight)

    return normIn

# Apply final normalization using the general image normalization values
def final_normalize(image, norm_values):
    norm = T.Normalize(*norm_values)
    return norm(image)

Model definitions

In [None]:
class SSLAE(nn.Module):
    def __init__(self, pretrained=None, classify=True, n_bins=256, huge=False):
        super().__init__()
        if huge:
            self.backbone = SSLVisionTransformer(
                embed_dim=1280,
                num_heads=20,
                out_indices=(9, 16, 22, 29),
                depth=32,
                pretrained=pretrained
            )
            self.decode_head = DPTHead(
                classify=classify,
                in_channels=(1280, 1280, 1280, 1280),
                embed_dims=1280,
                post_process_channels=[160, 320, 640, 1280],
            )
        else:
            self.backbone = SSLVisionTransformer(pretrained=pretrained)
            self.decode_head = DPTHead(classify=classify, n_bins=256)

    def forward(self, x):
        x = self.backbone(x)
        x = self.decode_head(x)
        return x

class SSLModule(pl.LightningModule):
    def __init__(self, ssl_path="compressed_SSLbaseline.pth"):
        super().__init__()

        if 'huge' in ssl_path:
            self.chm_module_ = SSLAE(classify=True, huge=True).eval()
        else:
            self.chm_module_ = SSLAE(classify=True, huge=False).eval()

        if 'compressed' in ssl_path:
            ckpt = torch.load(ssl_path, map_location='cpu')
            self.chm_module_ = torch.quantization.quantize_dynamic(
                self.chm_module_,
                {torch.nn.Linear, torch.nn.Conv2d, torch.nn.ConvTranspose2d},
                dtype=torch.qint8)
            self.chm_module_.load_state_dict(ckpt, strict=False)
        else:
            ckpt = torch.load(ssl_path)
            state_dict = ckpt['state_dict']
            self.chm_module_.load_state_dict(state_dict)

        self.chm_module = lambda x: 9.51 * self.chm_module_(x) + 0.4

    def forward(self, x):
        x = self.chm_module(x)
        return x

This function sets the model up with the tiles we downloaded with the R script. 

In [None]:
def evaluate(model, model_norm, norm_values, name, bs=32, device='cuda:0', display=False):
    ds = NorwayDataset(src_img_dir='data/encroachment/metamodel/tiles')
    dataloader = torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=False, num_workers=8, collate_fn=lambda x: tuple(zip(*x)))

    for batch in dataloader:
        images, img_paths = batch

        # Convert list of images to a single tensor
        images = torch.stack(images).to(device)

        # Ensure the dimensions are in the correct order (channels first)
        images = images.permute(0, 2, 3, 1)  # From [batch, height, width, channels] to [batch, channels, height, width]

        # Step 1: Apply histogram normalization (matching satellite images)
        images = torch.stack([normalize_image(img, model_norm, device) for img in images])

        # Step 2: Apply final normalization (standard image normalization)
        images = torch.stack([final_normalize(img, norm_values) for img in images])

        pred = model(images)
        pred = pred.cpu().detach().relu()

        for ind in range(pred.shape[0]):
            img_path = img_paths[ind]

            with rasterio.open(img_path) as src:
                meta = src.meta.copy()
                transform = src.transform  # Save the transform to preserve the location
                crs = src.crs  # Save the CRS

            # Update meta to reflect the number of layers
            meta.update(dtype=rasterio.float32, count=1)

            # Save the prediction as GeoTIFF
            output_path = Path(name) / f'{img_path.stem}_pred.tif'
            with rasterio.open(output_path, 'w', **meta) as dst:
                dst.write(pred[ind].numpy()[0], 1)  # Writing the single prediction band
                dst.transform = transform  # Ensure the correct geotransform is used
                dst.crs = crs  # Ensure the correct CRS is used

            if display:
                plt.imshow(pred[ind][0].numpy())
                plt.title(f'Prediction for {img_path.name}')
                plt.show()

This defines the dataset to be passed to the model and removes null tiles. 

In [None]:
class NorwayDataset(torch.utils.data.Dataset):
    def __init__(self, src_img_dir='data/encroachment/metamodel/tiles'):
        self.src_img_dir = Path(src_img_dir)
        self.img_files = list(self.src_img_dir.glob('*.tif'))  # Adjust extension if needed

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, i):
        img_path = self.img_files[i]
        
        # Use rasterio to read the GeoTIFF
        with rasterio.open(img_path) as src:
            img = TF.to_tensor(src.read([1, 2, 3]))  # Read the first 3 channels (assuming it's RGB)
        
        # Check if the image is all zeros
        if torch.all(img == 0):
            return self.__getitem__((i + 1) % len(self.img_files))  # Get the next image

        return img, img_path

This sets the model arguments and runs it. 

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description='test a model')
    parser.add_argument('--checkpoint', type=str, help='CHM pred checkpoint file', default='data/encroachment/metamodel/saved_checkpoints/compressed_SSLhuge_aerial.pth')
    parser.add_argument('--name', type=str, help='run name', default='data/encroachment/metamodel/output_inference')
    parser.add_argument('--display', type=bool, help='saving outputs in images')
    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    if 'compressed' in args.checkpoint:
        device = 'cpu'
    else:
        device = 'cpu'

    os.makedirs(args.name, exist_ok=True)

    # Load SSL model
    model = SSLModule(ssl_path=args.checkpoint)
    model.to(device)
    model = model.eval()

    # Load normalization model (RNet for histogram matching)
    norm_path = 'data/encroachment/metamodel/saved_checkpoints/aerial_normalization_quantiles_predictor.ckpt'
    ckpt = torch.load(norm_path, map_location=device)
    model_norm = RNet(n_classes=6).to(device).eval()
    state_dict = ckpt['state_dict']
    for k in list(state_dict.keys()):
        if 'backbone.' in k:
            new_k = k.replace('backbone.', '')
            state_dict[new_k] = state_dict.pop(k)
    model_norm.load_state_dict(state_dict)

    # Image normalization values
    norm_values = ((0.420, 0.411, 0.296), (0.213, 0.156, 0.143))

    evaluate(model=model, model_norm=model_norm, norm_values=norm_values, name=args.name, bs=16, device=device, display=args.display)

if __name__ == '__main__':
    main()