# OpenEarhMap Semantinc Segmentation

This demo code demonstrates training and testing of UNet-EfficientNet-B4 for the OpenEarthMap dataset (https://open-earth-map.org/). This demo code is based on the work from the "segmentation_models.pytorch" repository by qubvel, available at: https://github.com/qubvel/segmentation_models.pytorch. We extend our sincere appreciation to the original author for their invaluable contributions to the field of semantic segmentation and for providing this open-source implementation.

---

### Requirements

In [8]:
%cd open-earth-map/
%ls

/home/jovyan/open-earth-map
[0m[01;34mGeoSeg[0m/           OpenEarthMap.zip   deep_lab.ipynb  [01;34moutputs[0m/      xBD.zip
[01;34mLSKNet[0m/           [01;34mOpenEarthMap_old[0m/  [01;34mdemo[0m/           [01;34mpredictions[0m/  [01;34mxBD_old[0m/
OEM_demo.ipynb    [01;34mSFA-Net[0m/           example.png     [01;34msource[0m/
OEM_lsknet.ipynb  [01;34maerial-former[0m/     [01;34mmodels[0m/         [01;34mtest[0m/
[01;34mOpenEarthMap[0m/     compile_xbd.py     oem.ipynb       [01;34mxBD[0m/


In [9]:
!pip install torch
!pip install rasterio
!pip install pretrainedmodels
!pip install efficientnet_pytorch
!pip install timm
!pip install albumentations
!pip install segmentation_models_pytorch

Collecting rasterio
  Using cached rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Using cached affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Using cached cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Using cached click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Using cached rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
Using cached cligj-0.7.2-py3-none-any.whl (7.1 kB)
Using cached affine-2.4.0-py3-none-any.whl (15 kB)
Using cached click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.4.3
Collecting pretrainedmodels
  Using cached pretrainedmodels-0.7.4-py3-none-any.whl
Collecting munch (from pretrainedmodels)
  Using cached munch-4

In [10]:
# MM Rotate Installation

!pip install -U openmim
!mim install mmcv-full
!mim install mmdet
!pip install mmrotate

Collecting openmim
  Downloading openmim-0.3.9-py2.py3-none-any.whl.metadata (16 kB)
Collecting model-index (from openmim)
  Downloading model_index-0.1.11-py3-none-any.whl.metadata (3.9 kB)
Collecting opendatalab (from openmim)
  Downloading opendatalab-0.0.10-py3-none-any.whl.metadata (6.4 kB)
Collecting rich (from openmim)
  Using cached rich-13.9.4-py3-none-any.whl.metadata (18 kB)
Collecting tabulate (from openmim)
  Downloading tabulate-0.9.0-py3-none-any.whl.metadata (34 kB)
Collecting markdown (from model-index->openmim)
  Using cached Markdown-3.7-py3-none-any.whl.metadata (7.0 kB)
Collecting ordered-set (from model-index->openmim)
  Downloading ordered_set-4.1.0-py3-none-any.whl.metadata (5.3 kB)
Collecting pycryptodome (from opendatalab->openmim)
  Downloading pycryptodome-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting openxlab (from opendatalab->openmim)
  Downloading openxlab-0.1.2-py3-none-any.whl.metadata (3.8 kB)
Collecting 

### Import
---

In [None]:
import sys
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import source
import glob
import torchvision.transforms.functional as TF
import math
import cv2
from PIL import Image
import time
import warnings
from pathlib import Path

warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [24]:
from models.lsknet import LSKNet

smp.encoders.encoders["lsk_net"] = {
    "encoder": LSKNet,
    "params": {
        "img_size": 224,
        "in_chans": 3,
        "embed_dims": [64, 128, 256, 512],
        "mlp_ratios": [8, 8, 4, 4],
        "drop_rate": 0.0,
        "drop_path_rate": 0.0,
        "norm_layer": partial(nn.LayerNorm, eps=1e-6),
        "depths": [3, 4, 6, 3],
        "num_stages": 4,
        "pretrained": None,
        "init_cfg": None,
        "norm_cfg": None,
    },
}

ModuleNotFoundError: No module named 'mmcv'

### Define main parameters

In [12]:
OEM_ROOT = "./demo/"
OEM_DATA_DIR = "OpenEarthMap/"
TRAIN_LIST = OEM_DATA_DIR+"train.txt"
VAL_LIST = OEM_DATA_DIR+"val.txt"
TEST_LIST = OEM_DATA_DIR+"test.txt"
WEIGHT_DIR = OEM_ROOT+"weight" # path to save weights
OUT_DIR = OEM_ROOT+"result/" # path to save prediction images
os.makedirs(WEIGHT_DIR, exist_ok=True)

seed = 0
learning_rate = 0.0001
batch_size = 4
n_epochs = 5
classes = [1, 2, 3, 4, 5, 6, 7, 8]
n_classes = len(classes)+1
classes_wt = np.ones([n_classes], dtype=np.float32)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Number of epochs   :", n_epochs)
print("Number of classes  :", n_classes)
print("Batch size         :", batch_size)
print("Device             :", device)

Number of epochs   : 5
Number of classes  : 9
Batch size         : 4
Device             : cuda


### Prepare training and validation file lists

In this demo for Google Colab, we use only two regions, i.e., Tokyo and Kyoto for training. To train with the full set, please download the OpenEarthMap dataset from https://zenodo.org/record/7223446. Note for xBD data preparation is available at https://github.com/bao18/open_earth_map.

In [13]:
img_pths = [f for f in Path(OEM_DATA_DIR).rglob("*.tif") if "/labels/" in str(f)]

train_test_pths = [str(f) for f in img_pths if f.name in np.loadtxt(TRAIN_LIST, dtype=str)]
val_pths = [str(f) for f in img_pths if f.name in np.loadtxt(VAL_LIST, dtype=str)]

print("Total samples            :", len(img_pths))
print("Training/Testing samples :", len(train_test_pths))
print("Validation samples       :", len(val_pths))

Total samples            : 3501
Training/Testing samples : 3000
Validation samples       : 500


In [14]:
import random

random.shuffle(train_test_pths)

training_pths = train_test_pths[:2500]
testing_pths = train_test_pths[2500:]

print(f"Training list contains {len(training_pths)} elements.")
print(f"Testing list contains {len(testing_pths)} elements.")

Training list contains 2500 elements.
Testing list contains 500 elements.


### Define training and validation dataloaders

In [15]:
val_pths[0]

'OpenEarthMap/aachen/labels/aachen_11.tif'

In [16]:
trainset = source.dataset.Dataset(training_pths, classes=classes, size=512, train=True)
validset = source.dataset.Dataset(val_pths, classes=classes, train=False)
testset = source.dataset.Dataset(testing_pths, classes=classes, train=False)

train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)

### Setup network

In [17]:
network = smp.Unet(
    classes=n_classes,
    activation=None,
    encoder_weights="imagenet",
    encoder_name="lsk_net",
    decoder_attention_type="scse",
)

# count parameters
params = 0
for p in network.parameters():
    if p.requires_grad:
        params += p.numel()

criterion = source.losses.CEWithLogitsLoss(weights=classes_wt)
criterion_name = 'CE'
metric = source.metrics.IoU2()
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
network_fout = f"{network.name}_s{seed}_{criterion.name}"
OUT_DIR += network_fout # path to save prediction images
os.makedirs(OUT_DIR, exist_ok=True)

print("Model output name  :", network_fout)
print("Number of parameters: ", params)

if torch.cuda.device_count() > 1:
    print("Number of GPUs :", torch.cuda.device_count())
    network = torch.nn.DataParallel(network)
    optimizer = torch.optim.Adam(
        [dict(params=network.module.parameters(), lr=learning_rate)]
    )

KeyError: "Wrong encoder name `lsk_net`, supported encoders: ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_32x8d', 'resnext101_32x16d', 'resnext101_32x32d', 'resnext101_32x48d', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'inceptionresnetv2', 'inceptionv4', 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', 'mobilenet_v2', 'xception', 'timm-efficientnet-b0', 'timm-efficientnet-b1', 'timm-efficientnet-b2', 'timm-efficientnet-b3', 'timm-efficientnet-b4', 'timm-efficientnet-b5', 'timm-efficientnet-b6', 'timm-efficientnet-b7', 'timm-efficientnet-b8', 'timm-efficientnet-l2', 'timm-tf_efficientnet_lite0', 'timm-tf_efficientnet_lite1', 'timm-tf_efficientnet_lite2', 'timm-tf_efficientnet_lite3', 'timm-tf_efficientnet_lite4', 'timm-resnest14d', 'timm-resnest26d', 'timm-resnest50d', 'timm-resnest101e', 'timm-resnest200e', 'timm-resnest269e', 'timm-resnest50d_4s2x40d', 'timm-resnest50d_1s4x24d', 'timm-res2net50_26w_4s', 'timm-res2net101_26w_4s', 'timm-res2net50_26w_6s', 'timm-res2net50_26w_8s', 'timm-res2net50_48w_2s', 'timm-res2net50_14w_8s', 'timm-res2next50', 'timm-regnetx_002', 'timm-regnetx_004', 'timm-regnetx_006', 'timm-regnetx_008', 'timm-regnetx_016', 'timm-regnetx_032', 'timm-regnetx_040', 'timm-regnetx_064', 'timm-regnetx_080', 'timm-regnetx_120', 'timm-regnetx_160', 'timm-regnetx_320', 'timm-regnety_002', 'timm-regnety_004', 'timm-regnety_006', 'timm-regnety_008', 'timm-regnety_016', 'timm-regnety_032', 'timm-regnety_040', 'timm-regnety_064', 'timm-regnety_080', 'timm-regnety_120', 'timm-regnety_160', 'timm-regnety_320', 'timm-skresnet18', 'timm-skresnet34', 'timm-skresnext50_32x4d', 'timm-mobilenetv3_large_075', 'timm-mobilenetv3_large_100', 'timm-mobilenetv3_large_minimal_100', 'timm-mobilenetv3_small_075', 'timm-mobilenetv3_small_100', 'timm-mobilenetv3_small_minimal_100', 'timm-gernet_s', 'timm-gernet_m', 'timm-gernet_l', 'mit_b0', 'mit_b1', 'mit_b2', 'mit_b3', 'mit_b4', 'mit_b5', 'mobileone_s0', 'mobileone_s1', 'mobileone_s2', 'mobileone_s3', 'mobileone_s4']"

### Visualization functions

In [9]:
class_rgb = {
    "Bareland": [128, 0, 0],
    "Grass": [0, 255, 36],
    "Pavement": [148, 148, 148],
    "Road": [255, 255, 255],
    "Tree": [34, 97, 38],
    "Water": [0, 69, 255],
    "Cropland": [75, 181, 73],
    "buildings": [222, 31, 7],
}

class_gray = {
    "Bareland": 1,
    "Grass": 2,
    "Pavement": 3,
    "Road": 4,
    "Tree": 5,
    "Water": 6,
    "Cropland": 7,
    "buildings": 8,
}

def label2rgb(a):
    """
    a: labels (HxW)
    """
    out = np.zeros(shape=a.shape + (3,), dtype="uint8")
    for k, v in class_gray.items():
        out[a == v, 0] = class_rgb[k][0]
        out[a == v, 1] = class_rgb[k][1]
        out[a == v, 2] = class_rgb[k][2]
    return out

### Training

In [10]:
start = time.time()

max_score = 0
train_hist = []
valid_hist = []

for epoch in range(n_epochs):
  print(f"\nEpoch: {epoch + 1}")

  logs_train = source.runner.train_epoch(
      model=network,
      optimizer=optimizer,
      criterion=criterion,
      metric=metric,
      dataloader=train_loader,
      device=device,
  )

  logs_valid = source.runner.valid_epoch(
      model=network,
      criterion=criterion,
      metric=metric,
      dataloader=valid_loader,
      device=device,
  )

  train_hist.append(logs_train)
  valid_hist.append(logs_valid)

  score = logs_valid[metric.name]

  if max_score < score:
      max_score = score
      torch.save(network.state_dict(), os.path.join(WEIGHT_DIR, f"{network_fout}.pth"))
      print("Model saved!")

end = time.time()
print('Processing time:',end - start)


Epoch: 1


Train: 100%|██████████| 750/750 [47:28<00:00,  3.80s/it, CELoss=1.54, mIoU=0.244] 
Valid: 100%|██████████| 125/125 [12:15<00:00,  5.88s/it, CELoss=1.11, mIoU=0.337]


Model saved!

Epoch: 2


Train: 100%|██████████| 750/750 [46:12<00:00,  3.70s/it, CELoss=1.14, mIoU=0.369]
Valid: 100%|██████████| 125/125 [09:11<00:00,  4.41s/it, CELoss=0.912, mIoU=0.393]


Model saved!

Epoch: 3


Train: 100%|██████████| 750/750 [45:46<00:00,  3.66s/it, CELoss=1.01, mIoU=0.4]  
Valid: 100%|██████████| 125/125 [08:46<00:00,  4.21s/it, CELoss=0.836, mIoU=0.411]


Model saved!

Epoch: 4


Train: 100%|██████████| 750/750 [47:03<00:00,  3.76s/it, CELoss=0.919, mIoU=0.425]
Valid: 100%|██████████| 125/125 [09:02<00:00,  4.34s/it, CELoss=0.815, mIoU=0.447]


Model saved!

Epoch: 5


Train: 100%|██████████| 750/750 [45:57<00:00,  3.68s/it, CELoss=0.863, mIoU=0.462]
Valid: 100%|██████████| 125/125 [09:14<00:00,  4.43s/it, CELoss=0.772, mIoU=0.466]

Model saved!
Processing time: 16860.144696712494





### Testing


In [10]:
import os
import zipfile

def zip_files(output_dir, zip_filename):
    with zipfile.ZipFile(zip_filename, 'w') as zipf:
        for root, _, files in os.walk(output_dir):
            for file in files:
                file_path = os.path.join(root, file)
                zipf.write(file_path, os.path.relpath(file_path, output_dir))
                os.remove(file_path)


In [44]:
# load network
network.load_state_dict(torch.load(os.path.join(WEIGHT_DIR, f"{network_fout}.pth")))
network.to(device).eval()

for fn_img in testing_pths:
  fn_img = fn_img.replace("/labels/", "/images/")
  img = source.dataset.load_multiband(fn_img)
  h, w = img.shape[:2]
  power = math.ceil(np.log2(h) / np.log2(2))
  shape = (2 ** power, 2 ** power)
  img = cv2.resize(img, shape)

  # test time augmentation
  imgs = []
  imgs.append(img.copy())
  imgs.append(img[:, ::-1, :].copy())
  imgs.append(img[::-1, :, :].copy())
  imgs.append(img[::-1, ::-1, :].copy())

  input = torch.cat([TF.to_tensor(x).unsqueeze(0) for x in imgs], dim=0).float().to(device)

  pred = []
  with torch.no_grad():
      msk = network(input)
      msk = torch.softmax(msk[:, :, ...], dim=1)
      msk = msk.cpu().numpy()
      pred = (msk[0, :, :, :] + msk[1, :, :, ::-1] + msk[2, :, ::-1, :] + msk[3, :, ::-1, ::-1])/4
  pred = pred.argmax(axis=0).astype("uint8")
  size = pred.shape[0:]
  y_pr = cv2.resize(pred, (w, h), interpolation=cv2.INTER_NEAREST)

  # save image as png
  filename = os.path.splitext(os.path.basename(fn_img))[0]
  Image.fromarray(y_pr).save(os.path.join(OUT_DIR, filename+'.tif'))


In [46]:
import numpy as np
from PIL import Image

def calculate_iou(gt_mask, pred_mask, num_classes):
    """Calculates Intersection over Union (IoU) for each class.

    Args:
        gt_mask: Ground truth segmentation mask (NumPy array).
        pred_mask: Predicted segmentation mask (NumPy array).
        num_classes: Number of classes.

    Returns:
        A NumPy array containing IoU for each class.
    """

    iou_per_class = np.zeros(num_classes)
    for class_id in range(num_classes):
        intersection = np.sum((gt_mask == class_id) & (pred_mask == class_id))
        union = np.sum((gt_mask == class_id) | (pred_mask == class_id))

        if union == 0:  # Handle cases where the union is zero (prevent division by zero)
            iou_per_class[class_id] = 0.0  # Or np.nan if you prefer
        else:
            iou_per_class[class_id] = intersection / union

    return iou_per_class


def calculate_miou(gt_mask, pred_mask, num_classes):
    """Calculates mean Intersection over Union (mIoU).

    Args:
        gt_mask: Ground truth segmentation mask (NumPy array).
        pred_mask: Predicted segmentation mask (NumPy array).
        num_classes: Number of classes.

    Returns:
        The mIoU score.
    """
    iou_per_class = calculate_iou(gt_mask, pred_mask, num_classes)
    miou = np.nanmean(iou_per_class) #Handle nan values that could arise from classes not present

    return miou


def load_mask(image_path):
    """Loads a segmentation mask from an image file.

    Args:
        image_path: Path to the image file.

    Returns:
        A NumPy array representing the mask.
    """
    mask = Image.open(image_path)
    mask = np.array(mask)
    return mask

def calculate_miou_for_dataset(gt_dir, pred_dir, num_classes):
    """Calculates mIoU over a dataset of images.

    Args:
        gt_dir: Path to the directory containing ground truth masks.
        pred_dir: Path to the directory containing predicted masks.
        num_classes: Number of classes.

    Returns:
        The mean mIoU over the dataset.
        A dictionary containing the mIoU for each image.
    """

    image_names = set(os.listdir(gt_dir)) & set(os.listdir(pred_dir)) #Find common file names
    miou_scores = []
    image_miou_dict = {}

    if not image_names:
        raise ValueError("No matching files found in the ground truth and prediction directories.")

    for image_name in image_names:
        gt_path = os.path.join(gt_dir, image_name)
        pred_path = os.path.join(pred_dir, image_name)

        try:
            gt_mask = load_mask(gt_path)
            pred_mask = load_mask(pred_path)

            if gt_mask.shape != pred_mask.shape:
                print(f"Warning: Masks for {image_name} have different shapes. Skipping.")
                continue #skip this image

            miou = calculate_miou(gt_mask, pred_mask, num_classes)
            miou_scores.append(miou)
            image_miou_dict[image_name] = miou

        except FileNotFoundError:
            print(f"Warning: File not found: {image_name}. Skipping.")
        except Exception as e:
            print(f"An error occurred processing {image_name}: {e}. Skipping.")

    if not miou_scores: #if no valid images were processed
      return 0.0, {}

    dataset_miou = np.mean(miou_scores)
    return dataset_miou, image_miou_dict

# Example paths (replace with your actual paths)
gt_dir = "test/labels"  # Directory containing multiple GT masks
pred_dir = OUT_DIR    # Directory containing multiple predicted masks
num_classes = 8

try:
    dataset_miou, image_miou_dict = calculate_miou_for_dataset(gt_dir, pred_dir, num_classes)

    print(f"Dataset mIoU: {dataset_miou}")

    if image_miou_dict: #Print individual miou only if dictionary is not empty
      print("mIoU per image:")
      for image_name, miou in image_miou_dict.items():
          print(f"- {image_name}: {miou}")

except ValueError as e:
    print(f"ValueError: {e}")
except Exception as e:
    print(f"An error occurred: {e}")


An error occurred processing .ipynb_checkpoints: [Errno 21] Is a directory: '/home/jovyan/open-earth-map/test/labels/.ipynb_checkpoints'. Skipping.
Dataset mIoU: 0.26474113901459717
mIoU per image:
- aachen_4.tif: 0.26405045335720734
- aachen_5.tif: 0.29852855377949916
- aachen_1.tif: 0.312184203275812
- aachen_2.tif: 0.24481229375161148
- aachen_6.tif: 0.20413019090885565


### Testing a model for a large Geotiff image

A sample image is provided by the Geospatial Information Authority of Japan at https://cyberjapandata.gsi.go.jp/xyz/seamlessphoto/{z}/{x}/{y}.jpg


In [12]:
start = time.time()

# load network
network.load_state_dict(torch.load(os.path.join(WEIGHT_DIR, f"{network_fout}.pth")))
network.to(device).eval()

test_large = TEST_DIR+"35_1_op_2023.jpg"

# process large Geotiff image
img0 = source.dataset.load_multiband(test_large)

# get crs and transform
crs, trans = source.dataset.get_crs(test_large)

if img0.shape[2] > 3:
    img0 = img0[:, :, :3]
width = img0.shape[1]
band = img0.shape[2]

patch_size = 512
stride = 256
C = int(np.ceil( (width - patch_size) / stride ) + 1)
R = int(np.ceil( (height - patch_size) / stride ) + 1)

# weight matrix B for avoiding boundaries of patches
if patch_size > stride:
    w = patch_size
    s1 = stride
    s2 = w - s1
    d = 1/(1+s2)
    B1 = np.ones((w,w))
    B1[:,s1::] = np.dot(np.ones((w,1)),(-np.arange(1,s2+1)*d+1).reshape(1,s2))
    B2 = np.flip(B1)
    B3 = B1.T
    B4 = np.flip(B3)
    B = B1*B2*B3*B4
else:
    B = np.ones((w,w))

img1 = np.zeros((patch_size+stride*(R-1), patch_size+stride*(C-1),3))
img1[0:height,0:width,:] = img0.copy()

pred_all = np.zeros((9,patch_size+stride*(R-1), patch_size+stride*(C-1)))
weight = np.zeros((patch_size+stride*(R-1), patch_size+stride*(C-1)))

for r in range(R):
    for c in range(C):
        img = img1[r*stride:r*stride+patch_size,c*stride:c*stride+patch_size,:].copy().astype(np.float32)/255
        imgs = []
        imgs.append(img.copy())
        imgs.append(img[:, ::-1, :].copy())
        imgs.append(img[::-1, :, :].copy())
        imgs.append(img[::-1, ::-1, :].copy())

        input = torch.cat([TF.to_tensor(x).unsqueeze(0) for x in imgs], dim=0).float().to(device)

        pred = []
        with torch.no_grad():
            msk = network(input)
            msk = torch.softmax(msk[:, :, ...], dim=1)
            msk = msk.cpu().numpy()

            pred = (msk[0, :, :, :] + msk[1, :, :, ::-1] + msk[2, :, ::-1, :] + msk[3, :, ::-1, ::-1])/4

        pred_all[:,r*stride:r*stride+patch_size,c*stride:c*stride+patch_size] += pred.copy()*B
        weight[r*stride:r*stride+patch_size,c*stride:c*stride+patch_size] += B

for b in range(9):
    pred_all[b,:,:] = pred_all[b,:,:]/weight
    if b == 0:
        pred_all[b,:,:] = 0

pred_all = pred_all.argmax(axis=0).astype("uint8")

filename = os.path.splitext(os.path.basename(test_large))[0]
pr_rgb = label2rgb(pred_all)
Image.fromarray(pr_rgb[0:height,0:width,:]).save(os.path.join(OUT_DIR, filename+'_pr.png'))

# save geotiff
pr_rgb = np.transpose(pr_rgb[0:height,0:width,:], (2,0,1))
source.dataset.save_img(os.path.join(OUT_DIR, filename+'_pr.tif'),pr_rgb,crs,trans)

end = time.time()
print('Processing time:',end - start)

NameError: name 'network' is not defined

In [9]:
start = time.time()

# Load network
network.load_state_dict(torch.load(os.path.join(WEIGHT_DIR, f"{network_fout}.pth")))
network.to(device).eval()

# Load large GeoTIFF image and metadata
test_large = TEST_DIR + "35_1_op_2023.jpg"
img0 = source.dataset.load_multiband(test_large)
crs, trans = source.dataset.get_crs(test_large)

# Ensure 3-band (RGB) data
if img0.shape[2] > 3:
    img0 = img0[:, :, :3]

height, width, _ = img0.shape
patch_size = 512

# Padding image to ensure even division into patches
pad_h = (patch_size - height % patch_size) % patch_size
pad_w = (patch_size - width % patch_size) % patch_size
img_padded = np.pad(img0, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')

# Prepare output arrays
pred_all = np.zeros((9, height + pad_h, width + pad_w), dtype=np.float32)

# Divide into non-overlapping patches
patches = [
    img_padded[r:r+patch_size, c:c+patch_size].astype(np.float32) / 255
    for r in range(0, img_padded.shape[0], patch_size)
    for c in range(0, img_padded.shape[1], patch_size)
]

# Batch process patches
batch_size = 8  # Adjust based on GPU memory
for i in range(0, len(patches), batch_size):
    print("processing patch ", i, "\n")
    batch_patches = patches[i:i+batch_size]
    
    # Test time augmentation
    augmented_patches = []
    for img in batch_patches:
        augmented_patches.extend([
            img.copy(),  # Original
            img[:, ::-1, :].copy(),  # Flip horizontally
            img[::-1, :, :].copy(),  # Flip vertically
            img[::-1, ::-1, :].copy()  # Flip both axes
        ])
    
    input_tensor = torch.cat([TF.to_tensor(x).unsqueeze(0) for x in augmented_patches], dim=0).float().to(device)
    
    with torch.no_grad():
        msk = network(input_tensor)
        msk = torch.softmax(msk, dim=1).cpu().numpy()
    
    # Aggregate predictions
    for idx, img in enumerate(batch_patches):
        r, c = divmod(i + idx, img_padded.shape[1] // patch_size)
        pred = (
            msk[idx*4] +
            msk[idx*4+1][:, :, ::-1] +
            msk[idx*4+2][:, ::-1, :] +
            msk[idx*4+3][:, ::-1, ::-1]
        ) / 4
        
        pred_all[:, r*patch_size:(r+1)*patch_size, c*patch_size:(c+1)*patch_size] = pred

# Finalize predictions
pred_all = pred_all.argmax(axis=0).astype("uint8")[:height, :width]

# Save outputs
filename = os.path.splitext(os.path.basename(test_large))[0]
pr_rgb = label2rgb(pred_all)
Image.fromarray(pr_rgb).save(os.path.join(OUT_DIR, filename + '_pr.png'))

# Save GeoTIFF
# pr_rgb = np.transpose(pr_rgb, (2, 0, 1))
# source.dataset.save_img(os.path.join(OUT_DIR, filename + '_pr.tif'), pr_rgb, crs, trans)

end = time.time()
print("Processing time:", end - start)


processing patch  0 

processing patch  8 

processing patch  16 

processing patch  24 

processing patch  32 

processing patch  40 

processing patch  48 

processing patch  56 

processing patch  64 

processing patch  72 

processing patch  80 

processing patch  88 

processing patch  96 

processing patch  104 

processing patch  112 

processing patch  120 

processing patch  128 

processing patch  136 

processing patch  144 

processing patch  152 

processing patch  160 

processing patch  168 

processing patch  176 

processing patch  184 

processing patch  192 

processing patch  200 

processing patch  208 

processing patch  216 

processing patch  224 

processing patch  232 

processing patch  240 

processing patch  248 

processing patch  256 

processing patch  264 

processing patch  272 

processing patch  280 

processing patch  288 

processing patch  296 

processing patch  304 

processing patch  312 

processing patch  320 

processing patch  328 

process