In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import pandas as pd
from datasets import Dataset, load_from_disk
import pytorch_lightning as pl
import sys
import os
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
from model.minor_models.sfgModelST import SumSamModel, SAMDataset3
from utils.statistics import calculate_correlation
from helperFunctions import *

In [None]:
import yaml
import os
from pathlib import Path

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

In [None]:
#Load the test dataset
test_dataset = load_from_disk(os.path.join(DATA_DIR, 'datasetTestFinal'))


In [None]:
from transformers import SamModel, SamConfig, SamProcessor
import torch

sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/sam-adapters-simultaneousencoders-loss-epoch=146-val_loss=0.203-val_iou=0.604.ckpt")

# Create an instance of the model architecture with the loaded configuration
#model = LitSamModel(model_name="facebook/sam-vit-base")
model = SumSamModel.load_from_checkpoint(sam_checkpoint, model_name="sam-vit-b", normalize = True, adapt = True,)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
def calculate_iou(mask1, mask2):

    # Ensure the masks are PyTorch tensors
    if isinstance(mask1, np.ndarray):
        mask1 = torch.tensor(mask1)
    if isinstance(mask2, np.ndarray):
        mask2 = torch.tensor(mask2)
        
    # Ensure the masks are binary
    mask1 = mask1 > 0
    mask2 = mask2 > 0
    
    # Calculate the intersection and union
    intersection = torch.logical_and(mask1, mask2)
    union = torch.logical_or(mask1, mask2)
    
    # Compute the IoU
    iou = torch.sum(intersection).float() / torch.sum(union).float()
    
    return iou, intersection, union

In [None]:
from torch.utils.data import DataLoader

# Create an instance of the SAMDataset
test_dataset_sam = SAMDataset3(dataset=test_dataset, processor=processor, augment=False, target_size=1024, test = True)

# Create a DataLoader instance for the validation dataset
test_dataloader = DataLoader(test_dataset_sam, batch_size=5, shuffle=False, num_workers=8, drop_last=True)

In [None]:
# Create a trainer
trainer = pl.Trainer(accelerator='gpu', devices=1)

# Run the evaluation
trainer.test(model, dataloaders=test_dataloader)

In [None]:
# Access the stored results
test_results = model.test_results
ground_truth_masks = test_results['ground_truth_masks']
predicted_masks = test_results['predicted_masks']
individual_ious = test_results['individual_ious']
bboxes = test_results['bboxes']

In [None]:
predicted_masks = np.concatenate(predicted_masks, axis=0)
ground_truth_masks = np.concatenate(ground_truth_masks, axis=0)
bounding_boxes = np.concatenate(bboxes, axis=0)

In [None]:
subset_indices = list(range(384))  # choose first 384 examples for instance
test_dataset_subset = test_dataset.select(subset_indices)

In [None]:
results_zero_shot = generate_results(test_dataset_subset, ground_truth_masks, predicted_masks, bounding_boxes)

In [None]:
# Create a DataFrame from the list of dictionaries
df_finetune = pd.DataFrame(results_zero_shot)

# Save the DataFrame to a file (optional)
#df_finetune.to_pickle('dataframe_finetune_20epochs.pkl')

In [None]:
df_finetune = pd.read_pickle('dataframe_finetune_20epochs.pkl')

In [None]:
# Filter out rows with None IoU values
df_filtered = df_finetune.dropna(subset=['iou'])


# Create a boolean mask for rows with bbox (0,0,512,512) using apply
mask_bbox = df_filtered['bbox'].apply(
    lambda b: np.all(np.array(b) == np.array((0, 0, 1024, 1024)))
)

In [None]:
df_filtered = df_filtered[~mask_bbox]

In [None]:
df_filtered = df_filtered[mask_bbox]

In [None]:
plot_iou_statistics(df_filtered, model_name="smallprefix")

In [None]:
plot_mask_area_vs_iou(df_filtered, model_name='smallprefix')

In [None]:
df_filtered['bbox'] = df_filtered['bbox'].apply(lambda b: np.squeeze(b) if np.array(b).ndim == 2 else b)

In [None]:
plot_iou_vs_area_ratio(df_filtered, model_name='smallprefix')

In [None]:
plot_iou_vs_num_avalanches(df_filtered, model_name='smallprefix')

In [None]:
plot_iou_for_mask_area(df_filtered, model_name='smallprefix')

In [None]:
compute_mask_area_iou_correlation(df_filtered, calculate_correlation, model_name='smallprefix')

In [None]:
compute_mask_area_iou_correlation(df_filtered, calculate_correlation, model_name='smallprefix', scale='log')

In [None]:
compute_num_avalanche_iou_correlation(df_filtered, calculate_correlation, model_name='smallprefix')

In [None]:
compute_area_ratio_iou_correlation(df_filtered, calculate_correlation, model_name='smallprefix')

In [None]:
compute_area_ratio_iou_correlation(df_filtered, calculate_correlation, model_name='smallprefix', scale='log')

