In [None]:
!pip install terratorch==1.0.1 gdown tensorrt onnx onnxruntime polygraphy numpy pycuda numba

In [None]:
import os
import sys
import torch
import gdown
import terratorch
import albumentations
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from pathlib import Path
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
import warnings
import time
import tensorrt as trt
from pytorch_lightning import Trainer


In [None]:
if not os.path.isfile('hls_burn_scars.tar.gz'):
    gdown.download("https://drive.google.com/uc?id=1yFDNlGqGPxkc9lh9l1O70TuejXAQYYtC")

if not os.path.isdir('hls_burn_scars/'):
    !tar -xzvf hls_burn_scars.tar.gz

In [None]:
dataset_path = Path('hls_burn_scars')

In [None]:
datamodule = terratorch.datamodules.GenericNonGeoSegmentationDataModule(
    batch_size=1,
    num_workers=0,
    num_classes=2,

    # Define dataset paths 
    train_data_root=dataset_path / 'data/',
    train_label_data_root=dataset_path / 'data/',
    val_data_root=dataset_path / 'data/',
    val_label_data_root=dataset_path / 'data/',
    test_data_root=dataset_path / 'data/',
    test_label_data_root=dataset_path / 'data/',

    # Define splits
    train_split=dataset_path / 'splits/train.txt',
    val_split=dataset_path / 'splits/val.txt',
    test_split=dataset_path / 'splits/test.txt',
    
    img_grep='*_merged.tif',
    label_grep='*.mask.tif',
    
    train_transform=[
        albumentations.D4(), # Random flips and rotation
        albumentations.pytorch.transforms.ToTensorV2(),
    ],
    val_transform=None,  # Using ToTensor() by default
    test_transform=None,
        
    # Define standardization values
    means=[
      0.0333497067415863,
      0.0570118552053618,
      0.0588974813200132,
      0.2323245113436119,
      0.1972854853760658,
      0.1194491422518656,
    ],
    stds=[
      0.0226913556882377,
      0.0268075602230702,
      0.0400410984436278,
      0.0779173242367269,
      0.0870873883814014,
      0.0724197947743781,
    ],
    no_data_replace=0,
    no_label_replace=-1,
    # We use all six bands of the data, so we don't need to define dataset_bands and output_bands.
)

In [None]:
datamodule.setup("test")
test_dataset = datamodule.test_dataset
len(test_dataset)

In [None]:
def get_model():
    return terratorch.tasks.SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        # Backbone
        "backbone": "prithvi_eo_v2_300", # Model can be either prithvi_eo_v1_100, prithvi_eo_v2_300, prithvi_eo_v2_300_tl, prithvi_eo_v2_600, prithvi_eo_v2_600_tl
        "backbone_pretrained": True,
        "backbone_num_frames": 1, # 1 is the default value,
        "backbone_img_size": 512,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
        # "backbone_coords_encoding": [], # use ["time", "location"] for time and location metadata
        
        # Necks 
        "necks": [
            {
                "name": "SelectIndices",
                # "indices": [2, 5, 8, 11] # indices for prithvi_eo_v1_100
                "indices": [5, 11, 17, 23] # indices for prithvi_eo_v2_300
                # "indices": [7, 15, 23, 31] # indices for prithvi_eo_v2_600
            },
            {"name": "ReshapeTokensToImage",},
            {"name": "LearnedInterpolateToPyramidal"}            
        ],
        
        # Decoder
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        
        # Head
        "head_dropout": 0.1,
        "num_classes": 2,
    },
    
    loss="ce",
    optimizer="AdamW",
    lr=1e-4,
    ignore_index=-1,
    freeze_backbone=True, # Only to speed up fine-tuning
    freeze_decoder=False,
    plot_on_val=True,
    class_names=['no burned', 'burned']  # optionally define class names
    
)

In [None]:
pl.seed_everything(0)

"""
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="output/burnscars/checkpoints/",
    mode="max",
    monitor="val/Multiclass_Jaccard_Index", # Variable to monitor
    filename="best-{epoch:02d}",
)

# Lightning Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices=1, # Deactivate multi-gpu because it often fails in notebooks
    precision='bf16-mixed',  # Speed up training
    num_nodes=1,
    logger=True,  # Uses TensorBoard by default
    max_epochs=1, # For demos
    log_every_n_steps=1,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback, pl.callbacks.RichProgressBar()],
    default_root_dir="output/burnscars",
    detect_anomaly=True,
)
"""
# Model
model = get_model()

In [None]:
if not os.path.isfile('hls_burn_scars.tar.gz'):
    gdown.download("https://drive.google.com/uc?id=1-I_DiiO2T1mjBTi3OAJaVeRWKHtAG63N")

In [None]:
model.load_state_dict(torch.load("checkpoint.pt", map_location=torch.device('cuda')), strict=False)
model.eval()

# Export model to ONNX

