In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from os.path import expanduser

%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv()

from google.cloud import storage
from project_config import GCP_PROJECT_NAME, DATASET_JSON_PATH

gcp_client = storage.Client(project=GCP_PROJECT_NAME)



### Configuration

In [2]:
from experiment_configs.configs import *
config = satmae_large_config

### Create Rastervision datasets

In [6]:
from torch.utils.data import ConcatDataset
import json
from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, warn_if_nan_in_raw_raster
from utils.data_management import observation_factory, characterize_dataset
import random

#set the seed
random.seed(13)

# get the current working directory
root_dir = os.getcwd()

# define the relative path to the dataset JSON file
json_rel_path = '../' + DATASET_JSON_PATH

# combine the root directory with the relative path
json_abs_path = os.path.join(root_dir, json_rel_path)

dataset_json = json.load(open(json_abs_path, 'r'))
all_observations = observation_factory(dataset_json)

#find the highest cluster id
max_cluster_id = max([observation['cluster_id'] for observation in dataset_json])


# Randomly split the data into training and validation
val_split = random.randint(0, max_cluster_id+1)

training_scenes = []
validation_scenes = []

for observation in all_observations:
    if observation.cluster_id == val_split: 
        validation_scenes.append(observation_to_scene(config, observation))
    else:
        training_scenes.append(observation_to_scene(config, observation))


training_datasets = [
    scene_to_training_ds(config, scene) for scene in training_scenes #random window sampling happens here
]
validation_datasets = [
    scene_to_validation_ds(config, scene) for scene in validation_scenes
]

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)

print('Validation split cluster_id:', val_split)
print ('Training dataset size: {:4d} images | Number of observations: {:}'.format(len(train_dataset_merged), len(training_scenes)))
print ('Testing dataset size: {:4d}  images | Number of observations: {:}'.format(len(val_dataset_merged), len(validation_scenes)))

mine_percentage_aoi = characterize_dataset(training_scenes, validation_scenes)

2023-11-30 18:20:37:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Kathajodi_Cuttack_85-85_20-44_median/s2/Kathajodi_Cuttack_85-85_20-44_2022-05-01_s2.tif.


