In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import pandas as pd
from datasets import Dataset, load_from_disk
import pytorch_lightning as pl
import sys
import os
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
from model.minor_models.autoSamModel import LitSamModel
from model.samDataset import SAMDataset, SAMDataset3
from utils.statistics import calculate_correlation
from helperFunctions import *

In [None]:
import yaml
import os
from pathlib import Path

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

In [None]:
#Load the test dataset

test_dataset = load_from_disk(os.path.join(DATA_DIR, testDatasetFinal)

In [None]:
from transformers import SamModel, SamConfig, SamProcessor
import torch

sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/autosam/sam-auto-bestbase-smallnet-model-epoch=105-val_loss=0.2268-val_iou=0.590.ckpt")

# Create an instance of the model architecture with the loaded configuration
model = LitSamModel.load_from_checkpoint(sam_checkpoint, model_name="vit-b", normalize = True, adapt = True)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
#model.to(device)

In [None]:
def calculate_iou(mask1, mask2):

    # Ensure the masks are PyTorch tensors
    if isinstance(mask1, np.ndarray):
        mask1 = torch.tensor(mask1)
    if isinstance(mask2, np.ndarray):
        mask2 = torch.tensor(mask2)
        
    # Ensure the masks are binary
    mask1 = mask1 > 0
    mask2 = mask2 > 0
    
    # Calculate the intersection and union
    intersection = torch.logical_and(mask1, mask2)
    union = torch.logical_or(mask1, mask2)
    
    # Compute the IoU
    iou = torch.sum(intersection).float() / torch.sum(union).float()
    
    return iou, intersection, union

In [None]:
from torch.utils.data import DataLoader

# Create an instance of the SAMDataset
test_dataset_sam = SAMDataset3(dataset=test_dataset, processor=processor, augment=False, test = True)

# Create a DataLoader instance for the validation dataset
test_dataloader = DataLoader(test_dataset_sam, batch_size=5, shuffle=False, num_workers=8)

In [None]:
# Create a trainer
trainer = pl.Trainer(accelerator='gpu', devices=1)

# Run the evaluation
trainer.test(model, dataloaders=test_dataloader)

In [None]:
# Access the stored results
test_results = model.test_results
ground_truth_masks = test_results['ground_truth_masks']
predicted_masks = test_results['predicted_masks']
individual_ious = test_results['individual_ious']
bboxes = test_results['bboxes']

In [None]:
predicted_masks = np.concatenate(predicted_masks, axis=0)
ground_truth_masks = np.concatenate(ground_truth_masks, axis=0)
bounding_boxes = np.concatenate(bboxes, axis=0)

In [None]:
results_zero_shot = generate_results(test_dataset, ground_truth_masks, predicted_masks, bounding_boxes, calculate_iou)

In [None]:
#remove exemples with bounding boxes (0,0,512,512)
results_zero_shot_copy = results_zero_shot.copy()
results_zero_shot = [res for res in results_zero_shot if not (res['bbox'] == [0, 0, 1024, 1024]).all()]
results_zero_shot_general = [res for res in results_zero_shot_copy if (res['bbox'] == [0, 0, 1024, 1024]).all()]

results, metrics =calculate_pixel_based_metrics(results_zero_shot)

print(metrics)

results, metrics =calculate_pixel_based_metrics(results_zero_shot_general)

print(metrics)

results_zero_shot = results_zero_shot_copy

In [None]:
# Create a DataFrame from the list of dictionaries
df_finetune = pd.DataFrame(results_zero_shot)

# Save the DataFrame to a file (optional)
#df_finetune.to_pickle('dataframe_name.pkl')

In [None]:
# Filter out rows with None IoU values
df_filtered = df_finetune.dropna(subset=['iou'])


# Create a boolean mask for rows with bbox (0,0,512,512) using apply
mask_bbox = df_filtered['bbox'].apply(
    lambda b: np.all(np.array(b) == np.array((0, 0, 1024, 1024)))
)

In [None]:
# Filter the DataFrame to exclude rows with bbox (0,0,512,512)

df_filtered = df_filtered[~mask_bbox]

In [None]:
#
df_filtered = df_filtered[mask_bbox]

In [None]:
plot_iou_statistics(df_filtered, model_name='AutoSAM Best Base with SmallNet')

In [None]:
plot_mask_area_vs_iou(df_filtered, model_name='AutoSAM SAM Base with Resnet')

In [None]:
import matplotlib.pyplot as plt

# Calculate the area of each mask
df_filtered['mask_area'] = df_filtered['mask'].apply(lambda mask: np.sum(mask))

# Filter out rows where the mask area is above 5000
df_filtered2 = df_filtered[df_filtered['mask_area'] <= 1000]

# Create a scatter plot
plt.figure(figsize=(10, 6))
plt.scatter(df_filtered2['mask_area'], df_filtered2['iou'], marker='o', color='blue')
plt.xlabel('Mask Area')
plt.ylabel('IoU')
plt.title('IoU vs. Mask Area')
plt.grid(True)
plt.show()

In [None]:
df_filtered['bbox'] = df_filtered['bbox'].apply(lambda b: np.squeeze(b) if np.array(b).ndim == 2 else b)

In [None]:
import matplotlib.pyplot as plt

# Create a scatter plot with a logarithmic scale for the mask area axis
plt.figure(figsize=(10, 6))
plt.scatter(df_filtered['mask_area'], df_filtered['iou'], marker='o', color='blue')
plt.xscale('log')
plt.xlabel('Mask Area (log scale)')
plt.ylabel('IoU')
plt.title('IoU vs. Mask Area')
plt.grid(True)
plt.show()

In [None]:
plot_iou_vs_area_ratio(df_filtered, model_name='AutoSAM SAM Base with Resnet')

In [None]:
plot_iou_vs_num_avalanches(df_filtered, model_name='AutoSAM multidimnet')

In [None]:
plot_iou_for_mask_area(df_filtered, model_name='AutoSAM SAM Base with Resnet')

In [None]:
df_filtered['mask_area'] = df_filtered['mask'].apply(lambda mask: np.sum(np.array(mask, dtype=np.float32)))

In [None]:
compute_mask_area_iou_correlation(df_filtered, calculate_correlation, model_name='AutoSAM multidimnet')

In [None]:
# Ensure no zeros (to avoid log issues) by adding a small constant if needed.
mask_area_numeric = np.asarray(df_filtered['mask_area'].values, dtype=np.float32)
iou_numeric = np.asarray(df_filtered['iou'].values, dtype=np.float32)
epsilon = 1e-8
mask_area_log = np.log(mask_area_numeric + epsilon)

corr, p_value = calculate_correlation(mask_area_log, iou_numeric)
print(f"Correlation on log-transformed values: {corr}, p-value: {p_value}")

In [None]:
compute_num_avalanche_iou_correlation(df_filtered, calculate_correlation, model_name='AutoSAM multidimnet')

In [None]:
compute_area_ratio_iou_correlation(df_filtered, calculate_correlation, model_name='AutoSAM multidimnet')

In [None]:
compute_area_ratio_iou_correlation(df_filtered, calculate_correlation, model_name='AutoSAM multidimnet', scale='log')

In [None]:
# Save the DataFrame to a file (optional)
df_filtered.to_pickle('dataframe_bestadaptersmodel.pkl')

In [None]:
df_finetune = df_finetune = pd.read_pickle('dataframe_bestadaptersmodel.pkl')

In [None]:
# Function to find bounding boxes for each group of disconnected white pixels
def find_bounding_boxes(mask):
    # Ensure the mask is an 8-bit image.
    if mask.dtype != "uint8":
        # If mask values are in range 0-1, scale them by 255
        if mask.max() <= 1:
            mask_uint8 = (mask * 255).astype('uint8')
        else:
            mask_uint8 = mask.astype('uint8')
    else:
        mask_uint8 = mask

    # Find contours in the binary mask
    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Compute bounding boxes for each contour and convert (x, y, w, h) to (x_min, y_min, x_max, y_max)
    bounding_boxes = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        bounding_boxes.append([x, y, x + w, y + h])
    return bounding_boxes

In [None]:


for res in results_zero_shot:
    false_positives = []
    false_negatives = []
    # The predicted and ground truth avalanche masks
    pred_mask = res['calculated_mask']
    true_mask = res['mask']

    original_bbox = find_bounding_boxes(true_mask)
    pred_bbox = find_bounding_boxes(pred_mask)

    for bbox in original_bbox:
        predicted = False
        for pred in pred_bbox:

            # Assuming each bbox is [x_min, y_min, x_max, y_max]
            x_left = max(bbox[0], pred[0])
            y_top = max(bbox[1], pred[1])
            x_right = min(bbox[2], pred[2])
            y_bottom = min(bbox[3], pred[3])

            # Calculate the overlap area
            if x_right < x_left or y_bottom < y_top:
                overlap_area = 0
            else:
                overlap_area = (x_right - x_left) * (y_bottom - y_top)

            overlap = overlap_area / ((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) )#+ (pred[2] - pred[0]) * (pred[3] - pred[1]) - overlap_area)
            if overlap > 0.3:
                predicted = True
                break
        if not predicted:
            false_negatives.append(bbox)
    for bbox in pred_bbox:
        original = False
        for originalbbox in original_bbox:
            # Assuming each bbox is [x_min, y_min, x_max, y_max]
            x_left = max(bbox[0], originalbbox[0])
            y_top = max(bbox[1], originalbbox[1])
            x_right = min(bbox[2], originalbbox[2])
            y_bottom = min(bbox[3], originalbbox[3])

            # Calculate the overlap area
            if x_right < x_left or y_bottom < y_top:
                overlap_area = 0
            else:
                overlap_area = (x_right - x_left) * (y_bottom - y_top)

            overlap = overlap_area / ((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) )#+ (originalbbox[2] - originalbbox[0]) * (originalbbox[3] - originalbbox[1]) - overlap_area)
            if overlap > 0.3:
                original = True
                break
        if not original:
            false_positives.append(bbox)
            
    # Add the results into the current dictionary element
    res['false_negatives'] = false_negatives
    res['false_positives'] = false_positives
    res['percentage_false_negatives'] = len(false_negatives) / len(original_bbox) if len(original_bbox) > 0 else 0
    res['percentage_false_positives'] = len(false_positives) / len(pred_bbox) if len(pred_bbox) > 0 else 0
        


In [None]:
results_zero_shot = compute_error_percentages(results_zero_shot, threshold=0.3)

In [None]:
#remove exemples with bounding boxes (0,0,512,512)
results_zero_shot_copy = results_zero_shot.copy()
results_zero_shot = [res for res in results_zero_shot if not (res['bbox'] == [0, 0, 1024, 1024]).all()]
results_zero_shot_general = [res for res in results_zero_shot_copy if (res['bbox'] == [0, 0, 1024, 1024]).all()]

In [None]:
results_zero_shot = results_zero_shot_copy

In [None]:
# New cell: Plot the percentage of false positives and false negatives with averages
import matplotlib.pyplot as plt
import numpy as np

# Extract percentages from the results
false_negatives_pct = [res.get('percentage_false_negatives', 0) for res in results_zero_shot]
false_positives_pct = [res.get('percentage_false_positives', 0) for res in results_zero_shot]
indices = np.arange(len(results_zero_shot))

# Compute average percentages
avg_false_negatives = np.mean(false_negatives_pct)
avg_false_positives = np.mean(false_positives_pct)

plt.figure(figsize=(12, 6))
plt.bar(indices - 0.15, false_negatives_pct, width=0.3, color='red', label='False Negatives')
plt.bar(indices + 0.15, false_positives_pct, width=0.3, color='blue', label='False Positives')

# Plot average lines
plt.axhline(avg_false_negatives, color='darkred', linestyle='--', 
            label=f'Avg False Negatives: {avg_false_negatives:.2f}')
plt.axhline(avg_false_positives, color='darkblue', linestyle='--', 
            label=f'Avg False Positives: {avg_false_positives:.2f}')

plt.xlabel('Sample Index')
plt.ylabel('Percentage')
plt.title('Percentage of False Negatives and False Positives per Sample')
plt.legend()
plt.show()

In [None]:
# New cell: Visualization of samples with high error percentages

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
import numpy as np

# Define a threshold for high error (e.g., 30% error)
threshold = 0.8

# Filter samples with either false negative or false positive percentage above the threshold
high_error_results = [
    res for res in results_zero_shot 
    if res.get('percentage_false_negatives', 0) > threshold or res.get('percentage_false_positives', 0) > threshold
]

print(f"Number of high error samples: {len(high_error_results)}")

for idx, res in enumerate(high_error_results):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    
    # --- Left: Ground truth mask with false negative outlines --- 
    mask = res['mask']
    if mask.ndim == 2:
        gt_img = (mask * 255).astype(np.uint8)
        gt_img = cv2.cvtColor(gt_img, cv2.COLOR_GRAY2RGB)
    else:
        gt_img = mask.copy()
    ax1.imshow(gt_img)
    ax1.set_title("Ground Truth with False Negatives")
    
    for bbox in res.get('false_negatives', []):
        rect = patches.Rectangle((bbox[0], bbox[1]),
                                 bbox[2] - bbox[0],
                                 bbox[3] - bbox[1],
                                 linewidth=2,
                                 edgecolor='r',
                                 facecolor='none')
        ax1.add_patch(rect)
        
    # --- Right: Predicted mask with false positive outlines ---
    pred_mask = res['calculated_mask']
    if pred_mask.ndim == 2:
        pred_img = (pred_mask * 255).astype(np.uint8)
        pred_img = cv2.cvtColor(pred_img, cv2.COLOR_GRAY2RGB)
    else:
        pred_img = pred_mask.copy()
    ax2.imshow(pred_img)
    ax2.set_title("Prediction with False Positives")
    
    for bbox in res.get('false_positives', []):
        rect = patches.Rectangle((bbox[0], bbox[1]),
                                 bbox[2] - bbox[0],
                                 bbox[3] - bbox[1],
                                 linewidth=2,
                                 edgecolor='b',
                                 facecolor='none')
        ax2.add_patch(rect)
        
    plt.show()

    if idx >= 10:  # Limit to first 5 high error samples for visualization
        break