# Mapping mining areas with DINO & SAM

This notebook aims to detect and map mining areas using the Grounding DINO model for Object Detection, the Segment Anything Model (SAM) for segmentation. 

In [None]:
# for development
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import numpy as np
import random
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import rasterio


import sys
sys.path.append('../..')

from src.models.samgeo.model import MineSamGeo
from src.utils import geotiff_to_PIL
from src.visualization.visualization_funcs import plot_pred_vs_true_mask, plot_predictions

# check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# set working directory to root 
import os
os.chdir("../../")
root = os.getcwd()
# root = root + "/workspaces/mine-segmentation" # for lightning studios
print(f"Root directory: {root}")

## Setup

In [None]:
# chip_folder = root + "/data/processed/chips/train/chips"
# chip_mask_folder = root + "/data/processed/chips/train/labels"
# output_folder = root + "/data/output"

TEST_CHIP_DIR = "data/processed/chips/tif/2048/test/chips/"
TEST_LABEL_DIR = "data/processed/chips/tif/2048/test/labels/"
VAL_CHIP_DIR = "data/processed/chips/tif/2048/val/chips/"
VAL_LABEL_DIR = "data/processed/chips/tif/2048/val/labels/"
OUTPUT_DIR = "reports/samgeo/"

TEST_CHIP_DIR = root + "/" + TEST_CHIP_DIR
TEST_LABEL_DIR = root + "/" + TEST_LABEL_DIR
VAL_CHIP_DIR = root + "/" + VAL_CHIP_DIR
VAL_LABEL_DIR = root + "/" + VAL_LABEL_DIR
OUTPUT_DIR = root + "/" + OUTPUT_DIR
print(VAL_CHIP_DIR)

# Run the SamGeo Model with text prompts on Mining areas

In [None]:
# for development
%load_ext autoreload
%autoreload 2

In [None]:
msg = MineSamGeo(
    model_type="vit_b",
    chips_dir=TEST_CHIP_DIR,
    mask_dir=TEST_LABEL_DIR,
    output_dir=OUTPUT_DIR,
)

msg.num_chips

In [None]:
chip_path = msg.get_chip_path(0)

msg.predict(
    chip_path=chip_path,
    text_prompt="extractive mine", 
    box_threshold=0.15,
    text_threshold=0.01
)

print(msg.model.logits)
print(msg.model.boxes)
print(msg.model.phrases)
print(msg.calculate_metrics())

In [None]:
# get image
chip = geotiff_to_PIL(chip_path)
chip = np.array(chip)

# get mask
mask_path = msg.get_mask_path(chip_path)
with rasterio.open(mask_path) as src:
    mask = src.read(1)

# get prediction
pred = msg.model.prediction

plot_predictions(chip, mask, pred)

In [None]:
# Only show bounding boxes
fig = msg.show_anns(
    cmap="Blues",
    box_color="red",
    title="Bounding Boxes",
    blend=True,
    add_boxes=True,
    add_masks=False,
    alpha=0.1
)

In [None]:
# Only show mask
fig = msg.show_anns(
    cmap="Blues",
    title="Prediction Mask",
    blend=True,
    add_boxes=False,
    add_masks=True,
    alpha=0.5
)

In [None]:
# Only ground truth mask
fig = msg.show_true_mask(
    cmap="Blues",
    title="Prediction Mask",
    blend=True,
    add_boxes=False,
    add_masks=True,
    alpha=0.5
)

In [None]:
# Show ground truth mask and prediction mask
fig = msg.show_pred_vs_true_mask(
    blend=True,
    add_boxes=False,
    add_masks=True,
    alpha=0.4
)

In [None]:
# Show ground truth mask and prediction mask for a random sample
import random
for i in random.sample(range(msg.num_chips), 7):
    chip_path = msg.get_chip_path(i)

    msg.predict(
        chip_path=chip_path,
        text_prompt="extractive site", 
        box_threshold=0.08, 
        text_threshold=0.01,
        box_size_threshold=0.4
    )

    # print(msg.calculate_metrics())

    # Only show bounding boxes
    fig = msg.show_anns(
        cmap="Blues",
        box_color="red",
        title="Bounding Boxes",
        blend=True,
        add_boxes=True,
        add_masks=False,
        alpha=0.1
    )
    
    fig = msg.show_pred_vs_true_mask(
        title="Prediction Mask",
        blend=True,
        add_boxes=False,
        add_masks=True,
        alpha=0.4
    )

