Please run prithvi_v2_eo_300_tl_unet_burnscars_tensorrt.ipynb first

In [None]:
!pip install terratorch==1.0.1 gdown tensorrt onnx onnxruntime polygraphy numpy pycuda numba

In [None]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
import time
import tensorrt as trt
from pytorch_lightning import Trainer


In [None]:
dataset_path = Path('hls_burn_scars')
datamodule = terratorch.datamodules.GenericNonGeoSegmentationDataModule(
    batch_size=1,
    num_workers=0,
    num_classes=2,

    # Define dataset paths 
    train_data_root=dataset_path / 'data/',
    train_label_data_root=dataset_path / 'data/',
    val_data_root=dataset_path / 'data/',
    val_label_data_root=dataset_path / 'data/',
    test_data_root=dataset_path / 'data/',
    test_label_data_root=dataset_path / 'data/',

    # Define splits
    train_split=dataset_path / 'splits/train.txt',
    val_split=dataset_path / 'splits/val.txt',
    test_split=dataset_path / 'splits/test.txt',
    
    img_grep='*_merged.tif',
    label_grep='*.mask.tif',
    
    train_transform=[
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,
        
    # Define standardization values
    means=[
      0.0333497067415863,
      0.0570118552053618,
      0.0588974813200132,
      0.2323245113436119,
      0.1972854853760658,
      0.1194491422518656,
    ],
    stds=[
      0.0226913556882377,
      0.0268075602230702,
      0.0400410984436278,
      0.0779173242367269,
      0.0870873883814014,
      0.0724197947743781,
    ],
    no_data_replace=0,
    no_label_replace=-1,
    # We use all six bands of the data, so we don't need to define dataset_bands and output_bands.
)

In [None]:
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

In [None]:
def get_model():
    return terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "prithvi_eo_v2_300", # Model can be either prithvi_eo_v1_100, prithvi_eo_v2_300, prithvi_eo_v2_300_tl, prithvi_eo_v2_600, prithvi_eo_v2_600_tl
        "backbone_pretrained": True,
        "backbone_num_frames": 1, # 1 is the default value,
        "backbone_img_size": 512,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        # "backbone_coords_encoding": [], # use ["time", "location"] for time and location metadata
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                # "indices": [2, 5, 8, 11] # indices for prithvi_eo_v1_100
                "indices": [5, 11, 17, 23] # indices for prithvi_eo_v2_300
                # "indices": [7, 15, 23, 31] # indices for prithvi_eo_v2_600
            },
            {"name": "ReshapeTokensToImage",},
            {"name": "LearnedInterpolateToPyramidal"}            
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 2,
    },
    
    loss="ce",
    optimizer="AdamW",
    lr=1e-4,
    ignore_index=-1,
    freeze_backbone=True, # Only to speed up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
    class_names=['no burned', 'burned']  # optionally define class names
    
)

In [None]:
pl.seed_everything(0)
model = get_model()

In [None]:
model.load_state_dict(torch.load("checkpoint.pt", map_location=torch.device('cuda')), strict=False)
model.eval()

# Compute pytorch inference metrics

In [None]:
import torch
from pytorch_lightning import LightningModule
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score

