In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from os.path import expanduser

%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv()

from google.cloud import storage
from project_config import GCP_PROJECT_NAME

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

### Configuration

In [None]:
from experiment_configs.configs import *
config = satmae_large_config

### Create Rastervision datasets

In [None]:
from torch.utils.data import ConcatDataset

from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, warn_if_nan_in_raw_raster
from utils.data_management import observation_factory

from project_config import is_training, is_validation

all_observations = observation_factory(gcp_client)
training_scenes = []
validation_scenes = []

for observation in all_observations:
    is_train = is_training(observation.name)
    is_val = is_validation(observation.name)

    assert not (is_train and is_val), "An observation cannot be in both training and validation"

    scene = observation_to_scene(config, observation)    
    if is_train:
        training_scenes.append(scene)
    elif is_val:
        validation_scenes.append(scene)
    else:
        print(f"Ignoring observation {observation.name}")


#all_scenes = training_scenes + validation_scenes
#for scene in all_scenes:
#    warn_if_nan_in_raw_raster(scene.raster_source)

training_datasets = [
    scene_to_training_ds(config, scene) for scene in training_scenes #random window sampling happens here
]
validation_datasets = [
    scene_to_validation_ds(config, scene) for scene in validation_scenes
]

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)


## Visualize the datasets

In [None]:
from utils.visualizing import visualize_dataset

for ds in training_datasets:
    visualize_dataset(ds)

for ds in validation_datasets:
    visualize_dataset(ds)

# Train

In [None]:
from models.model_factory import model_factory
from ml.optimizer_factory import optimizer_factory
from ml.learner import BinarySegmentationLearner

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = model_factory(
    config,
    n_channels=n_channels,
)

optimizer = optimizer_factory(config, model)

learner = BinarySegmentationLearner(
    config=config,
    model=model,
    optimizer=optimizer,
    train_ds=train_dataset_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
    valid_ds=val_dataset_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
    output_dir=expanduser("~/sandmining-watch/out/OUTPUT_DIR"),
)

learner.log_data_stats()

#### Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [None]:
# Run this cell if you want to log the run to W&B. You might need to authenticate to W&B.
learner.initialize_wandb_run()

In [None]:
learner.train(epochs=20)

## Evaluate 

Initialize evaluation_datasets and predictor.
evaluation_datasets and validation_datasets are based on identical scenes, but have different sliding window configurations.

In [None]:
from ml.learner import BinarySegmentationPredictor
from utils.rastervision_pipeline import scene_to_inference_ds

evaluation_datasets =  [
    scene_to_inference_ds(
        config, scene, full_image=True
    ) for scene in validation_scenes
]

predictor = BinarySegmentationPredictor(
    config,
    model,
)

# # Alternatively: specify path to trained weights
# path_to_weights = expanduser("~/sandmining-watch/out/1102-satmae-1/last-model.pth")
# predictor = BinarySegmentationPredictor(
#     config,
#     model,
#     path_to_weights,
# )

In [None]:
from ml.eval_utils import evaluate_predicitions, make_wandb_segmentation_masks, make_wandb_predicted_probs_images
from utils.visualizing import raster_source_to_rgb

prediction_results_list = []

for ds in evaluation_datasets:
    predictions = predictor.predict_mine_probability_for_site(ds)

    rgb_img = raster_source_to_rgb(ds.scene.raster_source)
    prediction_results_list.append({
        "predictions": predictions,
        "ground_truth": ds.scene.label_source.get_label_arr(),
        "rgb_img": rgb_img,
        "name": ds.scene.id
    })

evaluation_results_dict = evaluate_predicitions(prediction_results_list)

Log results to Weights & Biases

In [None]:
import wandb

assert wandb.run is not None

# Add lists of W&B images to dict
evaluation_results_dict.update({
    'Segmenation masks': make_wandb_segmentation_masks(prediction_results_list),
    'Predicted probabilites': make_wandb_predicted_probs_images(prediction_results_list),
})

# Log to W&B
wandb.log(evaluation_results_dict)

In [None]:
wandb.finish()