In [None]:
import sys
import os
from typing import List
sys.path.insert(0, os.path.abspath('..'))

%load_ext autoreload
%autoreload 2

In [None]:
from google.cloud import storage
from project_config import GCP_PROJECT_NAME

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

### Configuration

In [None]:
VALIDATION_SITES = ["Ken_Banda", "Sone_Rohtas"]
TILE_SIZE = 110

import albumentations as A
AUGMENTATION_TRAIN = A.Compose([
    A.RandomRotate90(),
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.CoarseDropout(max_height=32, max_width=32, max_holes=3)
])

EXPERIMENT_DIR = '../out/0725-more-data'

### Create Rastervision datasets

In [None]:
from torch.utils.data import ConcatDataset

from utils.rastervision_utils import observation_to_scene, scene_to_training_ds, scene_to_validation_ds
from utils.data_management import observation_factory
from utils.schemas import ObservationPointer

from project_config import S2_CHANNELS


def is_validation(scene):
    return any(
        [validation_site in scene.id
         for validation_site in VALIDATION_SITES]
    )

def is_training(scene):
    return not is_validation(scene)


all_observations = observation_factory(gcp_client)
all_scenes = list(map(
    lambda observation: observation_to_scene(observation, S2_CHANNELS),
    all_observations
))

training_scenes = filter(is_training, all_scenes)
validation_scenes = filter(is_validation, all_scenes)

training_datasets = list(map(
    lambda scene: scene_to_training_ds(scene, TILE_SIZE, AUGMENTATION_TRAIN),
    training_scenes
))
validation_datasets = list(map(
    lambda scene: scene_to_validation_ds(scene, TILE_SIZE),
    validation_scenes
))
assert len(training_datasets) + len(validation_datasets) == len(all_scenes)

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)


In [None]:
from utils.visualizing import show_windows, show_image_in_dataset
from project_config import DISPLAY_GROUPS, CLASS_CONFIG

ds_to_visualize = validation_datasets[2]
show_windows(
    ds_to_visualize.scene.raster_source[:,:],
    ds_to_visualize.windows
)
show_image_in_dataset(
    ds_to_visualize,
    CLASS_CONFIG,
    DISPLAY_GROUPS
)

### Train

In [None]:
from utils.rastervision_utils import construct_semantic_segmentation_learner
from models.unet.unet_small import UNetSmall

_, _, n_channels = training_datasets[0].scene.raster_source.shape
n_classes = 2
model = UNetSmall(n_channels, n_classes)

learner = construct_semantic_segmentation_learner(
    model=model,
    training_ds=train_dataset_merged,
    validation_ds=val_dataset_merged,
    batch_size=64,
    learning_rate=3e-2,
    class_loss_weights=[1., 10.],
    experiment_dir=EXPERIMENT_DIR
)
learner.log_data_stats()

In [None]:
learner.train(epochs=10)

### Predict on validation locations 
And export visualizations to tensorboard

In [None]:
from utils.rastervision_utils import get_predictions_for_site
from utils.visualizing import show_rgb_labels_preds

for idx, ds in enumerate(validation_datasets):
    predicted_mine_probability = get_predictions_for_site(learner, ds)
    fig = show_rgb_labels_preds(
        ds.scene.raster_source[:, :],
        ds.scene.label_source[:, :],
        predicted_mine_probability,
        ds.scene.id,
    )
    learner.tb_writer.add_figure('Predictions', fig, idx, close=True)
    