In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

%load_ext autoreload
%autoreload 2

In [None]:
from google.cloud import storage
from project_config import GCP_PROJECT_NAME

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

In [None]:
from models.satmae.util import lr_decay as lrd
from models.satmae.satmae_encoder_custom_decoder.satmae_encoder_linear_decoder import SatMaeSegmenterWithLinearDecoder

model = SatMaeSegmenterWithLinearDecoder()
param_groups = lrd.param_groups_lrd(model.encoder, 0.05,
                                    no_weight_decay_list={'pos_embed', 'cls_token', 'dist_token'},
                                    layer_decay=0.75)


### Configuration

In [None]:
from experiment_configs.unet_fs_config import unet_fs_config
from experiment_configs.satmae_ft_config import satmae_ft_doubleupsampling_config, satmae_ft_lineardecoder_config
config = satmae_ft_doubleupsampling_config

VALIDATION_SITES = ["Ken_Banda", "Sone_Rohtas"]

### Create Rastervision datasets

In [None]:
from torch.utils.data import ConcatDataset

from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, warn_if_nan_in_raw_raster
from utils.data_management import observation_factory

def is_validation(scene):
    return any(
        [validation_site in scene.id
         for validation_site in VALIDATION_SITES]
    )

def is_training(scene):
    return not is_validation(scene)


all_observations = observation_factory(gcp_client)
all_scenes = list(map(
    lambda observation: observation_to_scene(config, observation),
    all_observations
))

#for scene in all_scenes:
#    warn_if_nan_in_raw_raster(scene.raster_source)

training_scenes = list(filter(is_training, all_scenes))
validation_scenes = list(filter(is_validation, all_scenes))

training_datasets = list(map(
    lambda scene: scene_to_training_ds(config, scene),
    training_scenes
))
validation_datasets = list(map(
    lambda scene: scene_to_validation_ds(config, scene),
    validation_scenes
))
assert len(training_datasets) + len(validation_datasets) == len(all_scenes)

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)


In [None]:
from utils.visualizing import show_windows

ds_to_visualize = training_datasets[11]
ds_to_visualize = ds
show_windows(
    ds_to_visualize.scene.raster_source[:],
    ds_to_visualize.windows
)

# Train

In [None]:
# torch.cuda.empty_cache()

In [None]:
from ml.rv_ml import construct_semantic_segmentation_learner
from models.model_factory import model_factory

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = model_factory(
    config,
    n_channels=n_channels,
)

learner = construct_semantic_segmentation_learner(
    config=config,
    model=model,
    training_ds=training_datasets[0],  # for development and debugging, use training_datasets[0] or similar to speed up
    validation_ds=training_datasets[1],  # for development and debugging, use training_datasets[1] or similar to speed up
)
learner.log_data_stats()

## Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [None]:
# Run this cell if you want to log the run to W&B. You might need to authenticate to W&B.
learner.initialize_wandb_run()

In [None]:
learner.train(epochs=10)
learner.save_model_bundle()

## Evaluate 

Create SlidingWindowGeoDatasets for predictions

In [None]:
from utils.rastervision_pipeline import scene_to_prediction_ds

# From the validation scenes, we construct new SlidingWindowGeoDatasets.
# The difference between prediction_ds and validation_ds lies in the sliding window configuration,
# the underlaying data is the same.

prediction_ds = list(map(
    lambda scene: scene_to_prediction_ds(config, scene),
    validation_scenes
))

Run inference on prediction sites and log segmentation images to W&B

In [None]:
import wandb
from utils.wandb_utils import create_semantic_segmentation_image
from ml.rv_ml import predict_class_for_site

assert wandb.run is not None

segmentation_result_images = []

for idx, ds in enumerate(prediction_ds):
    print(ds.scene.id)
    rgb_img = ds.scene.raster_source.get_image_array()[:,:,1:4] # todo channels
    predicted_mask = predict_class_for_site(learner, ds, crop_sz=0)
    ground_truth_mask = ds.scene.label_source.get_label_arr()
    segmantation_result_image = create_semantic_segmentation_image(
        rgb_img, predicted_mask, ground_truth_mask, ds.scene.id
    )
    segmentation_result_images.append(segmantation_result_image)

wandb.log({'Segmenation results': segmentation_result_images})