2023-11-30 18:20:37:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Kathajodi_Cuttack_85-85_20-44_median/annotations/Kathajodi_Cuttack_85-85_20-44_2022-05-01_annotations.geojson.
2023-11-30 18:20:37:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Kathajodi_Cuttack_85-85_20-44_median/rivers/Kathajodi_Cuttack_85-85_20-44_rivers_1000m.geojson.
2023-11-30 18:20:37:rastervision.core.data.vector_source.geojson_vector_source: INFO - Ignoring CRS ({'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}}) specified in https://storage.googleapis.com/sand_mining_median/labels/Kathajodi_Cuttack_85-85_20-44_median/rivers/Kathajodi_Cuttack_85-85_20-44_rivers_1000m.geojson and assuming EPSG:4326 instead.
2023-11-30 18:20:37:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/s

Validation split cluster_id: 4
Training dataset size: 3088 images | Number of observations: 72
Testing dataset size:  378  images | Number of observations: 8
Total dataset has 2.93%  mining area.
Training dataset has 3.05%  mining area.
Validation dataset has 2.15%  mining area.
Within AOIs, total dataset has 6.72%  mining area.
Outside AOIs, total dataset has 0.06%  mining area.

The median percentage of mine in an observation is 3.26%
The median number of mine pixels in an observation is 38059

The median number pixels in an observation is 1301008


## Update the loss weights

In [8]:
config.mine_class_loss_weight = (100 - mine_percentage_aoi) / mine_percentage_aoi
config

SupervisedFinetuningCofig(model_type=<ModelChoice.SatmaeLargeDoubleUpsampling: 'satmae-large-double-upsampling'>, optimizer=<OptimizerChoice.AdamW: 'adamw'>, tile_size=160, s2_channels=[1, 2, 3, 4, 5, 6, 7, 8, 10, 11], s2_normalization=<NormalizationS2Choice.ChannelWise: 'channelwise'>, loss_fn=<BackpropLossChoice.BCE: 'BCE'>, batch_size=128, learning_rate=0.001, datasets=<DatasetChoice.S2: 's2'>, mine_class_loss_weight=13.880675775871904, finetuning_strategy=<FinetuningStratagyChoice.LinearProbing: 'linear-probing'>, encoder_weights_path='/data/sand_mining/checkpoints/satmae_orig/pretrain-vit-large-e199.pth')

## Visualize the datasets

In [4]:
from utils.visualizing import visualize_dataset

print ('Training Dataset')
for ds in training_datasets:
    visualize_dataset(ds)

print("\n\n\n Test Dataset")

for ds in validation_datasets:
    visualize_dataset(ds)

Training Dataset


# Train

In [11]:
from models.model_factory import model_factory
from ml.optimizer_factory import optimizer_factory
from ml.learner import BinarySegmentationLearner

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = model_factory(
    config,
    n_channels=n_channels,
)

optimizer = optimizer_factory(config, model)

learner = BinarySegmentationLearner(
    config=config,
    model=model,
    optimizer=optimizer,
    train_ds=train_dataset_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
    valid_ds=val_dataset_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
    output_dir=expanduser("~/sandmining-watch/out/OUTPUT_DIR"),
)

learner.log_data_stats()

SatMae: Loading encoder weights from /data/sand_mining/checkpoints/satmae_orig/pretrain-vit-large-e199.pth
Position interpolate from 12x12 to 20x20
_IncompatibleKeys(missing_keys=['channel_cls_embed', 'head.weight', 'head.bias'], unexpected_keys=['mask_token', 'decoder_pos_embed', 'decoder_channel_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight', 'decoder_blocks.0.norm1.bias', 'decoder_blocks.0.attn.qkv.weight', 'decoder_blocks.0.attn.qkv.bias', 'decoder_blocks.0.attn.proj.weight', 'decoder_blocks.0.attn.proj.bias', 'decoder_blocks.0.norm2.weight', 'decoder_blocks.0.norm2.bias', 'decoder_blocks.0.mlp.fc1.weight', 'decoder_blocks.0.mlp.fc1.bias', 'decoder_blocks.0.mlp.fc2.weight', 'decoder_blocks.0.mlp.fc2.bias', 'decoder_blocks.1.norm1.weight', 'decoder_blocks.1.norm1.bias', 'decoder_blocks.1.attn.qkv.weight', 'decoder_blocks.1.attn.qkv.bias', 'decoder_blocks.1.attn.proj.weight', 'decoder_blocks.1.attn.proj.bias', 'decoder_blocks.1.norm2.weight', 'd

2023-11-30 18:29:40:rastervision: INFO - train_ds: 3088 items
2023-11-30 18:29:40:rastervision: INFO - valid_ds: 378 items


#### Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [12]:
# Run this cell if you want to log the run to W&B. You might need to authenticate to W&B.
learner.initialize_wandb_run()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandoshah[0m ([33msandmining-watch[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [7]:
learner.train(epochs=1)

2023-11-30 07:25:29:rastervision: INFO - epoch: 0


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 07:31:16:rastervision: INFO - metrics:
{'avg_f1': 0.902629554271698,
 'avg_precision': 0.971852719783783,
 'avg_recall': 0.8426119685173035,
 'epoch': 0,
 'other_f1': 0.9115577936172485,
 'other_precision': 0.9988391399383545,
 'other_recall': 0.8383044004440308,
 'sandmine_average_precision': tensor(0.6282, device='cuda:0'),
 'sandmine_f1': 0.28603678941726685,
 'sandmine_precision': 0.16772311925888062,
 'sandmine_recall': 0.9709680676460266,
 'train_bce_loss': 0.0035413133048022966,
 'train_dice_loss': 0.006418412213498446,
 'train_time': datetime.timedelta(seconds=304, microseconds=6428),
 'val_bce_loss': tensor(0.0033, device='cuda:0'),
 'val_dice_loss': tensor(0.0070, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=549796)}
2023-11-30 07:31:18:rastervision: INFO - epoch: 1


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 07:36:54:rastervision: INFO - metrics:
{'avg_f1': 0.9301215410232544,
 'avg_precision': 0.9726501703262329,
 'avg_recall': 0.8911561965942383,
 'epoch': 1,
 'other_f1': 0.9405212998390198,
 'other_precision': 0.99782794713974,
 'other_recall': 0.8894395232200623,
 'sandmine_average_precision': tensor(0.6519, device='cuda:0'),
 'sandmine_f1': 0.35988304018974304,
 'sandmine_precision': 0.2224131077528,
 'sandmine_recall': 0.9423085451126099,
 'train_bce_loss': 0.003252634112699044,
 'train_dice_loss': 0.006419925492044558,
 'train_time': datetime.timedelta(seconds=292, microseconds=765892),
 'val_bce_loss': tensor(0.0027, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=556207)}
2023-11-30 07:36:56:rastervision: INFO - epoch: 2


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 07:42:33:rastervision: INFO - metrics:
{'avg_f1': 0.9383898377418518,
 'avg_precision': 0.9737926721572876,
 'avg_recall': 0.9054709076881409,
 'epoch': 2,
 'other_f1': 0.9487325549125671,
 'other_precision': 0.9981120228767395,
 'other_recall': 0.9040085673332214,
 'sandmine_average_precision': tensor(0.6633, device='cuda:0'),
 'sandmine_f1': 0.39466598629951477,
 'sandmine_precision': 0.2491350769996643,
 'sandmine_recall': 0.9490461945533752,
 'train_bce_loss': 0.003279897524285193,
 'train_dice_loss': 0.006372903912796258,
 'train_time': datetime.timedelta(seconds=293, microseconds=242233),
 'val_bce_loss': tensor(0.0025, device='cuda:0'),
 'val_dice_loss': tensor(0.0068, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=433517)}
2023-11-30 07:42:35:rastervision: INFO - epoch: 3


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 07:48:12:rastervision: INFO - metrics:
{'avg_f1': 0.9046080708503723,
 'avg_precision': 0.9719427824020386,
 'avg_recall': 0.8459985852241516,
 'epoch': 3,
 'other_f1': 0.9136262536048889,
 'other_precision': 0.998830258846283,
 'other_recall': 0.8418161869049072,
 'sandmine_average_precision': tensor(0.6489, device='cuda:0'),
 'sandmine_f1': 0.290426641702652,
 'sandmine_precision': 0.17076048254966736,
 'sandmine_recall': 0.9706243872642517,
 'train_bce_loss': 0.003548857152770838,
 'train_dice_loss': 0.006371684642653391,
 'train_time': datetime.timedelta(seconds=293, microseconds=885203),
 'val_bce_loss': tensor(0.0033, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=257483)}
2023-11-30 07:48:13:rastervision: INFO - epoch: 4


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 07:53:51:rastervision: INFO - metrics:
{'avg_f1': 0.9316933751106262,
 'avg_precision': 0.9724385738372803,
 'avg_recall': 0.8942252993583679,
 'epoch': 4,
 'other_f1': 0.942314624786377,
 'other_precision': 0.9974810481071472,
 'other_recall': 0.8929304480552673,
 'sandmine_average_precision': tensor(0.6419, device='cuda:0'),
 'sandmine_f1': 0.36414870619773865,
 'sandmine_precision': 0.22623255848884583,
 'sandmine_recall': 0.9328083992004395,
 'train_bce_loss': 0.00326537746221908,
 'train_dice_loss': 0.0064581626437488614,
 'train_time': datetime.timedelta(seconds=293, microseconds=350634),
 'val_bce_loss': tensor(0.0026, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=791997)}
2023-11-30 07:53:52:rastervision: INFO - epoch: 5


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 07:59:29:rastervision: INFO - metrics:
{'avg_f1': 0.9202529191970825,
 'avg_precision': 0.9729779958724976,
 'avg_recall': 0.8729484677314758,
 'epoch': 5,
 'other_f1': 0.9297994375228882,
 'other_precision': 0.9989151954650879,
 'other_recall': 0.8696290850639343,
 'sandmine_average_precision': tensor(0.6540, device='cuda:0'),
 'sandmine_f1': 0.3318849205970764,
 'sandmine_precision': 0.20011088252067566,
 'sandmine_recall': 0.9718592166900635,
 'train_bce_loss': 0.0032760948714814653,
 'train_dice_loss': 0.006367186808215522,
 'train_time': datetime.timedelta(seconds=293, microseconds=249040),
 'val_bce_loss': tensor(0.0028, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=277926)}
2023-11-30 07:59:31:rastervision: INFO - epoch: 6


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:05:08:rastervision: INFO - metrics:
{'avg_f1': 0.9358887076377869,
 'avg_precision': 0.9729173183441162,
 'avg_recall': 0.9015753269195557,
 'epoch': 6,
 'other_f1': 0.9465358257293701,
 'other_precision': 0.9975312948226929,
 'other_recall': 0.9005007743835449,
 'sandmine_average_precision': tensor(0.6127, device='cuda:0'),
 'sandmine_f1': 0.38118064403533936,
 'sandmine_precision': 0.23947924375534058,
 'sandmine_recall': 0.9335945248603821,
 'train_bce_loss': 0.003343521622178468,
 'train_dice_loss': 0.006475090362865072,
 'train_time': datetime.timedelta(seconds=294, microseconds=144992),
 'val_bce_loss': tensor(0.0026, device='cuda:0'),
 'val_dice_loss': tensor(0.0068, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=315386)}
2023-11-30 08:05:10:rastervision: INFO - epoch: 7


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:10:47:rastervision: INFO - metrics:
{'avg_f1': 0.9250031113624573,
 'avg_precision': 0.9730350971221924,
 'avg_recall': 0.8814901113510132,
 'epoch': 7,
 'other_f1': 0.9348450303077698,
 'other_precision': 0.9986233115196228,
 'other_recall': 0.8787243366241455,
 'sandmine_average_precision': tensor(0.6325, device='cuda:0'),
 'sandmine_f1': 0.34563156962394714,
 'sandmine_precision': 0.21056802570819855,
 'sandmine_recall': 0.9639026522636414,
 'train_bce_loss': 0.0033008888595462463,
 'train_dice_loss': 0.006348453655144094,
 'train_time': datetime.timedelta(seconds=293, microseconds=835639),
 'val_bce_loss': tensor(0.0028, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=663490)}
2023-11-30 08:10:49:rastervision: INFO - epoch: 8


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:16:26:rastervision: INFO - metrics:
{'avg_f1': 0.9451583623886108,
 'avg_precision': 0.9726808071136475,
 'avg_recall': 0.9191505312919617,
 'epoch': 8,
 'other_f1': 0.9565572738647461,
 'other_precision': 0.9961683750152588,
 'other_recall': 0.9199758172035217,
 'sandmine_average_precision': tensor(0.6200, device='cuda:0'),
 'sandmine_f1': 0.4181073307991028,
 'sandmine_precision': 0.27280721068382263,
 'sandmine_recall': 0.8945596218109131,
 'train_bce_loss': 0.0031392240771357877,
 'train_dice_loss': 0.006447636401715057,
 'train_time': datetime.timedelta(seconds=294, microseconds=20867),
 'val_bce_loss': tensor(0.0025, device='cuda:0'),
 'val_dice_loss': tensor(0.0068, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=286159)}
2023-11-30 08:16:28:rastervision: INFO - epoch: 9


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:22:05:rastervision: INFO - metrics:
{'avg_f1': 0.925705075263977,
 'avg_precision': 0.9720682501792908,
 'avg_recall': 0.8835631608963013,
 'epoch': 9,
 'other_f1': 0.9361172914505005,
 'other_precision': 0.9976367354393005,
 'other_recall': 0.8817443251609802,
 'sandmine_average_precision': tensor(0.6265, device='cuda:0'),
 'sandmine_f1': 0.34340769052505493,
 'sandmine_precision': 0.21018953621387482,
 'sandmine_recall': 0.937760591506958,
 'train_bce_loss': 0.0032780439124823852,
 'train_dice_loss': 0.006442899531033373,
 'train_time': datetime.timedelta(seconds=293, microseconds=470351),
 'val_bce_loss': tensor(0.0028, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=868133)}
2023-11-30 08:22:07:rastervision: INFO - epoch: 10


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:27:44:rastervision: INFO - metrics:
{'avg_f1': 0.9283274412155151,
 'avg_precision': 0.9729678630828857,
 'avg_recall': 0.8876035809516907,
 'epoch': 10,
 'other_f1': 0.9384331703186035,
 'other_precision': 0.9982898235321045,
 'other_recall': 0.8853483200073242,
 'sandmine_average_precision': tensor(0.6526, device='cuda:0'),
 'sandmine_f1': 0.3555312752723694,
 'sandmine_precision': 0.21843352913856506,
 'sandmine_recall': 0.9548067450523376,
 'train_bce_loss': 0.0033403990182234214,
 'train_dice_loss': 0.006411716728012796,
 'train_time': datetime.timedelta(seconds=293, microseconds=656829),
 'val_bce_loss': tensor(0.0026, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=640234)}
2023-11-30 08:27:46:rastervision: INFO - epoch: 11


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:33:24:rastervision: INFO - metrics:
{'avg_f1': 0.9509456753730774,
 'avg_precision': 0.9738165140151978,
 'avg_recall': 0.9291244745254517,
 'epoch': 11,
 'other_f1': 0.9621134996414185,
 'other_precision': 0.9963752031326294,
 'other_recall': 0.9301297068595886,
 'sandmine_average_precision': tensor(0.6254, device='cuda:0'),
 'sandmine_f1': 0.45171549916267395,
 'sandmine_precision': 0.3016199469566345,
 'sandmine_recall': 0.8991712331771851,
 'train_bce_loss': 0.0031471388327643044,
 'train_dice_loss': 0.006386734661042999,
 'train_time': datetime.timedelta(seconds=294, microseconds=8704),
 'val_bce_loss': tensor(0.0023, device='cuda:0'),
 'val_dice_loss': tensor(0.0068, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=44, microseconds=73892)}
2023-11-30 08:33:26:rastervision: INFO - epoch: 12


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:39:02:rastervision: INFO - metrics:
{'avg_f1': 0.9372101426124573,
 'avg_precision': 0.973580002784729,
 'avg_recall': 0.9034597873687744,
 'epoch': 12,
 'other_f1': 0.9475882053375244,
 'other_precision': 0.998034656047821,
 'other_recall': 0.9019961953163147,
 'sandmine_average_precision': tensor(0.6406, device='cuda:0'),
 'sandmine_f1': 0.38915306329727173,
 'sandmine_precision': 0.24488909542560577,
 'sandmine_recall': 0.9470729231834412,
 'train_bce_loss': 0.0032116607063174866,
 'train_dice_loss': 0.006421818016724265,
 'train_time': datetime.timedelta(seconds=293, microseconds=195005),
 'val_bce_loss': tensor(0.0025, device='cuda:0'),
 'val_dice_loss': tensor(0.0068, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=198021)}
2023-11-30 08:39:04:rastervision: INFO - epoch: 13


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:44:41:rastervision: INFO - metrics:
{'avg_f1': 0.9321161508560181,
 'avg_precision': 0.9727174639701843,
 'avg_recall': 0.8947683572769165,
 'epoch': 13,
 'other_f1': 0.9426153898239136,
 'other_precision': 0.9977111220359802,
 'other_recall': 0.8932861685752869,
 'sandmine_average_precision': tensor(0.6517, device='cuda:0'),
 'sandmine_f1': 0.3668608069419861,
 'sandmine_precision': 0.2279658168554306,
 'sandmine_recall': 0.9389349818229675,
 'train_bce_loss': 0.003165143450307105,
 'train_dice_loss': 0.006491621541235731,
 'train_time': datetime.timedelta(seconds=293, microseconds=224928),
 'val_bce_loss': tensor(0.0026, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=740777)}
2023-11-30 08:44:42:rastervision: INFO - epoch: 14


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:50:20:rastervision: INFO - metrics:
{'avg_f1': 0.9420161843299866,
 'avg_precision': 0.9731568098068237,
 'avg_recall': 0.9128066897392273,
 'epoch': 14,
 'other_f1': 0.9529467225074768,
 'other_precision': 0.9970586895942688,
 'other_recall': 0.9125726222991943,
 'sandmine_average_precision': tensor(0.6420, device='cuda:0'),
 'sandmine_f1': 0.40654146671295166,
 'sandmine_precision': 0.2609376609325409,
 'sandmine_recall': 0.9197819232940674,
 'train_bce_loss': 0.003207698387185527,
 'train_dice_loss': 0.006405515991962017,
 'train_time': datetime.timedelta(seconds=294, microseconds=243663),
 'val_bce_loss': tensor(0.0024, device='cuda:0'),
 'val_dice_loss': tensor(0.0068, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=693902)}
2023-11-30 08:50:22:rastervision: INFO - epoch: 15


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 08:56:00:rastervision: INFO - metrics:
{'avg_f1': 0.9322409629821777,
 'avg_precision': 0.9733299612998962,
 'avg_recall': 0.8944805264472961,
 'epoch': 15,
 'other_f1': 0.942417562007904,
 'other_precision': 0.9982921481132507,
 'other_recall': 0.892466127872467,
 'sandmine_average_precision': tensor(0.6633, device='cuda:0'),
 'sandmine_f1': 0.37005195021629333,
 'sandmine_precision': 0.2295166403055191,
 'sandmine_recall': 0.9545043706893921,
 'train_bce_loss': 0.003118120944561736,
 'train_dice_loss': 0.00643309039773101,
 'train_time': datetime.timedelta(seconds=293, microseconds=718531),
 'val_bce_loss': tensor(0.0025, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=44, microseconds=93877)}
2023-11-30 08:56:01:rastervision: INFO - epoch: 16


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 09:01:39:rastervision: INFO - metrics:
{'avg_f1': 0.9301341772079468,
 'avg_precision': 0.9732560515403748,
 'avg_recall': 0.8906714916229248,
 'epoch': 16,
 'other_f1': 0.9402074813842773,
 'other_precision': 0.9984123110771179,
 'other_recall': 0.8884152173995972,
 'sandmine_average_precision': tensor(0.6396, device='cuda:0'),
 'sandmine_f1': 0.36264535784721375,
 'sandmine_precision': 0.2236594706773758,
 'sandmine_recall': 0.9579033851623535,
 'train_bce_loss': 0.003273982458164037,
 'train_dice_loss': 0.006372220157959301,
 'train_time': datetime.timedelta(seconds=294, microseconds=40499),
 'val_bce_loss': tensor(0.0026, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=902830)}
2023-11-30 09:01:41:rastervision: INFO - epoch: 17


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 09:07:19:rastervision: INFO - metrics:
{'avg_f1': 0.9206599593162537,
 'avg_precision': 0.9723639488220215,
 'avg_recall': 0.8741768598556519,
 'epoch': 17,
 'other_f1': 0.9305656552314758,
 'other_precision': 0.9982956051826477,
 'other_recall': 0.8714421391487122,
 'sandmine_average_precision': tensor(0.6109, device='cuda:0'),
 'sandmine_f1': 0.33031538128852844,
 'sandmine_precision': 0.1996634155511856,
 'sandmine_recall': 0.9556660056114197,
 'train_bce_loss': 0.003126473006806843,
 'train_dice_loss': 0.00637609538636677,
 'train_time': datetime.timedelta(seconds=293, microseconds=857288),
 'val_bce_loss': tensor(0.0029, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=824546)}
2023-11-30 09:07:20:rastervision: INFO - epoch: 18


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 09:12:58:rastervision: INFO - metrics:
{'avg_f1': 0.9307929873466492,
 'avg_precision': 0.972937285900116,
 'avg_recall': 0.8921481966972351,
 'epoch': 18,
 'other_f1': 0.9410833716392517,
 'other_precision': 0.9980452656745911,
 'other_recall': 0.8902723789215088,
 'sandmine_average_precision': tensor(0.6387, device='cuda:0'),
 'sandmine_f1': 0.36339759826660156,
 'sandmine_precision': 0.22477920353412628,
 'sandmine_recall': 0.9480436444282532,
 'train_bce_loss': 0.0031256558363919433,
 'train_dice_loss': 0.006405734027605601,
 'train_time': datetime.timedelta(seconds=294, microseconds=97215),
 'val_bce_loss': tensor(0.0026, device='cuda:0'),
 'val_dice_loss': tensor(0.0069, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=449964)}
2023-11-30 09:13:00:rastervision: INFO - epoch: 19


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2023-11-30 09:18:37:rastervision: INFO - metrics:
{'avg_f1': 0.9428982138633728,
 'avg_precision': 0.9737053513526917,
 'avg_recall': 0.9139807820320129,
 'epoch': 19,
 'other_f1': 0.9535909295082092,
 'other_precision': 0.9974827766418457,
 'other_recall': 0.9133989810943604,
 'sandmine_average_precision': tensor(0.6285, device='cuda:0'),
 'sandmine_f1': 0.41283372044563293,
 'sandmine_precision': 0.2651945948600769,
 'sandmine_recall': 0.9313157796859741,
 'train_bce_loss': 0.0031257676337049416,
 'train_dice_loss': 0.006446888409747978,
 'train_time': datetime.timedelta(seconds=293, microseconds=668967),
 'val_bce_loss': tensor(0.0024, device='cuda:0'),
 'val_dice_loss': tensor(0.0068, device='cuda:0'),
 'valid_time': datetime.timedelta(seconds=43, microseconds=343775)}


## Evaluate 

Initialize evaluation_datasets and predictor.
evaluation_datasets and validation_datasets are based on identical scenes, but have different sliding window configurations.

In [13]:
from ml.learner import BinarySegmentationPredictor
from utils.rastervision_pipeline import scene_to_inference_ds

evaluation_datasets =  [
    scene_to_inference_ds(
        config, scene, full_image=True
    ) for scene in validation_scenes
]

predictor = BinarySegmentationPredictor(
    config,
    model,
)

## Alternatively: specify path to trained weights
# path_to_weights = expanduser("~/sandmining-watch/out/OUTPUT_DIR/SatMAE-L_w15.47_b128.pth")
# predictor = BinarySegmentationPredictor(
#     config,
#     model,
#     path_to_weights,
# )

Loading weights from /home/ando/sandmining-watch/out/OUTPUT_DIR/SatMAE-L_w15.47_b128.pth


In [13]:
from ml.eval_utils import evaluate_predicitions, make_wandb_segmentation_masks, make_wandb_predicted_probs_images
from utils.visualizing import raster_source_to_rgb

prediction_results_list = []

for ds in evaluation_datasets:
    predictions = predictor.predict_mine_probability_for_site(ds)

    rgb_img = raster_source_to_rgb(ds.scene.raster_source)
    prediction_results_list.append({
        "predictions": predictions,
        "ground_truth": ds.scene.label_source.get_label_arr(),
        "rgb_img": rgb_img,
        "name": ds.scene.id
    })

evaluation_results_dict = evaluate_predicitions(prediction_results_list)

Log results to Weights & Biases

In [10]:
import wandb

assert wandb.run is not None

# Add lists of W&B images to dict
evaluation_results_dict.update({
    'Segmenation masks': make_wandb_segmentation_masks(prediction_results_list),
    'Predicted probabilites': make_wandb_predicted_probs_images(prediction_results_list),
})

# Log to W&B
wandb.log(evaluation_results_dict)



In [11]:
wandb.finish()



0,1
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/average_precision,▁
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/f1_score,▁
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/precision,▁
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/recall,▁
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/average_precision,▁
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/f1_score,▁
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/precision,▁
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/recall,▁
eval/Betwa_Jalaun_79-79_25-89_2022-10-01/average_precision,▁
eval/Betwa_Jalaun_79-79_25-89_2022-10-01/f1_score,▁

0,1
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/average_precision,0.75794
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/f1_score,0.36516
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/precision,0.22577
eval/Betwa_Jalaun_79-49_25-84_2022-10-01/recall,0.9544
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/average_precision,0.63346
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/f1_score,0.19286
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/precision,0.10674
eval/Betwa_Jalaun_79-49_25-84_2023-05-01/recall,0.9987
eval/Betwa_Jalaun_79-79_25-89_2022-10-01/average_precision,0.77294
eval/Betwa_Jalaun_79-79_25-89_2022-10-01/f1_score,0.46233
