In [1]:
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoModel
import matplotlib.pyplot as plt
from PIL import Image
from tifffile import imread

In [2]:

model_name = "facebook/dinov3-vitl16-pretrain-sat493m"  # smaller + sat-trained

processor = AutoImageProcessor.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_name).eval().to(device)

print("device:", device)

device: cuda


In [3]:
labels = imread("/run/media/mak/Partition of 1TB disk/SH_dataset/planet_labels_2022.tif")


In [4]:
images_folder = "/home/mak/PycharmProjects/SegEdge/experiments/get_data_from_api/patches_mt/"
labels = imread("/run/media/mak/Partition of 1TB disk/SH_dataset/planet_labels_2022.tif")


In [5]:
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.mask import mask
from shapely.geometry import box
import glob as glob


def subset_label_to_image(img_path, lab_path):
    # --- 1. Open the image (this defines the reference grid) ---
    with rasterio.open(img_path) as src_img:
        img_bounds = src_img.bounds
        img_meta   = src_img.meta.copy()
        img_crs    = src_img.crs

    # --- 2. Open the label file ---
    with rasterio.open(lab_path) as src_lab:

        # If CRS mismatch → reproject labels onto image grid
        if src_lab.crs != img_crs:
            # prepare an in-memory raster with same grid as img
            new_meta = img_meta.copy()
            new_meta.update(
                dtype = src_lab.dtypes[0],
                count = src_lab.count,
            )
            with rasterio.io.MemoryFile() as mem:
                with mem.open(**new_meta) as dst:
                    for i in range(1, src_lab.count + 1):
                        reproject(
                            source=rasterio.band(src_lab, i),
                            destination=rasterio.band(dst, i),
                            src_transform=src_lab.transform,
                            src_crs=src_lab.crs,
                            dst_transform=img_meta["transform"],
                            dst_crs=img_crs,
                            dst_width=img_meta["width"],
                            dst_height=img_meta["height"],
                            resampling=Resampling.nearest,
                        )
                    labels_aligned = dst.read()  # (bands, H, W)
        else:
            # CRS is already same → just mask/crop
            geom = [box(*img_bounds).__geo_interface__]
            labels_aligned, _ = mask(src_lab, geom, crop=True)

    return labels_aligned



#Create image,label pairs
image_paths = glob.glob(images_folder + "*.tif")
print(f"Found {len(image_paths)} images")

image_label_pairs = []
for img_path in image_paths:
    img_labels = subset_label_to_image(img_path, "/run/media/mak/Partition of 1TB disk/SH_dataset/planet_labels_2022.tif")
    image_label_pairs.append((img_path, img_labels))

print(f"Created {len(image_label_pairs)} image-label pairs")



Found 96 images
Created 96 image-label pairs


In [6]:
from func import *


img = imread(image_label_pairs[0][0])  # (H, W, C)
feats, hp, hw = extract_patch_features_single_scale(img, model, processor, device)



In [10]:
import numpy as np

def labels_to_patch_labels(label_hw, ps=16):
    """
    label_hw: (H, W) numpy array with {0,1}
    returns: patch_labels (Hp, Wp) with {0,1}
    """
    H, W = label_hw.shape
    Hc = (H // ps) * ps
    Wc = (W // ps) * ps
    label_crop = label_hw[:Hc, :Wc]              # ensure multiple of ps

    Hp = Hc // ps
    Wp = Wc // ps
    patch_blocks = label_crop.reshape(Hp, ps, Wp, ps)  # (Hp, ps, Wp, ps)
    patch_blocks = patch_blocks.transpose(0, 2, 1, 3)  # (Hp, Wp, ps, ps)

    patch_mean = patch_blocks.mean(axis=(2, 3))        # (Hp, Wp)
    patch_labels = (patch_mean > 0.5).astype(np.int64)
    return patch_labels


In [11]:
from torch.utils.data import Dataset
import torch
from tifffile import imread

class DinoPatchDataset(Dataset):
    def __init__(self, image_label_pairs, model, processor, device, ps=16):
        self.pairs = image_label_pairs
        self.model = model
        self.processor = processor
        self.device = device
        self.ps = ps

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, label_raster = self.pairs[idx]
        img = imread(img_path)                    # (H, W, C)
        feats, Hp, Wp = extract_patch_features_single_scale(
            img, self.model, self.processor, self.device, ps=self.ps
        )
        labels_hw = label_raster[0] if label_raster.ndim == 3 else label_raster
        patch_labels = labels_to_patch_labels(labels_hw, ps=self.ps)  # (Hp, Wp)

        x = torch.from_numpy(feats).permute(2, 0, 1).float()  # (C, Hp, Wp)
        y = torch.from_numpy(patch_labels).long()             # (Hp, Wp)
        return x, y


import torch.nn as nn

class PatchHead1x1(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        # x: (B, C, Hp, Wp)
        return self.conv(x)  # (B, K, Hp, Wp)



In [12]:



from torch.utils.data import DataLoader
import torch.optim as optim
import torch

dataset = DinoPatchDataset(image_label_pairs, model, processor, device, ps=16)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

C = feats.shape[-1]        # or read from model config
K = 2                      # woody / non-woody
head = PatchHead1x1(C, K).to(device)

optimizer = optim.Adam(head.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(20):
    head.train()
    for x, y in loader:
        x = x.to(device)                # (B, C, Hp, Wp)
        y = y.to(device)                # (B, Hp, Wp)

        optimizer.zero_grad()
        logits = head(x)                # (B, K, Hp, Wp)
        loss = criterion(logits, y)
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

        loss.backward()
        optimizer.step()


KeyboardInterrupt: 