In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from os.path import expanduser

%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv()

from google.cloud import storage
from project_config import GCP_PROJECT_NAME, DATASET_JSON_PATH

gcp_client = storage.Client(project=GCP_PROJECT_NAME)



In [2]:
import os, torch
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" #to prevent cuda out of memory error
torch.cuda.empty_cache()


#For reproducibility
torch.manual_seed(13)

<torch._C.Generator at 0x7f44980ba870>

### Configuration

In [3]:
from experiment_configs.configs import *
config = satmae_large_config_lora_methodB

### Create Rastervision datasets

In [4]:
from torch.utils.data import ConcatDataset
import json
from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, scene_to_inference_ds
from utils.data_management import observation_factory, characterize_dataset
import random

from utils.rastervision_pipeline import GoogleCloudFileSystem
GoogleCloudFileSystem.storage_client = gcp_client

#set the seed
random.seed(13)

# get the current working directory
root_dir = os.getcwd()

# define the relative path to the dataset JSON file
json_rel_path = '../' + DATASET_JSON_PATH

# combine the root directory with the relative path
json_abs_path = os.path.join(root_dir, json_rel_path)

dataset_json = json.load(open(json_abs_path, 'r'))
all_observations = observation_factory(dataset_json)

#find the highest cluster id
max_cluster_id = max([observation['cluster_id'] for observation in dataset_json])


# Randomly split the data into training and validation
# val_split = random.randint(0, max_cluster_id+1)
val_split = 0

training_scenes = []
validation_scenes = []

for observation in all_observations:
    if observation.cluster_id == val_split: 
        validation_scenes.append(observation_to_scene(config, observation))
    else:
        training_scenes.append(observation_to_scene(config, observation))

# This is only for testing
if len(validation_scenes) <= 0: validation_scenes = training_scenes

training_datasets = [
    scene_to_training_ds(config, scene) for scene in training_scenes #random window sampling happens here
]
validation_datasets = [
    # scene_to_validation_ds(config, scene) for scene in validation_scenes
    scene_to_inference_ds(config, scene, full_image=False, stride=int(config.tile_size/2)) for scene in validation_scenes # better performance with this
]

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)

print('Validation split cluster_id:', val_split)
print ('Training dataset size: {:4d} images | Number of observations: {:}'.format(len(train_dataset_merged), len(training_scenes)))
print ('Testing dataset size: {:4d}  images | Number of observations: {:}'.format(len(val_dataset_merged), len(validation_scenes)))

mine_percentage_aoi = characterize_dataset(training_scenes, validation_scenes)

