In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import CenterCrop, ToTensor
from terratorch.datasets.m_chesapeake_landcover import MChesapeakeLandcoverNonGeo
import numpy as np
from pathlib import Path
from PIL import Image

# ------------------- Config -------------------
root = "./chesapeake"          # local cache
splits = ["de-test"]           # or ["pa-test"], ["md-train"], etc.
patch_size = 224
num_samples = 12
out_dir = Path("./chesapeake_samples")
out_dir.mkdir(parents=True, exist_ok=True)

# ------------------- Class mapping -------------------
class_map = {
    0: "Water",
    1: "Tree canopy",
    2: "Low vegetation",
    3: "Barren",
    4: "Impervious",
    5: "Wetlands",
    6: "Crops",
}

# fixed colors for visualization
class_colors = {
    0: (0, 0, 255),       # blue for water
    1: (34, 139, 34),     # green for tree canopy
    2: (124, 252, 0),     # light green for low vegetation
    3: (210, 180, 140),   # tan for barren
    4: (128, 128, 128),   # gray for impervious
    5: (0, 255, 255),     # cyan for wetlands
    6: (255, 255, 0),     # yellow for crops
}

# ------------------- Transforms -------------------
def chesapeake_transform(sample):
    image, mask = sample["image"], sample["mask"]

    crop = CenterCrop(patch_size)

    # convert both to tensor first
    image = ToTensor()(image)   # (C,H,W), float in [0,1] if 8-bit
    mask = torch.as_tensor(np.array(mask), dtype=torch.long)  # (H,W)

    # apply same crop
    image = crop(image)
    mask = crop(mask.unsqueeze(0)).squeeze(0)  # keep it 2D (H,W)

    return {"image": image, "mask": mask}

# ------------------- Dataset -------------------
dataset = MChesapeakeLandcoverNonGeo(
    data_root='.',
)

loader = DataLoader(dataset, batch_size=1, shuffle=True)

# ------------------- Save samples -------------------
for i, sample in enumerate(loader):
    image = sample["image"].squeeze(0)  # (C,H,W), usually 4 channels
    mask = sample["mask"].squeeze(0)    # (H,W)

    # ---- Save 4-channel image for inference ----
    img4 = (image.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)  # (H,W,4)
    img4_pil = Image.fromarray(img4)
    img4_resized = img4_pil.resize((224, 224), resample=Image.BILINEAR)  # or Image.LANCZOS for even higher quality
    img4_resized.save(out_dir / f"sample_{i}_x.png")

    # ---- Save RGB visualization (natural look) ----
    # Remove batch dimension for each key in the sample dict
    single_sample = {k: v.squeeze(0) for k, v in sample.items()}  # Now single image/mask

    fig = dataset.plot(single_sample)
    fig.savefig(out_dir / f"sample_{i}_z.png")
    import matplotlib.pyplot as plt
    plt.close(fig)







    # ---- Save mask with colors ----
    mask_np = mask.numpy()
    color_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
    for cls_id, color in class_colors.items():
        color_mask[mask_np == cls_id] = color
    Image.fromarray(color_mask).save(out_dir / f"sample_{i}_y.png")

    # Print info
    unique_classes = torch.unique(mask).tolist()
    class_names = [class_map[c] for c in unique_classes]
    print(f"Saved sample {i}: classes {unique_classes} → {class_names}")

    if i + 1 >= num_samples:
        break

print(f"Samples stored in: {out_dir.resolve()}")


In [None]:
import onnxruntime as ort
import numpy as np
import onnx

# Load the ONNX model
model_path = "/home/romeokienzler/gitco/tmtinyonnxwebdemo/src/model_chesapeake.onnx"
session = ort.InferenceSession(model_path)

# Print model input and output info
print("Model Input Info:")
for input_meta in session.get_inputs():
    print(f"  Name: {input_meta.name}, Shape: {input_meta.shape}, Type: {input_meta.type}")

print("\nModel Output Info:")
for output_meta in session.get_outputs():
    print(f"  Name: {output_meta.name}, Shape: {output_meta.shape}, Type: {output_meta.type}")

# Define the expected input shape based on your web app code (224x224 with 4 channels)
input_shape = [1, 4, 224, 224]

