In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import pandas as pd
from datasets import Dataset, load_from_disk
import pytorch_lightning as pl
import sys
import os
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
from model.adapterModel import LitSamModel
from model.samDataset import SAMDataset, SAMDataset3
from utils.statistics import calculate_correlation
from helperFunctions import *
from model.inputTypes import InputTypes
from torch.utils.data import DataLoader

In [None]:
import yaml
import os
from pathlib import Path

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

In [None]:
#Load the test dataset

test_dataset = load_from_disk(os.path.join(DATA_DIR, "datasetTestFinal"))

In [None]:
from transformers import SamModel, SamConfig, SamProcessor
import torch

sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/firstencoder/sam-adapter-VV-epoch=89-val_loss=0.222-val_iou=0.599.ckpt")

# Create an instance of the model architecture with the loaded configuration
model1 = LitSamModel.load_from_checkpoint(sam_checkpoint, model_name="facebook/sam-vit-base", normalize = True, adapt = True, input_type=InputTypes.VV, HQ = False)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
from model.minor_models.selfSupervisedTwoEncodersModel import selfSupSamModel

#Initialize the  self sup model
#/home/gelato/Avalanche-Segmentation-with-Sam/code/training/checkpoints/mmsamsecondencoder/sam-selfsup-1-MMSAM-epoch=63-val_loss=0.001-val_iou=0.000.ckpt
#
#
sam_checkpoint = sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/mmsamsecondencoder/sam-selfsup-3-MMSAM-epoch=62-val_loss=0.001-val_iou=0.000.ckpt")

image_encoder = LitSamModel(model_name="vit_b", normalize=True, learning_rate=1e-5, adapt_patch_embed=False, input_type=InputTypes.VH, adapt=True).model.image_encoder
target_image_encoder = LitSamModel(model_name="vit_b", normalize=True, learning_rate=1e-5, adapt_patch_embed=False, input_type=InputTypes.VH, adapt=True).model.image_encoder

# Create an instance of the model architecture with the loaded configuration
selfsupmodel = selfSupSamModel.load_from_checkpoint(sam_checkpoint, model_name="vit-b", normalize = True, adapt_patch_embed=False, image_encoder = image_encoder, target_image_encoder = target_image_encoder)

# Extract only the image encoder
encoder = selfsupmodel.model.image_encoder
decoder = model1.model.mask_decoder

# Create an instance of the model architecture with the loaded configuration
model2 = LitSamModel(model_name="facebook/sam-vit-base", normalize = True, input_type=InputTypes.VH, adapt = True, encoder = encoder, decoder = decoder)

In [None]:
#Standard SAM with VH training on freezed decoder

sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/ourssecondencoder/sam-adapter-complementaryVH-1-epoch=50-val_loss=0.245-val_iou=0.575.ckpt")

# Create an instance of the model architecture with the loaded configuration
model2 = LitSamModel.load_from_checkpoint(sam_checkpoint, model_name="facebook/sam-vit-base", normalize = True, input_type=InputTypes.VH, adapt = True )

In [None]:
def calculate_iou(mask1, mask2):

    # Ensure the masks are PyTorch tensors
    if isinstance(mask1, np.ndarray):
        mask1 = torch.tensor(mask1)
    if isinstance(mask2, np.ndarray):
        mask2 = torch.tensor(mask2)
        
    # Ensure the masks are binary
    mask1 = mask1 > 0
    mask2 = mask2 > 0
    
    # Calculate the intersection and union
    intersection = torch.logical_and(mask1, mask2)
    union = torch.logical_or(mask1, mask2)
    
    # Compute the IoU
    iou = torch.sum(intersection).float() / torch.sum(union).float()
    
    return iou, intersection, union

In [None]:
from model.sumModel import SumSamModel, SAMDataset as SumDataset, SAMDataset3 as SumDataset3
from torch.utils.data import DataLoader


sumModel = SumSamModel(models = [model1, model2], model_name="vit-b", normalize = True, HQ = True)

# Create an instance of the SAMDataset
test_dataset_sam = SumDataset3(dataset=test_dataset, processor=processor, augment=False, test = True)

# Create a DataLoader instance for the validation dataset
test_sum_dataloader = DataLoader(test_dataset_sam, batch_size=5, shuffle=False, num_workers=8, drop_last=True)

In [None]:
from model.sumModel import SumSamModel, SAMDataset as SumDataset, SAMDataset3 as SumDataset3
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model1.to(device)
model2.to(device)
#
#
#
#
#/home/gelato/Avalanche-Segmentation-with-Sam/code/training/checkpoints/sfg/sam-sfg-2-epoch=48-val_loss=0.204-val_iou=0.626.ckpt
#
sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/sfg/sam-sfg-1-epoch=06-val_loss=0.201-val_iou=0.624.ckpt")
sumModel = SumSamModel.load_from_checkpoint(sam_checkpoint, models = [model1, model2], model_name="vit-b", normalize = True, HQ = False)

