In [1]:
import sys
import os
from typing import List
sys.path.insert(0, os.path.abspath('..'))

%load_ext autoreload
%autoreload 2

In [2]:
VALIDATION_SITES = ["Ken_Banda", "Sone_Rohtas"]
TILE_SIZE = 110

EXPERIMENT_DIR = '../out/0725-more-data'

In [3]:
from google.cloud import storage

GCP_PROJECT_NAME = "gee-sand"
BUCKET_NAME = "sand_mining"

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

In [6]:
import albumentations as A
from rastervision.pytorch_learner import SemanticSegmentationSlidingWindowGeoDataset

from utils.rastervision_utils import create_scene
from utils.data_management import observation_factory
from utils.schemas import ObservationPointer

from config import CLASS_CONFIG_BINARY_SAND as class_config
from config import S2_CHANNELS


augmentation = A.Compose([
    A.RandomRotate90(),
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.CoarseDropout(max_height=32, max_width=32, max_holes=3)
])

def observation_to_scene(observation: ObservationPointer):
    return create_scene(
        observation.uri_to_bs,
        observation.uri_to_annotations,
        observation.name,
        S2_CHANNELS,
        class_config,
    )

def is_validation(scene):
    for validation_site_iter in VALIDATION_SITES:
        if validation_site_iter in scene.id:
            return True
    return False

def is_training(scene):
    return not is_validation(scene)

def scene_to_validation_ds(scene):
    return SemanticSegmentationSlidingWindowGeoDataset(
        scene,
        size=TILE_SIZE,
        stride=TILE_SIZE,
        padding=TILE_SIZE,
        pad_direction='end',
        transform=None,
    )

def scene_to_training_ds(scene):
    return SemanticSegmentationSlidingWindowGeoDataset(
        scene,
        size=TILE_SIZE,
        stride=int(TILE_SIZE / 2),
        padding=None,
        pad_direction='both',
        transform=augmentation,
    )


all_observations = observation_factory(gcp_client)
all_scenes = list(map(observation_to_scene, all_observations))

training_scenes = filter(is_training, all_scenes)
validation_scenes = filter(is_validation, all_scenes)

training_datasets = list(map(scene_to_training_ds, training_scenes))
validation_datasets = list(map(scene_to_validation_ds, validation_scenes))


2023-07-25 16:49:12:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining/labels/Betwa_Hamirpur_79-81_25-91/bs/Betwa_Hamirpur_79-81_25-91_2022-03-22_bs.tif.
2023-07-25 16:49:12:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining/labels/Betwa_Hamirpur_79-81_25-91/annotations/Betwa_Hamirpur_79-81_25-91_2022-03-22_annotations.geojson.
2023-07-25 16:49:12:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining/labels/Betwa_Hamirpur_79-81_25-91/bs/Betwa_Hamirpur_79-81_25-91_2022-04-01_bs.tif.
2023-07-25 16:49:12:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining/labels/Betwa_Hamirpur_79-81_25-91/annotations/Betwa_Hamirpur_79-81_25-91_2022-04-01_annotations.geojson.
2023-07-25 16:49:12:rastervision.pipeline.file_system.ut

In [None]:
from utils.visualizing import show_windows, show_image_in_dataset
from config import DISPLAY_GROUPS

ds_to_visualize = validation_datasets[2]
show_windows(
    ds_to_visualize.scene.raster_source[:,:],
    ds_to_visualize.windows
)
show_image_in_dataset(
    ds_to_visualize,
    class_config,
    DISPLAY_GROUPS
)

In [10]:
from torch.utils.data import ConcatDataset

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)

## Train

In [12]:
from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig, SolverConfig, SemanticSegmentationLearnerConfig, SemanticSegmentationLearner

from models.unet.unet_small import UNetSmall

_, _, n_channels = training_datasets[0].scene.raster_source.shape
n_classes = 2

model = UNetSmall(n_channels, n_classes)

data_cfg = SemanticSegmentationGeoDataConfig(
    class_names=class_config.names,
    class_colors=class_config.colors,
    num_workers=0,
)

solver_cfg = SolverConfig(
    batch_sz=64,
    lr=3e-2,
    class_loss_weights=[1., 10]
)

learner_cfg = SemanticSegmentationLearnerConfig(data=data_cfg, solver=solver_cfg)

learner = SemanticSegmentationLearner(
    cfg=learner_cfg,
    output_dir=EXPERIMENT_DIR,
    model=model,
    train_ds=train_dataset_merged,
    valid_ds=val_dataset_merged,
)
learner.log_data_stats()

2023-07-25 16:50:18:rastervision.pytorch_learner.learner: INFO - Loading checkpoint from ../out/0725-more-data/last-model.pth
2023-07-25 16:50:18:rastervision.pytorch_learner.learner: INFO - train_ds: 9643 items
2023-07-25 16:50:18:rastervision.pytorch_learner.learner: INFO - valid_ds: 3048 items


In [14]:
learner.train(epochs=9)

2023-07-25 17:04:05:rastervision.pytorch_learner.learner: INFO - Resuming training from epoch 21
2023-07-25 17:04:05:rastervision.pytorch_learner.learner: INFO - epoch: 21
Training:   0%|          | 0/151 [00:00<?, ?it/s]

## Evaluate

In [14]:
from rastervision.core.data import SemanticSegmentationLabels
from rastervision.pytorch_learner import SemanticSegmentationSlidingWindowGeoDataset, SemanticSegmentationLearner


def get_predictions_for_site(learner: SemanticSegmentationLearner, ds: SemanticSegmentationSlidingWindowGeoDataset):
    predictions = learner.predict_dataset(
        ds,
        raw_out=True,
        numpy_out=True,
        progress_bar=True
    )
    pred_labels = SemanticSegmentationLabels.from_predictions(
        ds.windows,
        predictions,
        smooth=True,
        extent=ds.scene.extent,
        num_classes=len(class_config)
    )
    scores = pred_labels.get_score_arr(pred_labels.extent)
    predicted_mine_probability = scores[class_config.get_class_id('sandmine')]
    return predicted_mine_probability

#### Save prediction visualizations

In [None]:
from utils.visualizing import show_rgb_labels_preds

for idx, ds in enumerate(val_datasets):
    predicted_mine_probability = get_predictions_for_site(learner, ds)
    show_rgb_labels_preds(
        ds.scene.raster_source[:, :],
        ds.scene.label_source[:, :],
        predicted_mine_probability,
        ds.scene.id,
        savefig_path=f"{EXPERIMENT_DIR}/result_images/{idx}_{ds.scene.id}.png"
    )