# Create a random input tensor with the correct shape and data type
# Using a normal distribution to get a variety of floating-point values
# Scale them to be in the [0, 1] range like the image data
random_input = np.random.randn(*input_shape).astype(np.float32)
random_input = (random_input - random_input.min()) / (random_input.max() - random_input.min())

# Run inference with the random input
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
outputs = session.run([output_name], {input_name: random_input})

# Get the logits (raw output) from the model
logits = outputs[0]

# Print the logits to see if they're still constant
print("\nLogits from Random Input (first 100 values):")
print(logits.flatten()[:100])

# Perform an argmax to find the predicted class for each pixel
# This is a good way to see if there's any pattern in the output
num_classes = logits.shape[1]
num_pixels = logits.shape[2] * logits.shape[3]
logits_reshaped = np.transpose(logits, (0, 2, 3, 1)).reshape(-1, num_classes)
predicted_classes = np.argmax(logits_reshaped, axis=1)

print("\nPredicted Classes from Random Input (first 100 pixels):")
print(predicted_classes[:100])

In [3]:
from terratorch.tasks import SemanticSegmentationTask
from terratorch.datasets.m_chesapeake_landcover import MChesapeakeLandcoverNonGeo
from terratorch.datamodules.m_chesapeake_landcover import MChesapeakeLandcoverNonGeoDataModule
from jsonargparse import ArgumentParser

def load_dm_from_config(config_path: str) -> MChesapeakeLandcoverNonGeoDataModule:
    parser = ArgumentParser()
    parser.add_class_arguments(MChesapeakeLandcoverNonGeoDataModule, "data")
    cfg = parser.parse_path(config_path)
    namespace = parser.instantiate_classes(cfg)
    return namespace.data

def load_task_from_config(config_path: str) -> SemanticSegmentationTask:
    parser = ArgumentParser()
    parser.add_class_arguments(SemanticSegmentationTask, "model")
    cfg = parser.parse_path(config_path)
    namespace = parser.instantiate_classes(cfg)
    return namespace.model

dm = load_dm_from_config('/home/romeokienzler/Downloads/chesapeake_best_iterate_data.yaml')
task = load_task_from_config('/home/romeokienzler/Downloads/chesapeake_best_iterate.yaml')
# restore weights
task = SemanticSegmentationTask.load_from_checkpoint('/home/romeokienzler/Downloads/best-epoch=89-val=0.0000.ckpt', **task.hparams)

model = task.model
model.eval()  # Set the model to evaluation mode

dm.setup("fit")
dm.setup("test")    
sample = dm.test_dataset[0]
sample = dm.aug(sample)
x = sample["image"]   # tensor (C,H,W)
y = sample["mask"]    # tensor (H,W)

print(type(x), x.shape)
print(type(y), y.shape)

#x = x.unsqueeze(0).unsqueeze(2)   # [1, C, 1, H, W]



<class 'torch.Tensor'> torch.Size([1, 4, 224, 224])
<class 'torch.Tensor'> torch.Size([1, 1, 224, 224])


In [None]:
x.shape

In [None]:
import torch
with torch.no_grad():
    output = model(x)

predicted_classes = torch.argmax(output.output, dim=1)
print("Predicted classes for random input:")
print(predicted_classes.flatten()[:10000])
print((predicted_classes.flatten() == 3).all())
hist = torch.histc(predicted_classes.flatten().float(), bins=7, min=-10, max=10)
print(hist)



In [None]:
# ------------------- Save samples -------------------
for i, sample in enumerate(loader):
    image = sample["image"].squeeze(0)  # (C,H,W), usually 4 channels
    mask = sample["mask"].squeeze(0)    # (H,W)

    # ---- Save 4-channel image for inference ----
    img4 = (image.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)  # (H,W,4)
    img4_pil = Image.fromarray(img4)
    img4_resized = img4_pil.resize((224, 224), resample=Image.BILINEAR)  # or Image.LANCZOS for even higher quality
    img4_resized.save(out_dir / f"sample_{i}_x.png")

    # ---- Save RGB visualization (natural look) ----
    # Remove batch dimension for each key in the sample dict
    single_sample = {k: v.squeeze(0) for k, v in sample.items()}  # Now single image/mask

    fig = dataset.plot(single_sample)
    fig.savefig(out_dir / f"sample_{i}_z.png")
    import matplotlib.pyplot as plt
    plt.close(fig)







    # ---- Save mask with colors ----
    mask_np = mask.numpy()
    color_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
    for cls_id, color in class_colors.items():
        color_mask[mask_np == cls_id] = color
    Image.fromarray(color_mask).save(out_dir / f"sample_{i}_y.png")

    # Print info
    unique_classes = torch.unique(mask).tolist()
    class_names = [class_map[c] for c in unique_classes]
    print(f"Saved sample {i}: classes {unique_classes} → {class_names}")

    if i + 1 >= num_samples:
        break