In [None]:
if not os.path.isfile('model2.onnx'):
    dummy_input = torch.randn(1, 6, 512, 512)
    original_forward = model.forward
    model.forward = lambda x: original_forward(x).output
    core_model = model.model 
    core_model.eval()
    original_forward = core_model.forward
    core_model.forward = lambda x: original_forward(x).output
    torch.onnx.export(
        core_model,
        dummy_input,
        "model2.onnx",
        export_params=True,
        opset_version=17,  # or 13/17 depending on TRT support
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
    )

# Convert model from ONNX to TensorRT

In [None]:
if not os.path.isfile('model.trt'):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, TRT_LOGGER)
    
    with open("model2.onnx", "rb") as f:
        if not parser.parse(f.read()):
            for i in range(parser.num_errors):
                print(parser.get_error(i))
            raise RuntimeError("ONNX parsing failed")
    
    # Set optimization profile matching 5D input
    input_tensor = network.get_input(0)
    input_name = input_tensor.name
    profile = builder.create_optimization_profile()
    
    # Must be 5D shape: (batch_size, 6, 1, 224, 224)
    min_shape = (1, 6, 512, 512)
    opt_shape = (1, 6, 512, 512)
    max_shape = (1, 6, 512, 512)
    profile.set_shape(input_name, min=min_shape, opt=opt_shape, max=max_shape)
    
    config = builder.create_builder_config()
    config.add_optimization_profile(profile)
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)
    
    # Build engine
    serialized_engine = builder.build_serialized_network(network, config)
    if serialized_engine is None:
        raise RuntimeError("Failed to build engine")
    
    # Save engine
    with open("model.trt", "wb") as f:
        f.write(serialized_engine)


In [None]:
best_ckpt_path = "checkpoint.pt"

In [None]:
def run_test_and_visual_inspection(model,ckpt_path):

    # let's run the model on the test set
    #trainer.test(model, datamodule=datamodule, ckpt_path=ckpt_path)


    start = time.time()
    model = terratorch.tasks.SemanticSegmentationTask.load_from_checkpoint(
        ckpt_path,
        model_factory=model.hparams.model_factory,
        model_args=model.hparams.model_args,
    )    
    end = time.time()
    print(f'batch load time {end - start}')
    # now we can use the model for predictions and plotting!

    
    test_loader = datamodule.test_dataloader()
    
    with torch.no_grad():
        start = time.time()
        batch = next(iter(test_loader))
        end = time.time()
        print(f'batch load time {end - start}')
        images = datamodule.aug(batch)
        images = batch["image"].to(model.device)
        masks = batch["mask"].numpy()

        start = time.time()
        outputs = model(images)
        end = time.time()
        print(f'interence time {end - start}')

In [None]:
run_test_and_visual_inspection(model, best_ckpt_path)

In [None]:
import time
def run_test_and_visual_inspection_tensorrt():
    import numpy as np
    import pycuda.driver as cuda
    import pycuda.autoinit
    import tensorrt as trt
    import torch

    test_loader = datamodule.test_dataloader()

    batch = next(iter(test_loader))
    batch_size = 1  # Ensure batch size is 1 for TensorRT execution context

    # Only use the first sample to keep it compatible with TensorRT engine (batch_size = 1)
    for key in batch:
        batch[key] = batch[key][:batch_size]

    # Apply any augmentation and get input image
    batch = datamodule.aug(batch)
    images = batch["image"].to('cuda')  # Tensor on CUDA
    masks = batch["mask"].cpu().numpy()  # For visualization/comparison later

    # Convert image tensor to NumPy and ensure float32 dtype
    input_data = images.cpu().numpy().astype(np.float32)

    # Define input/output shapes
    input_shape = input_data.shape  # Should be (1, 6, 512, 512)
    output_shape = (batch_size, 2, 512, 512)

    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    start = time.time()
    with open('model.trt', 'rb') as f:
        engine_data = f.read()
    end = time.time()
    print(f'model load time: {end - start}')

    runtime = trt.Runtime(TRT_LOGGER)
    engine = runtime.deserialize_cuda_engine(engine_data)
    context = engine.create_execution_context()

    # Allocate device memory
    d_input = cuda.mem_alloc(int(np.prod(input_shape)) * np.float32().itemsize)
    d_output = cuda.mem_alloc(int(np.prod(output_shape)) * np.float32().itemsize)

    # Transfer input to device
    cuda.memcpy_htod(d_input, input_data)

    start = time.time()
    context.execute_v2([int(d_input), int(d_output)])
    end = time.time()
    print(f'inference time: {end - start}')

    # Get output from device
    output_data = np.empty(output_shape, dtype=np.float32)
    cuda.memcpy_dtoh(output_data, d_output)

    print("Inference completed.")
    print("Output shape:", output_data.shape)

    # Post-process output
    preds = torch.argmax(torch.from_numpy(output_data), dim=1).cpu().numpy()

    # Visual inspection
    for i in range(batch_size):
        sample = {key: batch[key][i] for key in batch}
        sample["prediction"] = preds[i]
        sample["image"] = sample["image"].cpu()
        sample["mask"] = sample["mask"].cpu()
        test_dataset.plot(sample)



In [None]:
run_test_and_visual_inspection_tensorrt()