2024-04-16 08:26:17:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Tawa_Hoshangabad_77-80_22-74_median/s2/Tawa_Hoshangabad_77-80_22-74_2022-04-01_s2.tif.
2024-04-16 08:26:17:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Tawa_Hoshangabad_77-80_22-74_median/annotations/Tawa_Hoshangabad_77-80_22-74_2022-04-01_annotations_3class.geojson.
2024-04-16 08:26:17:rastervision.pipeline.file_system.utils: INFO - Using cached file /data/tmp/cache/http/storage.googleapis.com/sand_mining_median/labels/Tawa_Hoshangabad_77-80_22-74_median/rivers/Tawa_Hoshangabad_77-80_22-74_rivers_1000m.geojson.
2024-04-16 08:26:17:rastervision.core.data.vector_source.geojson_vector_source: INFO - Ignoring CRS ({'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}}) specified in https://storage.googleapis.com/sand_mining_med

Validation split cluster_id: 0
Training dataset size: 3254 images | Number of observations: 57
Testing dataset size:  510  images | Number of observations: 2
Total dataset has 1.79%  mining area.
Training dataset has 1.71%  mining area.
Validation dataset has 2.92%  mining area.
Within AOIs, total dataset has 4.78%  mining area.
Outside AOIs, total dataset has 0.03%  mining area.

The median percentage of mine in an observation is 2.01%
The median number of mine pixels in an observation is 34934

The median number pixels in an observation is 1951732


## Update the loss weights to account for the imbalanced dataset

In [5]:
# config.mine_class_loss_weight = (100 - mine_percentage_aoi) / mine_percentage_aoi
# config.mine_class_loss_weight = 1
config.mine_class_loss_weight

6.0

## Visualize the datasets

In [None]:
from utils.visualizing import visualize_dataset

# print ('Training Dataset')
# for ds in training_datasets:
#     visualize_dataset(ds)

print("\n\n\n Val Dataset")

for ds in validation_datasets:
    visualize_dataset(ds)

# Train

In [6]:
from models.model_factory import model_factory, print_trainable_parameters
from ml.optimizer_factory import optimizer_factory
from ml.learner_factory import learner_factory
from experiment_configs.schemas import ThreeClassVariants

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = model_factory(
    config,
    n_channels=n_channels,
    config_lora=lora_config
)

optimizer = optimizer_factory(config, model)

learner = learner_factory(
    config=config,
    model=model,
    optimizer=optimizer,
    train_ds=training_datasets[0],  # for development and debugging, use training_datasets[0] or similar to speed up
    valid_ds=training_datasets[0],  # for development and debugging, use training_datasets[1] or similar to speed up
    output_dir=expanduser("~/sandmining-watch/out/OUTPUT_DIR")
)
print_trainable_parameters(learner.model)

SatMae: Loading encoder weights from /data/sand_mining/checkpoints/satmae_orig/pretrain-vit-large-e199.pth
Position interpolate from 12x12 to 20x20
['channel_cls_embed', 'head.weight', 'head.bias']
SatMaePretrained: Freezing encoder weights
Number of parameters loaded: 298
Applying LoRA ...


2024-04-16 08:27:09:rastervision: INFO - train_ds: 52 items
2024-04-16 08:27:09:rastervision: INFO - valid_ds: 52 items


trainable params: 2.565123M || all params: 306.839304M || trainable%: 0.84


#### Check GPU Activity

You can continuously monitor your GPU activity by using the command in the terminal


`watch -d -n 0.5 nvidia-smi`

In [7]:
# Run this cell if you want to log the run to W&B. You might need to authenticate to W&B.
# learner.initialize_wandb_run()

In [7]:
learner.train(epochs=1)

2024-04-16 08:27:11:rastervision: INFO - epoch: 0


Training:   0%|          | 0/7 [00:00<?, ?it/s]

Validating:   0%|          | 0/7 [00:00<?, ?it/s]

2024-04-16 08:27:36:rastervision: INFO - metrics:
{'avg_f1': 0.9648976922035217,
 'avg_precision': 0.9535627961158752,
 'avg_recall': 0.9765053987503052,
 'epoch': 0,
 'lc_f1': 0.0,
 'lc_precision': 0.0,
 'lc_recall': 0.0,
 'other_f1': 0.9881130456924438,
 'other_precision': 0.9765053987503052,
 'other_recall': 1.0,
 'sandmine_average_precision': 0.054424935776535295,
 'sandmine_best_f1_score': 0.12506071424048654,
 'sandmine_best_threshold': 0.84512323,
 'sandmine_f1': 0.0,
 'sandmine_precision': 0.0,
 'sandmine_recall': 0.0,
 'train_ce_loss': 0.08008507581857535,
 'train_time': datetime.timedelta(seconds=15, microseconds=6623),
 'valid_time': datetime.timedelta(seconds=9, microseconds=849935)}


Saving model weights to /home/gautamsai_y/sandmining-watch/out/OUTPUT_DIR/last-model.pth
Saving model weights to /home/gautamsai_y/sandmining-watch/out/OUTPUT_DIR/best-model.pth
Merging LoRa weights ...
Saving model weights to /home/gautamsai_y/sandmining-watch/out/OUTPUT_DIR/last-model.pth


## Evaluate 

Initialize evaluation_datasets and predictor.
evaluation_datasets and validation_datasets are based on identical scenes, but have different sliding window configurations.

In [None]:
config = satmae_large_inf_config

In [None]:
from ml.learner import BinarySegmentationPredictor
from utils.rastervision_pipeline import scene_to_inference_ds

evaluation_datasets =  [
    scene_to_inference_ds(
        config, scene, full_image=True
    ) for scene in validation_scenes
]

predictor = BinarySegmentationPredictor(
    config,
    model
    )


In [None]:
from utils.visualizing import raster_source_to_rgb
from tqdm.notebook import tqdm
from ml.eval_utils import evaluate_predictions

prediction_results_list = []
# crop_sz = int(config.tile_size // 5) #20% of the tiles at the edges are discarded
crop_sz = config.crop_sz

for ds in tqdm(evaluation_datasets):
    predictions = predictor.predict_mine_probability_for_site(ds, crop_sz)

    rgb_img = raster_source_to_rgb(ds.scene.raster_source)
    prediction_results_list.append({
        "predictions": predictions,
        "ground_truth": ds.scene.label_source.get_label_arr(),
        "rgb_img": rgb_img,
        "name": ds.scene.id,
        "crop_sz": crop_sz,
    })

evaluation_results_dict = evaluate_predictions(prediction_results_list)

## Pick the threshold that maximizes F1 score

In [None]:
threshold = evaluation_results_dict['eval/total/best_threshold']
threshold

In [None]:
evaluation_results_dict


Log results to Weights & Biases

In [None]:
import wandb
from ml.eval_utils import make_wandb_segmentation_masks, make_wandb_predicted_probs_images

assert wandb.run is not None

# Add lists of W&B images to dict
evaluation_results_dict.update({
    'Segmentation masks': make_wandb_segmentation_masks(prediction_results_list, threshold),
    'Predicted probabilites': make_wandb_predicted_probs_images(prediction_results_list),
})

# Log to W&B
wandb.log(evaluation_results_dict)

In [None]:
wandb.finish()