print(f"Samples stored in: {out_dir.resolve()}")


In [None]:
import torch
from pytorch_lightning.utilities.cloud_io import atomic_save
from terratorch.tasks import SemanticSegmentationTask

# Paths
pt_path = "Prithvi_EO_V2_tiny_TL.pt"
ckpt_path = "Prithvi_EO_V2_tiny_TL.ckpt"

# 1. Load raw state dict from .pt
raw = torch.load(pt_path, map_location="cpu")
if "state_dict" in raw:
    state_dict = raw["state_dict"]
else:
    state_dict = raw

# 3. Load backbone weights into task (ignore missing heads if needed)
task.load_state_dict(state_dict, strict=False)

# 4. Build checkpoint dict
checkpoint = {
    "state_dict": task.state_dict(),
    "pytorch-lightning_version": "2.3.0",  # or your installed version
    "hyper_parameters": task.hparams,
    "epoch": 0,
    "global_step": 0,
}

# 5. Save as Lightning checkpoint
atomic_save(checkpoint, ckpt_path)
print(f"Converted {pt_path} → {ckpt_path}")


In [9]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import CenterCrop
import numpy as np
from pathlib import Path
from PIL import Image
from terratorch.datamodules.m_chesapeake_landcover import MChesapeakeLandcoverNonGeoDataModule
from jsonargparse import ArgumentParser

# ------------------- Config -------------------
config_path = '/home/romeokienzler/Downloads/chesapeake_best_iterate_data.yaml'
patch_size = 224
num_samples = 30
out_dir = Path("./chesapeake_samples")
out_dir.mkdir(parents=True, exist_ok=True)

# ------------------- Class mapping -------------------
class_map = {
    0: "Water",
    1: "Tree canopy",
    2: "Low vegetation",
    3: "Barren",
    4: "Impervious",
    5: "Wetlands",
    6: "Crops",
}

class_colors = {
    0: (0, 0, 255),
    1: (34, 139, 34),
    2: (124, 252, 0),
    3: (210, 180, 140),
    4: (128, 128, 128),
    5: (0, 255, 255),
    6: (255, 255, 0),
}

# ------------------- Load DataModule -------------------
def load_dm_from_config(config_path: str) -> MChesapeakeLandcoverNonGeoDataModule:
    parser = ArgumentParser()
    parser.add_class_arguments(MChesapeakeLandcoverNonGeoDataModule, "data")
    cfg = parser.parse_path(config_path)
    namespace = parser.instantiate_classes(cfg)
    return namespace.data

dm = load_dm_from_config(config_path)
dm.setup("test")
loader = DataLoader(dm.test_dataset, batch_size=1, shuffle=True)

# ------------------- Transform -------------------
crop = CenterCrop(patch_size)