## Test best Hyperparameters

### Prompts
`prompt`: The prompt is an integral part of the object detection for grounding DINO. if multiple objects should be detected, prompts can be separated with a `.`, like so: `mine . city .`. 

`negative_prompt`: **TODO**. It is often the case that Dino detects cities or other objects that are not mines as mine. Maybe it is possible to use this to to detect negative classes (cities, forest, industrial area, deforestation), and where these are detected, the bounding boxes for any positive classes (mines) are removed. 


### Thresholds
Part of the model prediction includes setting appropriate thresholds for object detection and text association with the detected objects. These threshold values range from 0 to 1 and are set while calling the predict method of the LangSAM class.

`box_threshold`: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.

`text_threshold`: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.

`box_size_threshold`: This is a custom threshold used to discard bounding boxes that are too large. It represents the maximum size of the bounding box as a fraction of the image. For example, 0.5 discards all bounding boxes covering more than half the image. 

Remember to test different threshold values on your specific data. The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts. Make sure to choose a balance that suits your requirements, whether that's precision or recall.

### Perform hyperparameter grid search

In [None]:
# Define the model
msg = MineSamGeo(
    model_type="vit_b",
    chips_dir=VAL_CHIP_DIR, # VAL_CHIP_DIR or TEST_CHIP_DIR
    mask_dir=VAL_LABEL_DIR, # VAL_LABEL_DIR or TEST_LABEL_DIR
    output_dir=OUTPUT_DIR,
)

# Define the range of values for the hyperparameters
text_prompts = ["mine", "extractive site", "mine site", "extractive mine"]
box_thresholds = [.06, .08, .1]
text_thresholds = [0.01]
box_size_thresholds = [0.3, 0.4, 0.5]

# Take best hyperparameters and run inference on test set
# text_prompts = ["extractive site"]
# box_thresholds = [.08]
# text_thresholds = [0.01]
# box_size_thresholds = [0.4]

# # Define the number of tiles to sample
# n_chips = 116
# sample_indices = random.sample(range(msg.num_chips), n_chips)
# OR 
n_chips = msg.num_chips
sample_indices = range(msg.num_chips)

# Initialize variables to store the best hyperparameters and corresponding metrics
best_metrics = None
best_box_threshold = None
best_text_threshold = None
best_box_size_threshold = None

# Create an empty dataframe to store the metrics
metrics_df = pd.DataFrame(columns=['prompt', 'box_threshold', 'text_threshold', 'box_size_threshold', 'iou', 'f1_score', 'accuracy', 'precision', 'recall'])

# Perform grid search
total_iterations = len(text_prompts) * len(box_thresholds) * len(text_thresholds) * len(box_size_thresholds) * n_chips
progress_bar = tqdm(total=total_iterations, desc="Grid Search Progress")

for prompt in text_prompts:
    for box_threshold in box_thresholds:
        for text_threshold in text_thresholds:
            for box_size_threshold in box_size_thresholds:
                # Reset the metrics for each combination of hyperparameters
                metrics = {'iou': 0, 'f1_score': 0, 'accuracy': 0, 'precision': 0, 'recall': 0}

                # Perform predictions with the current hyperparameters
                for i in sample_indices:
                    chip_path = msg.get_chip_path(i)

                    msg.predict(
                        chip_path=chip_path,
                        text_prompt=prompt, 
                        box_threshold=box_threshold, 
                        text_threshold=text_threshold,
                        box_size_threshold=box_size_threshold
                    )

                    metrics_i = msg.calculate_metrics()

                    # Accumulate the metrics for each chip
                    metrics['iou'] += metrics_i['iou']
                    metrics['f1_score'] += metrics_i['f1_score']
                    metrics['accuracy'] += metrics_i['accuracy']
                    metrics['precision'] += metrics_i['precision']
                    metrics['recall'] += metrics_i['recall']

                    # Update the progress bar
                    progress_bar.update(1)

                # Average the metrics over all chips
                metrics['iou'] /= n_chips
                metrics['f1_score'] /= n_chips
                metrics['accuracy'] /= n_chips
                metrics['precision'] /= n_chips
                metrics['recall'] /= n_chips

                # Add metrics to the dataframe
                metrics_df = pd.concat(
                    [metrics_df, pd.DataFrame({
                        'prompt': [prompt],
                        'box_threshold': [box_threshold],
                        'text_threshold': [text_threshold],
                        'box_size_threshold': [box_size_threshold],
                        'iou': [metrics['iou']],
                        'f1_score': [metrics['f1_score']],
                        'accuracy': [metrics['accuracy']],
                        'precision': [metrics['precision']],
                        'recall': [metrics['recall']]
                        })], ignore_index=True)