sumModel.to(device)

# Create an instance of the SAMDataset
test_dataset_sam = SumDataset3(dataset=test_dataset, processor=processor, augment=False, test = True)

# Create a DataLoader instance for the validation dataset
test_sum_dataloader = DataLoader(test_dataset_sam, batch_size=5, shuffle=False, num_workers=8, drop_last=True)

In [None]:
# Create a trainer
trainer3 = pl.Trainer(accelerator='gpu', devices=1)

# Run the evaluation
trainer3.test(sumModel, dataloaders=test_sum_dataloader, )

In [None]:
# Access the stored results
test_results3 = sumModel.test_results
ground_truth_masks3 = test_results3['ground_truth_masks']
predicted_masks3 = test_results3['predicted_masks']
individual_ious3 = test_results3['individual_ious']
bboxes3 = test_results3['bboxes']

In [None]:
predicted_masks3 = np.concatenate(predicted_masks3, axis=0)
ground_truth_masks3 = np.concatenate(ground_truth_masks3, axis=0)
bounding_boxes3 = np.concatenate(bboxes3, axis=0)

In [None]:
subset_indices = list(range(384))  # choose first 384 examples for instance
test_dataset_subset = test_dataset.select(subset_indices)

In [None]:
results_zero_shot3 = generate_results(test_dataset_subset, ground_truth_masks3, predicted_masks3, bounding_boxes3)

In [None]:
#remove exemples with bounding boxes (0,0,512,512)
results_zero_shot_copy3 = results_zero_shot3.copy()
results_zero_shot3 = [res for res in results_zero_shot3 if not (res['bbox'] == [0, 0, 1024, 1024]).all()]
results_zero_shot_general3 = [res for res in results_zero_shot_copy3 if (res['bbox'] == [0, 0, 1024, 1024]).all()]

results, metrics =calculate_pixel_based_metrics(results_zero_shot3)

print(metrics)

results, metrics =calculate_pixel_based_metrics(results_zero_shot_general3)


print(metrics)

results_zero_shot = results_zero_shot_copy3

In [None]:
# Create a DataFrame from the list of dictionaries
df_finetune3 = pd.DataFrame(results_zero_shot3)

In [None]:
# Filter out rows with None IoU values
df_filtered3 = df_finetune3.dropna(subset=['iou'])


# Create a boolean mask for rows with bbox (0,0,512,512) using apply
mask_bbox3 = df_filtered3['bbox'].apply(
    lambda b: np.all(np.array(b) == np.array((0, 0, 1024, 1024)))
)

In [None]:
df_filtered3 = df_filtered3[~mask_bbox3]

In [None]:
df_filtered3 = df_filtered3[mask_bbox3]

In [None]:
plot_iou_statistics(df_filtered3, model_name='SUM model sfg HQ')

In [None]:
plot_iou_statistics(df_filtered3, model_name='SUM model sfg MM')

In [None]:
plot_iou_statistics(df_filtered3, model_name='SUM model sfg full segmentation')

In [None]:
plot_iou_statistics(df_filtered3, model_name='SUM model sfg')

In [None]:
plot_iou_statistics(df_filtered3, model_name='SUM model MM')

In [None]:
plot_iou_statistics(df_filtered3, model_name='SUM model')

In [None]:
plot_mask_area_vs_iou(df_filtered3, model_name='SUM model sfg')

In [None]:
plot_iou_vs_num_avalanches(df_filtered3, model_name='SUM model sfg')

In [None]:
df_filtered3['bbox'] = df_filtered3['bbox'].apply(lambda b: np.squeeze(b) if np.array(b).ndim == 2 else b)

In [None]:
plot_iou_vs_area_ratio(df_filtered3, model_name='SUM model sfg')

In [None]:
compute_mask_area_iou_correlation(df_filtered3, calculate_correlation, model_name='SUM model sfg ')
compute_mask_area_iou_correlation(df_filtered3, calculate_correlation, model_name='SUM model sfg', scale='log')
compute_num_avalanche_iou_correlation(df_filtered3, calculate_correlation, model_name='SUM model sfg ')
compute_area_ratio_iou_correlation(df_filtered3, calculate_correlation, model_name='SUM model sfg ')
compute_area_ratio_iou_correlation(df_filtered3, calculate_correlation, model_name='SUM model sfg', scale='log')