# ------------------- Save samples -------------------
for i, sample in enumerate(loader):
    sample = {k: v.squeeze(0) for k, v in sample.items()}

    # Apply the datamodule's augmentation (normalization, etc.)
    sample = dm.aug(sample)

    image = sample["image"]  # (C,H,W)
    mask = sample["mask"]    # (H,W)

    image_np = image.detach().cpu().numpy()
    channel_names = ["RED", "GREEN", "BLUE", "NIR"]
    for c in range(image_np.shape[0]):
        ch_data = image_np[c].flatten()
        hist, bin_edges = np.histogram(ch_data, bins=12)
        print(f"Sample {i}, Channel {channel_names[c]} histogram:")
        for edge_start, edge_end, count in zip(bin_edges[:-1], bin_edges[1:], hist):
            print(f"  {edge_start:.3f} - {edge_end:.3f}: {count}")

    # ---- Run model and save computed mask ----
    x = image
    if x.ndim == 3:          # [C,H,W] → add batch and temporal
        x = x.unsqueeze(0).unsqueeze(2)
    elif x.ndim == 4:        # [B,C,H,W] → add temporal
        x = x.unsqueeze(2)
    print(x.shape)           # should be [B,C,T,H,W]
    with torch.no_grad():
        output = model(x)
        pred_mask = output.output.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.uint8)


        import numpy as np
        from PIL import Image

        # Example class colors
        class_colors = {
            0: (0, 0, 255),
            1: (34, 139, 34),
            2: (124, 252, 0),
            3: (210, 180, 140),
            4: (128, 128, 128),
            5: (0, 255, 255),
            6: (255, 255, 0),
        }

        # Create RGB mask
        color_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
        for cls_id, color in class_colors.items():
            color_mask[pred_mask == cls_id] = color

        # Save as PNG
        Image.fromarray(color_mask).save(out_dir / f"sample_{i}_pred_mask.png")



    # ---- Save 4-channel input as raw .bin ----
    img_bin_path = out_dir / f"sample_{i}_x.bin"
    image.detach().cpu().numpy().astype(np.float32).tofile(img_bin_path)

    # ---- Save RGB preview for browser/grid ----
    rgb_sample = sample
    rgb_sample["image"] = rgb_sample["image"].squeeze(0)
    fig = dm.test_dataset.plot_rgb(rgb_sample)
    fig.savefig(out_dir / f"sample_{i}_rgb.png", dpi=150, bbox_inches='tight')
    fig.clf()  # clear the figure to free memory

    # ---- Save mask with colors ----
    mask_np = mask.detach().cpu().numpy()
    if mask_np.ndim == 4:
        mask_np = mask_np.squeeze(0).squeeze(0)
    elif mask_np.ndim == 3:
        mask_np = mask_np.squeeze(0)

    color_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
    for cls_id, color in class_colors.items():
        color_mask[mask_np == cls_id] = color
    Image.fromarray(color_mask).save(out_dir / f"sample_{i}_y.png")

    # Print info
    unique_classes = torch.unique(mask).tolist()
    class_names = [class_map[c] for c in unique_classes]
    print(f"Saved sample {i}: classes {unique_classes} → {class_names}")

    if i + 1 >= num_samples:
        break

print(f"Samples stored in: {out_dir.resolve()}")


Sample 0, Channel RED histogram:
  -1.727 - -1.344: 490
  -1.344 - -0.962: 2735
  -0.962 - -0.579: 11475
  -0.579 - -0.197: 71303
  -0.197 - 0.185: 20685
  0.185 - 0.568: 33698
  0.568 - 0.950: 33098
  0.950 - 1.333: 15570
  1.333 - 1.715: 7840
  1.715 - 2.097: 2095
  2.097 - 2.480: 1128
  2.480 - 2.862: 587
torch.Size([1, 4, 1, 224, 224])
Saved sample 0: classes [2, 3, 5, 6] → ['Low vegetation', 'Barren', 'Wetlands', 'Crops']
Sample 1, Channel RED histogram:
  -1.865 - -1.584: 1096
  -1.584 - -1.303: 3932
  -1.303 - -1.022: 7490
  -1.022 - -0.741: 15484
  -0.741 - -0.460: 29006
  -0.460 - -0.179: 39698
  -0.179 - 0.102: 31131
  0.102 - 0.383: 18155
  0.383 - 0.664: 11149
  0.664 - 0.945: 16451
  0.945 - 1.226: 26824
  1.226 - 1.507: 288
torch.Size([1, 4, 1, 224, 224])
Saved sample 1: classes [2] → ['Low vegetation']
Sample 2, Channel RED histogram:
  -1.535 - -1.178: 3995
  -1.178 - -0.820: 12973
  -0.820 - -0.463: 23094
  -0.463 - -0.106: 27349
  -0.106 - 0.251: 59099
  0.251 - 0.609

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>

<Figure size 400x400 with 0 Axes>