# Close the progress bar
progress_bar.close()

# Find the best performing prompt and corresponding metrics
best_metrics = metrics_df.loc[metrics_df['f1_score'].idxmax()]
best_prompt = best_metrics['prompt']
best_box_threshold = best_metrics['box_threshold']
best_text_threshold = best_metrics['text_threshold']
best_box_size_threshold = best_metrics['box_size_threshold']
print("Best prompt:", best_prompt)
print("Best box_threshold:", best_box_threshold)
print("Best text_threshold:", best_text_threshold)
print("Best box_size_threshold:", best_box_size_threshold)
print("Best metrics:", best_metrics)

# Save the metrics dataframe to a CSV file
time = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
metrics_df.to_csv(root + f"/reports/samgeo_gridsearch_{time}.csv", index=False)

### Plot the hyperparameter grid search

In [None]:
# load the output csv
metrics_df = pd.read_csv(root + "/reports/samgeo_gridsearch_20241003_093349.csv")
metrics_df = metrics_df.sort_values(by='iou', ascending=False)
top_5_df = metrics_df.head(5)

# reset index
top_5_df = top_5_df.reset_index(drop=True)

hypers_to_plot = ['prompt', 'box_threshold', 'text_threshold', 'box_size_threshold']

top_5_df['hyperparams_str'] = top_5_df[hypers_to_plot].apply(lambda x: '\n'.join(x.astype(str)), axis=1)

# add explanations to the first row
first_row = top_5_df.loc[0,"hyperparams_str"]

# split the string by newline
first_row = first_row.split("\n")

# insert explanations
first_row[0] = f"prompt: {first_row[0]}"
first_row[1] = f"box_threshold: {first_row[1]}"
first_row[2] = f"text_threshold: {first_row[2]}"
first_row[3] = f"box_size_threshold: {first_row[3]}"

# join the string by newline
first_row = "\n".join(first_row)

# add the explanations to the first row
top_5_df.loc[0,"hyperparams_str"] = first_row

# Plotting
plt.figure(figsize=(10, 6))
plt.bar(top_5_df['hyperparams_str'], top_5_df['iou'])
plt.xticks(rotation=45)
plt.xlabel('Hyperparameter combinations')
plt.ylabel('IoU')
plt.title('Top 5 Hyperparameters combinations by IoU')
plt.tight_layout()
plt.show()

In [None]:
metrics_df[[ "prompt", "box_threshold", "text_threshold", "box_size_threshold", "iou", "f1_score", "precision", "recall"]].head(5)

In [None]:
import plotly.express as px

metrics_df['f1_score_cat'] = pd.qcut(metrics_df['f1_score'], q=10, labels=False)

# Create the parallel coordinates plot
fig = px.parallel_coordinates(metrics_df, 
                              dimensions=['f1_score', 'iou', 'accuracy', 'precision', 'recall'],
                              color='f1_score_cat',
                              labels={'f1_score': 'f1_score',
                                      'iou': 'iou',
                                      'accuracy': 'accuracy',
                                      'precision': 'precision',
                                      'recall': 'recall'},
                              color_continuous_scale=px.colors.diverging.Tealrose,
                              color_continuous_midpoint=2)

# Show the plot
fig.show()