class ClassifierWrapper(LightningModule):
    def __init__(self, base_model: LightningModule):
        super().__init__()
        self.base_model = base_model.eval()
        self.accuracy = BinaryAccuracy(ignore_index=-1)
        self.precision = BinaryPrecision(ignore_index=-1)
        self.recall = BinaryRecall(ignore_index=-1)
        self.f1 = BinaryF1Score(ignore_index=-1)

    def forward(self, x):
        return self.base_model(x)

    def test_step(self, batch, batch_idx):
        x = batch["image"]  # (B, C, H, W)
        y = batch["mask"]   # (B, H, W)
    
        with torch.no_grad():
            output = self.forward(x)
            probs = output.output  # (B, 1, H, W)
            preds = torch.argmax(probs, dim=1)  # (B, H, W) as torch.long

        preds_flat = preds.reshape(-1)
        y_flat = y.reshape(-1)

        
        self.accuracy(preds_flat, y_flat)
        self.precision(preds_flat, y_flat)
        self.recall(preds_flat, y_flat)
        self.f1(preds_flat, y_flat)


    def on_test_epoch_end(self):
        self.log("test/accuracy", self.accuracy.compute(), prog_bar=True)
        self.log("test/precision", self.precision.compute(), prog_bar=True)
        self.log("test/recall", self.recall.compute(), prog_bar=True)
        self.log("test/f1", self.f1.compute(), prog_bar=True)

        self.accuracy.reset()
        self.precision.reset()
        self.recall.reset()
        self.f1.reset()


In [None]:
classifier = ClassifierWrapper(model)

trainer = Trainer(accelerator="auto", devices=1 if torch.cuda.is_available() else None)
trainer.test(classifier, dataloaders=datamodule.test_dataloader())

# Compute TensorRT inference metrics

In [None]:
import torch
import numpy as np
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
import time

def run_tensorrt_evaluation(datamodule, engine_path='model.trt'):
    # Initialize metrics
    accuracy = BinaryAccuracy(ignore_index=-1).cuda()
    precision = BinaryPrecision(ignore_index=-1).cuda()
    recall = BinaryRecall(ignore_index=-1).cuda()
    f1 = BinaryF1Score(ignore_index=-1).cuda()

    test_loader = datamodule.test_dataloader()
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

    # Load engine
    with open(engine_path, 'rb') as f:
        engine_data = f.read()
    runtime = trt.Runtime(TRT_LOGGER)
    engine = runtime.deserialize_cuda_engine(engine_data)
    context = engine.create_execution_context()

    # Assume fixed output shape (binary class logits: 2 classes)
    output_shape = (1, 2, 512, 512)
    output_size_bytes = int(np.prod(output_shape)) * np.float32().itemsize
    d_output = cuda.mem_alloc(output_size_bytes)

    for batch in test_loader:
        # Ensure batch size 1
        for key in batch:
            batch[key] = batch[key][:1]

        batch = datamodule.aug(batch)
        images = batch["image"].to('cuda')  # (1, C, H, W)
        masks = batch["mask"].to('cuda')    # (1, H, W)

        input_data = images.cpu().numpy().astype(np.float32)
        input_shape = input_data.shape
        input_size_bytes = int(np.prod(input_shape)) * np.float32().itemsize
        d_input = cuda.mem_alloc(input_size_bytes)

        # Transfer input
        cuda.memcpy_htod(d_input, input_data)

        context.execute_v2([int(d_input), int(d_output)])
        cuda.Context.synchronize()

        # Get output
        output_data = np.empty(output_shape, dtype=np.float32)
        cuda.memcpy_dtoh(output_data, d_output)

        # Process output
        preds = torch.argmax(torch.from_numpy(output_data), dim=1).to('cuda')  # (1, H, W)

        # Flatten for metrics
        preds_flat = preds.view(-1)
        masks_flat = masks.view(-1)

        # Apply metrics
        accuracy(preds_flat, masks_flat)
        precision(preds_flat, masks_flat)
        recall(preds_flat, masks_flat)
        f1(preds_flat, masks_flat)

    # Final metrics
    print("\n=== TensorRT Evaluation Metrics ===")
    print(f"Accuracy:  {accuracy.compute().item():.4f}")
    print(f"Precision: {precision.compute().item():.4f}")
    print(f"Recall:    {recall.compute().item():.4f}")
    print(f"F1 Score:  {f1.compute().item():.4f}")

    # Reset metrics
    accuracy.reset()
    precision.reset()
    recall.reset()
    f1.reset()


In [None]:
run_tensorrt_evaluation(datamodule)