In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from os.path import expanduser

%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv()

from google.cloud import storage
from project_config import GCP_PROJECT_NAME, DATASET_JSON_PATH

gcp_client = storage.Client(project=GCP_PROJECT_NAME)



### Configuration

In [2]:
from experiment_configs.configs import *
config = satmae_large_config

### Create Rastervision datasets

In [3]:
from torch.utils.data import ConcatDataset
import json
from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, warn_if_nan_in_raw_raster
from utils.data_management import observation_factory

# get the current working directory
root_dir = os.getcwd()

# define the relative path to the dataset JSON file
json_rel_path = '../' + DATASET_JSON_PATH

# combine the root directory with the relative path
json_abs_path = os.path.join(root_dir, json_rel_path)

dataset_json = json.load(open(json_abs_path, 'r'))
all_observations = observation_factory(dataset_json)

training_scenes = []
validation_scenes = []

for observation in all_observations:

    if observation.cluster_id == 0: #statically assign clusetr zero to validation set
        validation_scenes.append(observation_to_scene(config, observation))
    else:
        training_scenes.append(observation_to_scene(config, observation))


training_datasets = [
    scene_to_training_ds(config, scene) for scene in training_scenes #random window sampling happens here
]
validation_datasets = [
    scene_to_validation_ds(config, scene) for scene in validation_scenes
]

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)


2023-11-17 10:36:30:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Kathajodi_Cuttack_85-85_20-44_median/s2/Kathajodi_Cuttack_85-85_20-44_2022-05-01_s2.tif.
2023-11-17 10:36:30:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Kathajodi_Cuttack_85-85_20-44_median/annotations/Kathajodi_Cuttack_85-85_20-44_2022-05-01_annotations.geojson.
2023-11-17 10:36:30:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Kathajodi_Cuttack_85-85_20-44_median/rivers/Kathajodi_Cuttack_85-85_20-44_rivers_1000m.geojson.
2023-11-17 10:36:30:rastervision.core.data.vector_source.geojson_vector_source: INFO - Ignoring CRS ({'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}}) specified in https://storage.googleapis.com/sand_mining_medi

## Visualize the datasets

In [7]:
from utils.visualizing import visualize_dataset

for ds in training_datasets:
    visualize_dataset(ds)

for ds in validation_datasets:
    visualize_dataset(ds)

# Train

In [8]:
from models.model_factory import model_factory
from ml.optimizer_factory import optimizer_factory
from ml.learner import BinarySegmentationLearner

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = model_factory(
    config,
    n_channels=n_channels,
)

optimizer = optimizer_factory(config, model)

learner = BinarySegmentationLearner(
    config=config,
    model=model,
    optimizer=optimizer,
    train_ds=train_dataset_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
    valid_ds=val_dataset_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
    output_dir=expanduser("~/sandmining-watch/out/OUTPUT_DIR"),
)

learner.log_data_stats()

SatMae: Loading encoder weights from /data/sand_mining/checkpoints/satmae_orig/pretrain-vit-large-e199.pth
Position interpolate from 12x12 to 20x20
_IncompatibleKeys(missing_keys=['channel_cls_embed', 'head.weight', 'head.bias'], unexpected_keys=['mask_token', 'decoder_pos_embed', 'decoder_channel_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight', 'decoder_blocks.0.norm1.bias', 'decoder_blocks.0.attn.qkv.weight', 'decoder_blocks.0.attn.qkv.bias', 'decoder_blocks.0.attn.proj.weight', 'decoder_blocks.0.attn.proj.bias', 'decoder_blocks.0.norm2.weight', 'decoder_blocks.0.norm2.bias', 'decoder_blocks.0.mlp.fc1.weight', 'decoder_blocks.0.mlp.fc1.bias', 'decoder_blocks.0.mlp.fc2.weight', 'decoder_blocks.0.mlp.fc2.bias', 'decoder_blocks.1.norm1.weight', 'decoder_blocks.1.norm1.bias', 'decoder_blocks.1.attn.qkv.weight', 'decoder_blocks.1.attn.qkv.bias', 'decoder_blocks.1.attn.proj.weight', 'decoder_blocks.1.attn.proj.bias', 'decoder_blocks.1.norm2.weight', 'd

2023-11-17 10:54:40:rastervision: INFO - train_ds: 3060 items
2023-11-17 10:54:40:rastervision: INFO - valid_ds: 408 items


#### Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [None]:
# Run this cell if you want to log the run to W&B. You might need to authenticate to W&B.
learner.initialize_wandb_run()

In [9]:
learner.train(epochs=20)

2023-11-17 10:54:45:rastervision: INFO - epoch: 0


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Validating:   0%|          | 0/7 [00:00<?, ?it/s]

2023-11-17 11:00:14:rastervision: INFO - metrics:
{'avg_f1': 0.9372423887252808,
 'avg_precision': 0.9520689249038696,
 'avg_recall': 0.9228705167770386,
 'epoch': 0,
 'other_f1': 0.9587695002555847,
 'other_precision': 0.9813424348831177,
 'other_recall': 0.9372116327285767,
 'sandmine_average_precision': tensor(0.4106, device='cuda:0'),
 'sandmine_f1': 0.4035343527793884,
 'sandmine_precision': 0.3027803599834442,
 'sandmine_recall': 0.6047838926315308,
 'train_bce_loss': 0.008986550998064427,
 'train_dice_loss': 0.013842218685773462,
 'train_time': datetime.timedelta(seconds=287, microseconds=850162),
 'val_bce_loss': tensor(0.0061, device='cuda:0'),
 'val_dice_loss': tensor(0.0153, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=41, microseconds=349427)}


## Evaluate 

Initialize evaluation_datasets and predictor.
evaluation_datasets and validation_datasets are based on identical scenes, but have different sliding window configurations.

In [10]:
from ml.learner import BinarySegmentationPredictor
from utils.rastervision_pipeline import scene_to_inference_ds

evaluation_datasets =  [
    scene_to_inference_ds(
        config, scene, full_image=True
    ) for scene in validation_scenes
]

predictor = BinarySegmentationPredictor(
    config,
    model,
)

# # Alternatively: specify path to trained weights
# path_to_weights = expanduser("~/sandmining-watch/out/1102-satmae-1/last-model.pth")
# predictor = BinarySegmentationPredictor(
#     config,
#     model,
#     path_to_weights,
# )

In [11]:
from ml.eval_utils import evaluate_predicitions, make_wandb_segmentation_masks, make_wandb_predicted_probs_images
from utils.visualizing import raster_source_to_rgb

prediction_results_list = []

for ds in evaluation_datasets:
    predictions = predictor.predict_mine_probability_for_site(ds)

    rgb_img = raster_source_to_rgb(ds.scene.raster_source)
    prediction_results_list.append({
        "predictions": predictions,
        "ground_truth": ds.scene.label_source.get_label_arr(),
        "rgb_img": rgb_img,
        "name": ds.scene.id
    })

evaluation_results_dict = evaluate_predicitions(prediction_results_list)

Log results to Weights & Biases

In [None]:
import wandb

assert wandb.run is not None

# Add lists of W&B images to dict
evaluation_results_dict.update({
    'Segmenation masks': make_wandb_segmentation_masks(prediction_results_list),
    'Predicted probabilites': make_wandb_predicted_probs_images(prediction_results_list),
})

# Log to W&B
wandb.log(evaluation_results_dict)

In [None]:
wandb.finish()