In [8]:
import os

ARCHITECTURE = 'unet'
os.environ['ARCHITECTURE'] = ARCHITECTURE

In [9]:
from landnet.modelling.segmentation.models import (
    DeepLabV3ResNet50Builder,
    FCNResNet50Builder,
    UNetBuilder,
)
from pathlib import Path
from landnet.enums import GeomorphometricalVariable, Mode
from landnet.features.tiles import TileConfig, TileSize
from landnet.features.grids import get_grid_for_variable
import torch
from landnet.modelling.segmentation.lightning import (
    LandslideImageSegmenter,
    LandslideImageSegmentationDataModule,
)
from landnet.modelling.dataset import (
    get_default_mask_transform,
    get_default_transform,
    get_default_augment_transform,
)
from landnet.modelling.tune import MetricSorter
from landnet.modelling.segmentation.dataset import (
    ConcatLandslideImageSegmentation,
    LandslideImageSegmentation,
)
from landnet.modelling.segmentation.inference import Infer
from landnet.modelling import torch_clear
import lightning as L
from landnet.typing import TuneSpace
import typing as t
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

torch_clear()

In [10]:
variables = [
    GeomorphometricalVariable('shade'),
    GeomorphometricalVariable('tpi'),
    GeomorphometricalVariable('dem'),
    GeomorphometricalVariable('nego'),
    GeomorphometricalVariable('tri'),
    GeomorphometricalVariable('eastness'),
    GeomorphometricalVariable('clo'),
    GeomorphometricalVariable('area'),
    GeomorphometricalVariable('slope'),
    GeomorphometricalVariable('croto'),
]
train_tile_config = TileConfig(TileSize(100, 100), overlap=0)
train_model_config: TuneSpace = {
    'batch_size': 4,
    'learning_rate': 0.000001,
    'tile_config': train_tile_config,
}

test_tile_config = TileConfig(TileSize(100, 100), overlap=0)
test_model_config: TuneSpace = {
    'batch_size': 4,
    'tile_config': test_tile_config,
}

In [11]:
train_grids = [
    get_grid_for_variable(
        variable,
        tile_config=train_tile_config,
        mode=Mode.TRAIN,
    )
    for variable in variables
]

validation_grids = [
    get_grid_for_variable(
        variable,
        tile_config=test_tile_config,
        mode=Mode.VALIDATION,
    )
    for variable in variables
]

train_dataset = ConcatLandslideImageSegmentation(
    landslide_images=[
        LandslideImageSegmentation(
            grid,
            Mode.TRAIN,
            transform=get_default_transform(),
            mask_transform=get_default_mask_transform(),
        )
        for grid in train_grids
    ],
    augment_transform=get_default_augment_transform(),
    # augment_transform=None,
)

validation_dataset = ConcatLandslideImageSegmentation(
    landslide_images=[
        LandslideImageSegmentation(
            grid,
            Mode.VALIDATION,
            transform=get_default_transform(),
            mask_transform=get_default_mask_transform(),
        )
        for grid in validation_grids
    ],
    augment_transform=None,
)

# train_dataset, validation_dataset = torch.utils.data.random_split(
#     dataset, (0.7, 0.3)
# )
# t.cast(
#     ConcatLandslideImageSegmentation, train_dataset
# ).augment_transform = get_default_augment_transform()

test_grids = [
    get_grid_for_variable(
        variable,
        tile_config=test_tile_config,
        mode=Mode.TEST,
    )
    for variable in variables
]
test_dataset = ConcatLandslideImageSegmentation(
    landslide_images=[
        LandslideImageSegmentation(
            grid,
            Mode.TEST,
            transform=get_default_transform(),
            mask_transform=get_default_mask_transform(),
        )
        for grid in test_grids
    ],
    augment_transform=None,
)



In [12]:
model = UNetBuilder(len(variables), 2).build(
    in_channels=len(variables), mode=Mode.TRAIN
)
dm = LandslideImageSegmentationDataModule(
    train_model_config,
    variables,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    validation_dataset=validation_dataset,
)
segmenter = LandslideImageSegmenter(train_model_config, model, 2)
trainer = L.Trainer(
    enable_checkpointing=True,
    callbacks=[EarlyStopping(monitor='val_mIoU', mode='max', patience=5)],
    max_epochs=10,
)

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs


In [None]:
model = trainer.fit(model=segmenter, datamodule=dm)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type                         | Params | Mode 
-------------------------------------------------------------------------
0 | model           | Unet                         | 51.5 M | train
1 | criterion       | CrossEntropyLoss             | 0      | train
2 | train_metrics   | SegmentationMetricCollection | 0      | train
3 | val_metrics     | SegmentationMetricCollection | 0      | train
4 | test_metrics    | SegmentationMetricCollection | 0      | train
5 | predict_metrics | SegmentationMetricCollection | 0      | train
-------------------------------------------------------------------------
51.5 M    Trainable params
0         Non-trainable params
51.5 M    Total params
206.141   Total estimated model params size (MB)
376       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

In [7]:
trainer.test(model=segmenter, dataloaders=dm)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_dice_score_epoch': 0.5503259897232056,
  'test_generalized_dice_score_epoch': 0.4524320363998413,
  'test_mIoU_epoch': 0.3366027772426605,
  'test_loss_epoch': 0.6645846366882324}]

In [None]:
infer = Infer(variables, test_model_config)
parent = Path(
    '/media/alex/alex/python-modules-packages-utils/landnet/notebooks/lightning_logs'
)
last_version = sorted(
    map(lambda x: int(x.name.split('_')[-1]), parent.glob('version*'))
)[-1]
last_version = '16'
ckpt = Path(
    f'/media/alex/alex/python-modules-packages-utils/landnet/notebooks/lightning_logs/version_251/checkpoints/epoch=37-step=4750.ckpt'
)
infer.handle_checkpoint(
    ckpt,
    model=UNetBuilder(len(variables), 2).build(
        in_channels=len(variables), mode=Mode.TRAIN
    ),
)

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                                                                                 …

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                                                                                 …

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                                                                                 …