In [None]:
from helperFunctions import *

In [None]:
results_zero_shot = compute_error_percentages(results_zero_shot)

In [None]:
#remove exemples with bounding boxes (0,0,512,512)
results_zero_shot_copy = results_zero_shot.copy()
results_zero_shot = [res for res in results_zero_shot if not (res['bbox'] == [0, 0, 1024, 1024]).all()]
results_zero_shot_general = [res for res in results_zero_shot_copy if (res['bbox'] == [0, 0, 1024, 1024]).all()]

In [None]:
results_zero_shot = results_zero_shot_copy

In [None]:
plot_false_pos_neg_percentages(results_zero_shot,treshold= 0.3, model_name = "smallprefix")

In [None]:
plot_false_pos_neg_percentages(results_zero_shot_general,treshold= 0.3, model_name = "smallprefix")

In [None]:
# New cell: Visualization with mask background and bbox outlines

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
import numpy as np

num_results_to_display = 15

for idx, res in enumerate(results_zero_shot[:num_results_to_display]):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    
    # --- Left subplot: Ground truth mask with false negative outlines ---
    mask = res['mask']
    if mask.ndim == 2:
        gt_img = (mask * 255).astype(np.uint8)
        gt_img = cv2.cvtColor(gt_img, cv2.COLOR_GRAY2RGB)
    else:
        gt_img = mask.copy()
    ax1.imshow(gt_img)
    ax1.set_title("Ground Truth with False Negatives")
    
    false_negatives = res.get('false_negatives', [])
    print("Lenght of false negatives:", len(false_negatives))
    for bbox in false_negatives:
        # bbox is expected as [x_min, y_min, x_max, y_max]
        rect = patches.Rectangle((bbox[0], bbox[1]),
                                 bbox[2] - bbox[0],
                                 bbox[3] - bbox[1],
                                 linewidth=2,
                                 edgecolor='r',
                                 facecolor='none')
        ax1.add_patch(rect)
        
    
    # --- Right subplot: Predicted mask with false positive outlines ---
    pred_mask = res['calculated_mask']
    if pred_mask.ndim == 2:
        pred_img = (pred_mask * 255).astype(np.uint8)
        pred_img = cv2.cvtColor(pred_img, cv2.COLOR_GRAY2RGB)
    else:
        pred_img = pred_mask.copy()
    ax2.imshow(pred_img)
    ax2.set_title("Prediction with False Positives")
    
    false_positives = res.get('false_positives', [])
    print("Lenght of false positives:", len(false_positives))
    for bbox in false_positives:
        rect = patches.Rectangle((bbox[0], bbox[1]),
                                 bbox[2] - bbox[0],
                                 bbox[3] - bbox[1],
                                 linewidth=2,
                                 edgecolor='b',
                                 facecolor='none')
        ax2.add_patch(rect)
        

    ax1.axis('off')
    ax2.axis('off')
    plt.show()

In [None]:
# New cell: Visualization of samples with high error percentages

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
import numpy as np

# Define a threshold for high error (e.g., 30% error)
threshold = 0.8

# Filter samples with either false negative or false positive percentage above the threshold
high_error_results = [
    res for res in results_zero_shot 
    if res.get('percentage_false_negatives', 0) > threshold or res.get('percentage_false_positives', 0) > threshold
]

print(f"Number of high error samples: {len(high_error_results)}")

for idx, res in enumerate(high_error_results):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    
    # --- Left: Ground truth mask with false negative outlines --- 
    mask = res['mask']
    if mask.ndim == 2:
        gt_img = (mask * 255).astype(np.uint8)
        gt_img = cv2.cvtColor(gt_img, cv2.COLOR_GRAY2RGB)
    else:
        gt_img = mask.copy()
    ax1.imshow(gt_img)
    ax1.set_title("Ground Truth with False Negatives")
    
    for bbox in res.get('false_negatives', []):
        rect = patches.Rectangle((bbox[0], bbox[1]),
                                 bbox[2] - bbox[0],
                                 bbox[3] - bbox[1],
                                 linewidth=2,
                                 edgecolor='r',
                                 facecolor='none')
        ax1.add_patch(rect)
        
    # --- Right: Predicted mask with false positive outlines ---
    pred_mask = res['calculated_mask']
    if pred_mask.ndim == 2:
        pred_img = (pred_mask * 255).astype(np.uint8)
        pred_img = cv2.cvtColor(pred_img, cv2.COLOR_GRAY2RGB)
    else:
        pred_img = pred_mask.copy()
    ax2.imshow(pred_img)
    ax2.set_title("Prediction with False Positives")
    
    for bbox in res.get('false_positives', []):
        rect = patches.Rectangle((bbox[0], bbox[1]),
                                 bbox[2] - bbox[0],
                                 bbox[3] - bbox[1],
                                 linewidth=2,
                                 edgecolor='b',
                                 facecolor='none')
        ax2.add_patch(rect)
        
    plt.show()

    if idx >= 10:  # Limit to first 5 high error samples for visualization
        break