In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

%load_ext autoreload
%autoreload 2

In [None]:
from google.cloud import storage
from project_config import GCP_PROJECT_NAME

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

### Configuration

In [None]:
from experiment_configs.satmae_ft_config import satmae_ft_config as config
# Alternatively, construct new config = SupervisedTrainingConfig(...)

VALIDATION_SITES = ["Ken_Banda", "Sone_Rohtas"]

### Create Rastervision datasets

In [None]:
from torch.utils.data import ConcatDataset

from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, warn_if_nan_in_raw_raster
from utils.data_management import observation_factory

def is_validation(scene):
    return any(
        [validation_site in scene.id
         for validation_site in VALIDATION_SITES]
    )

def is_training(scene):
    return not is_validation(scene)


all_observations = observation_factory(gcp_client)
all_scenes = list(map(
    lambda observation: observation_to_scene(config, observation),
    all_observations
))

#for scene in all_scenes:
#    warn_if_nan_in_raw_raster(scene.raster_source)

training_scenes = list(filter(is_training, all_scenes))
validation_scenes = list(filter(is_validation, all_scenes))

training_datasets = list(map(
    lambda scene: scene_to_training_ds(config, scene),
    training_scenes
))
validation_datasets = list(map(
    lambda scene: scene_to_validation_ds(config, scene),
    validation_scenes
))
assert len(training_datasets) + len(validation_datasets) == len(all_scenes)

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)


In [None]:
from utils.visualizing import show_windows

ds_to_visualize = training_datasets[0]
show_windows(
    ds_to_visualize.scene.raster_source[:],
    ds_to_visualize.windows
)

# Train

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# torch.cuda.empty_cache()

In [None]:
from ml.rv_ml import construct_semantic_segmentation_learner
from models.model_selection import get_model

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = get_model(
    config,
    n_channels=n_channels,
)

learner = construct_semantic_segmentation_learner(
    model=model,
    training_ds=train_dataset_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
    validation_ds=val_dataset_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
    batch_size=config.batch_size,
    learning_rate=config.learning_rate,
    # class_loss_weights=[1., 10.],
    experiment_dir=config.output_dir
)
learner.log_data_stats()

## Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [None]:
learner.train(epochs=3)
learner.save_model_bundle()

## Evaluate 

In [None]:
from utils.rastervision_pipeline import scene_to_prediction_ds

# From the validation scenes, we construct new SlidingWindowGeoDatasets.
# The difference between prediction_ds and validation_ds lies in the sliding window configuration,
# the underlaying data is the same.

prediction_ds = list(map(
    lambda scene: scene_to_prediction_ds(config, scene),
    validation_scenes
))

In [None]:
from utils.rastervision_pipeline import get_predictions_for_site
from utils.visualizing import show_rgb_labels_preds

for idx, ds in enumerate(prediction_ds):
    print(ds.scene.id)
    predicted_mine_probability = get_predictions_for_site(learner, ds, crop_sz=10)
    fig = show_rgb_labels_preds(
            ds.scene.raster_source[:],
            ds.scene.label_source[:],
            predicted_mine_probability,
            ds.scene.id,
            show=False
    )
    n_trained_epochs = learner.get_start_epoch() - 1
    learner.tb_writer.add_figure(
        f"Predictions after {n_trained_epochs} epochs",
        fig, idx, close=True
        )

# Tensorboard

To activate tensorboard, run the following command in the terminal:
`conda activate rastervision` (or whatever your conda environment is called)

`tensorboard --logdir sandmining-watch/out` (make sure it's relative to your current path in the terminal)

You will also need to port-forward; on your local machine, run:

`ssh -N -f -L localhost:6006:localhost:6006 <USERNAME>@fati.ischool.berkeley.edu`

Note: this is not needed if you're running from within VS Code- it should automatically give you the option to open tensorboard in the browser.