In [None]:
# Function to find bounding boxes for each group of disconnected white pixels
def find_bounding_boxes(mask):
    # Ensure the mask is an 8-bit image.
    if mask.dtype != "uint8":
        # If mask values are in range 0-1, scale them by 255
        if mask.max() <= 1:
            mask_uint8 = (mask * 255).astype('uint8')
        else:
            mask_uint8 = mask.astype('uint8')
    else:
        mask_uint8 = mask

    # Find contours in the binary mask
    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Compute bounding boxes for each contour and convert (x, y, w, h) to (x_min, y_min, x_max, y_max)
    bounding_boxes = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        bounding_boxes.append([x, y, x + w, y + h])
    return bounding_boxes

In [None]:
treshold = 0.3

for res in results_zero_shot3:
    false_positives = []
    false_negatives = []
    # The predicted and ground truth avalanche masks
    pred_mask = res['calculated_mask']
    true_mask = res['mask']

    original_bbox = find_bounding_boxes(true_mask)
    pred_bbox = find_bounding_boxes(pred_mask)

    for bbox in original_bbox:
        predicted = False
        for pred in pred_bbox:

            pred_mask_local = np.zeros_like(pred_mask)
            pred_mask_local[pred[1]:pred[3], pred[0]:pred[2]] = pred_mask[pred[1]:pred[3], pred[0]:pred[2]]

            original_mask_local = np.zeros_like(true_mask)
            original_mask_local[bbox[1]:bbox[3], bbox[0]:bbox[2]] = true_mask[bbox[1]:bbox[3], bbox[0]:bbox[2]]

            # Calculate the overlap area
            overlap_area = np.sum(np.logical_and(pred_mask_local, original_mask_local))
            original_area = np.sum(original_mask_local)

            overlap = overlap_area / original_area if original_area > 0 else 0
            if overlap > treshold:
                predicted = True
                break
        if not predicted:
            false_negatives.append(bbox)
    for bbox in pred_bbox:
        original = False
        for originalbbox in original_bbox:
            
            pred_mask_local = np.zeros_like(pred_mask)
            pred_mask_local[bbox[1]:bbox[3], bbox[0]:bbox[2]] = pred_mask[bbox[1]:bbox[3], bbox[0]:bbox[2]]

            original_mask_local = np.zeros_like(true_mask)
            original_mask_local[originalbbox[1]:originalbbox[3], originalbbox[0]:originalbbox[2]] = true_mask[originalbbox[1]:originalbbox[3], originalbbox[0]:originalbbox[2]]

            # Calculate the overlap area
            overlap_area = np.sum(np.logical_and(pred_mask_local, original_mask_local))
            pred_area = np.sum(pred_mask_local)

            overlap = overlap_area / pred_area if pred_area > 0 else 0
            if overlap > treshold:
                original = True
                break
        if not original:
            false_positives.append(bbox)
            
    # Add the results into the current dictionary element
    res['false_negatives'] = false_negatives
    res['false_positives'] = false_positives
    res['percentage_false_negatives'] = len(false_negatives) / len(original_bbox) if len(original_bbox) > 0 else 0
    res['percentage_false_positives'] = len(false_positives) / len(pred_bbox) if len(pred_bbox) > 0 else 0
        


In [None]:
from helperFunctions import *

results_zero_shot3 = compute_error_percentages_iou(results_zero_shot3, iou_threshold=0.5)

In [None]:
#remove exemples with bounding boxes (0,0,512,512)
results_zero_shot_copy = results_zero_shot3.copy()
results_zero_shot3 = [res for res in results_zero_shot3 if not (res['bbox'] == [0, 0, 1024, 1024]).all()]
results_zero_shot_general = [res for res in results_zero_shot_copy if (res['bbox'] == [0, 0, 1024, 1024]).all()]

In [None]:
results_zero_shot3 = results_zero_shot_copy

In [None]:
# New cell: Plot the percentage of false positives and false negatives with averages
import matplotlib.pyplot as plt
import numpy as np



# Extract percentages from the results
false_negatives_pct = [res.get('percentage_false_negatives', 0) for res in results_zero_shot3]
false_positives_pct = [res.get('percentage_false_positives', 0) for res in results_zero_shot3]
indices = np.arange(len(results_zero_shot3))

# Compute average percentages
avg_false_negatives = np.mean(false_negatives_pct)
avg_false_positives = np.mean(false_positives_pct)

plt.figure(figsize=(12, 6))
plt.bar(indices - 0.15, false_negatives_pct, width=0.3, color='red', label='False Negatives')
plt.bar(indices + 0.15, false_positives_pct, width=0.3, color='blue', label='False Positives')

