In [1]:
import os
ARCHITECTURE = 'unet'
os.environ['ARCHITECTURE'] = ARCHITECTURE

In [3]:
from pathlib import Path
import itertools
import reprlib
import os

import numpy as np
import torch
from torchvision.transforms.v2.functional import InterpolationMode, resize

from landnet.modelling.segmentation.models import UNetBuilder
from landnet.modelling.dataset import get_default_transform
from landnet.modelling.segmentation.lightning import LandslideImageSegmentation, LandslideImageSegmenter
from landnet.modelling.segmentation.dataset import ConcatLandslideImageSegmentation
from landnet.enums import GeomorphometricalVariable, Mode
from landnet.features.tiles import TileConfig, TileSize
from landnet.features.grids import Grid
from landnet.config import GRIDS, PROCESSED_DATA_DIR
from landnet.typing import TuneSpace



In [4]:
ckpt = Path(f'/media/alex/alex/python-modules-packages-utils/landnet/notebooks/lightning_logs/version_251/checkpoints/epoch=37-step=4750.ckpt')
grids_dir = GRIDS / 'laki'
outdir = PROCESSED_DATA_DIR / 'segmentation_results' / 'laki'
os.makedirs(outdir, exist_ok=True)

variables = [
    GeomorphometricalVariable('shade'),
    GeomorphometricalVariable('tpi'),
    GeomorphometricalVariable('dem'),
    GeomorphometricalVariable('nego'),
    GeomorphometricalVariable('tri'),
    GeomorphometricalVariable('eastness'),
    GeomorphometricalVariable('clo'),
    GeomorphometricalVariable('area'),
    GeomorphometricalVariable('slope'),
    GeomorphometricalVariable('croto'),
]
tile_config = TileConfig(TileSize(100, 100), overlap=0)
grid_paths = [(grids_dir / variable.value).with_suffix('.tif') for variable in variables]
grids = [Grid(path, tile_config, Mode.INFERENCE) for path in grid_paths]
model_config: TuneSpace = {
    'batch_size': 4,
    'learning_rate': 0.00001,
    'tile_config': tile_config,
}

In [5]:
def get_segmenter():
    return LandslideImageSegmenter.load_from_checkpoint(
        ckpt,
        model=UNetBuilder(len(variables), 2).build(
            in_channels=len(variables), mode=Mode.INFERENCE
        ),
        tune_space=model_config
    )
def write_indices(grid: Grid, indices: list[int], masks: torch.Tensor):
    assert masks.shape[0] == len(indices)
    for i, j in enumerate(indices):
        mask = masks.select(0, i)
        resized_mask = resize(
            mask.unsqueeze(0),
            [grid.tile_config.size.height, grid.tile_config.size.width],
            interpolation=InterpolationMode.NEAREST_EXACT
        )
        grid.write_tile(j, resized_mask.cpu().numpy(), prefix=str(j), out_dir=outdir)
    
length = grids[0].get_tiles_length()
start_index = 16400
transform = get_default_transform()
segmenter = get_segmenter()
segmenter.eval()


batch_size = 64

with torch.no_grad():
    for indices in itertools.batched(range(start_index, length), batch_size):
        print(f'Starting indices {reprlib.repr(indices)}')
        samples = {}
        for i in indices:
            images = [transform(grid.get_tile(i)[1].squeeze(0)) for grid in grids]
            if np.isnan(images[0]).all():
                continue
            samples[i] = torch.cat(images, dim=0).unsqueeze(0)
        if not samples:
            continue
        batch = torch.cat(list(samples.values()), dim=0).to(segmenter.device, non_blocking=True)
        masks = segmenter(batch).softmax(dim=1).argmax(dim=1)
        write_indices(grids[0], list(samples), masks)

Starting indices (16400, 16401, 16402, 16403, 16404, 16405, ...)


  if np.isnan(images[0]).all():


Starting indices (16464, 16465, 16466, 16467, 16468, 16469, ...)
Starting indices (16528, 16529, 16530, 16531, 16532, 16533, ...)
Starting indices (16592, 16593, 16594, 16595, 16596, 16597, ...)
Starting indices (16656, 16657, 16658, 16659, 16660, 16661, ...)
Starting indices (16720, 16721, 16722, 16723, 16724, 16725, ...)
Starting indices (16784, 16785, 16786, 16787, 16788, 16789, ...)
Starting indices (16848, 16849, 16850, 16851, 16852, 16853, ...)
Starting indices (16912, 16913, 16914, 16915, 16916, 16917, ...)
Starting indices (16976, 16977, 16978, 16979, 16980, 16981, ...)
Starting indices (17040, 17041, 17042, 17043, 17044, 17045, ...)
Starting indices (17104, 17105, 17106, 17107, 17108, 17109, ...)
Starting indices (17168, 17169, 17170, 17171, 17172, 17173, ...)
Starting indices (17232, 17233, 17234, 17235, 17236, 17237, ...)
Starting indices (17296, 17297, 17298, 17299, 17300, 17301, ...)
Starting indices (17360, 17361, 17362, 17363, 17364, 17365, ...)
Starting indices (17424, 