# Plot average lines
plt.axhline(avg_false_negatives, color='darkred', linestyle='--', 
            label=f'Avg False Negatives: {avg_false_negatives:.2f}')
plt.axhline(avg_false_positives, color='darkblue', linestyle='--', 
            label=f'Avg False Positives: {avg_false_positives:.2f}')

plt.xlabel('Sample Index')
plt.ylabel('Percentage')
plt.title(f'Percentage of False Negatives and False Positives per Sample (Baseline, Treshold = {treshold})')
plt.legend()
plt.show()

In [None]:
# New cell: Plot the percentage of false positives and false negatives with averages
import matplotlib.pyplot as plt
import numpy as np

# Extract percentages from the results
false_negatives_pct = [res.get('percentage_false_negatives', 0) for res in results_zero_shot_general]
false_positives_pct = [res.get('percentage_false_positives', 0) for res in results_zero_shot_general]
indices = np.arange(len(results_zero_shot_general))

# Compute average percentages
avg_false_negatives = np.mean(false_negatives_pct)
avg_false_positives = np.mean(false_positives_pct)

plt.figure(figsize=(12, 6))
plt.bar(indices - 0.15, false_negatives_pct, width=0.3, color='red', label='False Negatives')
plt.bar(indices + 0.15, false_positives_pct, width=0.3, color='blue', label='False Positives')

# Plot average lines
plt.axhline(avg_false_negatives, color='darkred', linestyle='--', 
            label=f'Avg False Negatives: {avg_false_negatives:.2f}')
plt.axhline(avg_false_positives, color='darkblue', linestyle='--', 
            label=f'Avg False Positives: {avg_false_positives:.2f}')

plt.xlabel('Sample Index')
plt.ylabel('Percentage')
plt.title('Percentage of False Negatives and False Positives per Sample')
plt.legend()
plt.show()

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

num_thresholds = 50
thresholds = np.linspace(0, 1, num_thresholds)
optimal_thresholds = []
gt_areas = []
optimal_ious = []

# You can also store IoU for each threshold if needed:
results_threshold = []

for index in range(384):

    # Skip samples with a bounding box of [0, 0, 512, 512]
    if np.array_equal(np.array(bounding_boxes3[index]), np.array([0, 0, 1024, 1024])):
        continue
    # Get the ground truth mask and predicted soft mask
    mask = ground_truth_masks3[index]
    sam_seg = predicted_masks3[index]
    
    # Convert predicted soft mask to probability map
    sam_seg_prob = torch.sigmoid(torch.tensor(sam_seg))
    sam_seg_prob_np = sam_seg_prob.cpu().numpy().squeeze()
    
    best_iou = -1
    best_thr = None
    
    for thr in thresholds:
        # Convert soft mask to binary using current threshold
        pred_mask = (sam_seg_prob_np > thr).astype(np.uint8)
        
        # Calculate IoU; calculate_iou returns (iou, intersection, union)
        iou, inter, union = calculate_iou(mask, pred_mask)
        iou_val = iou.cpu().item() if isinstance(iou, torch.Tensor) else iou
        
        results_threshold.append({
            'sample': index,
            'threshold': thr,
            'iou': iou_val
        })
        
        if iou_val > best_iou:
            best_iou = iou_val
            best_thr = thr
            
    optimal_thresholds.append(best_thr)
    optimal_ious.append(best_iou)  # Save the best IoU.
    # Compute ground truth mask area (number of non-zero pixels)
    area = np.sum(mask > 0)
    gt_areas.append(area)

In [None]:
import matplotlib.pyplot as plt

# optimal_ious is a list containing the best IoU for each sample
average_optimal_iou = sum(optimal_ious) / len(optimal_ious)

# Bar chart for optimal IoU values per sample
plt.figure(figsize=(10, 6))
plt.bar(range(len(optimal_ious)), optimal_ious, color='green')
plt.xlabel('Sample Index')
plt.ylabel('Optimal IoU')
plt.title('Optimal IoU Values for Different Samples')
plt.text(0.5, 0.95, f'Average Optimal IoU: {average_optimal_iou:.4f}', 
         ha='center', va='center', transform=plt.gca().transAxes, 
         fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
plt.show()

# Line plot for optimal IoU values per sample
plt.figure(figsize=(10, 6))
plt.plot(range(len(optimal_ious)), optimal_ious, marker='o', linestyle='-', color='green')
plt.xlabel('Sample Index')
plt.ylabel('Optimal IoU')
plt.title('Optimal IoU Values for Different Samples')
plt.text(0.5, 0.95, f'Average Optimal IoU: {average_optimal_iou:.4f}', 
         ha='center', va='center', transform=plt.gca().transAxes, 
         fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
plt.show()