In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import pandas as pd
from datasets import Dataset, load_from_disk, concatenate_datasets
import pytorch_lightning as pl
import sys
import os
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
from model.minor_models.selfSupervisedModel import selfSupSamModel, SAMDataset

In [None]:
import yaml
import os
from pathlib import Path

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

In [None]:
#Load the test dataset
test_dataset = load_from_disk(os.path.join(DATA_DIR, "/home/gelato/Avalanche-Segmentation-with-Sam/code/dataprocessing/datasetTestSelfSupSlope"))


In [None]:
# Combine the training datasets
test_dataset = concatenate_datasets([test_dataset1, test_dataset2])

In [None]:
from transformers import SamModel, SamConfig, SamProcessor
import torch

sam_checkpoint = "/home/gelato/Avalanche-Segmentation-with-Sam/code/training/checkpointsSelfSup/sam-float-selfsup_slope-model-epoch=146-val_loss=0.0013.ckpt"
sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/sam-float-selfsup_slope-model-epoch=146-val_loss=0.0013.ckpt")

# Create an instance of the model architecture with the loaded configuration
model = selfSupSamModel.load_from_checkpoint(sam_checkpoint, model_name="facebook/sam-vit-base", normalize = True)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
#model.to(device)

In [None]:
idx = 0
image = np.array(test_dataset[idx]["image"])
print(image.shape)


In [None]:
print("Max:" + str(np.max(image)))
print("Min:" + str(np.min(image)))

In [None]:
array = test_dataset[idx]["image"]

In [None]:
idx = 4
image = np.array(test_dataset[idx]["image"], np.float32)
mask = test_dataset[idx]['label']
mask = np.array(mask, np.float32)

image_copy = image.copy()
mask_copy = mask.copy()

# Draw the bounding boxes on the copied image
for (x, y, w, h) in test_dataset[idx:idx+1]['box']:
    cv2.rectangle(image_copy, (x, y), (x + w, y + h), (0, 255, 0), 2)  # Green box with thickness 2
    cv2.rectangle(mask_copy, (x, y), (x + w, y + h), (0, 255, 0), 2)  # Green box with thickness 2

# Convert the copied images from BGR to RGB (OpenCV uses BGR by default)
image_copy_rgb = cv2.cvtColor(image_copy, cv2.COLOR_RGB2RGBA)
mask_copy_rgb = cv2.cvtColor(mask_copy, cv2.COLOR_RGB2RGBA)
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)

# Display the original image, mask, and image with bounding boxes using matplotlib
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image_rgb)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(mask_copy_rgb)
axes[1].set_title('Mask with Bounding Boxes')
axes[1].axis('off')

axes[2].imshow(image_copy_rgb)
axes[2].set_title('Image with Bounding Boxes')
axes[2].axis('off')

plt.show()

In [None]:
from torch.utils.data import DataLoader

# Create an instance of the SAMDataset
test_dataset_sam = SAMDataset(dataset=test_dataset, processor=processor, augment=False)

# Create a DataLoader instance for the validation dataset
test_dataloader = DataLoader(test_dataset_sam, batch_size=5, shuffle=False, num_workers=4)

In [None]:
batch = next(iter(test_dataloader))

# Get the first image, mask, and boxes from the batch
image = batch["pixel_values"][1]
mask = batch["ground_truth_mask"][1]

print(image.shape)
print(mask.shape)
print(image.dtype)

In [None]:
print(mask.min(), mask.max())

In [None]:
import matplotlib.patches as patches

def display_image_mask_boxes(image, mask):
    # Convert the image and mask to PIL images
    image = image.permute(1, 2, 0).cpu().numpy()
    mask = mask.permute(1, 2, 0).cpu().numpy()

    # Create a figure and axis
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))

    # Display the image
    ax[0].imshow(image)
    ax[0].set_title("Image")
    ax[0].axis("off")

    # Display the mask
    ax[1].imshow(mask)
    ax[1].set_title("Mask")
    ax[1].axis("off")


    plt.show()

# Display the image, mask, and bounding boxes
display_image_mask_boxes(image, mask)

In [None]:
# Iterate over the dataloader and display different batches
for i, batch in enumerate(test_dataloader):
    if i >= 5:  # Display 5 different batches
        break

    # Get the first image, mask, and boxes from the batch
    image = batch["pixel_values"][0]
    mask = batch["ground_truth_mask"][0]

    # Display the image, mask, and bounding boxes
    display_image_mask_boxes(image, mask)

In [None]:
import torch
import matplotlib.pyplot as plt

def denormalize(image, pixel_mean, pixel_std):
    """
    Denormalizes an image tensor.
    
    Args:
        image (torch.Tensor): Tensor of shape (C, H, W) in normalized space.
        pixel_mean (torch.Tensor): Tensor of shape (C, 1, 1).
        pixel_std (torch.Tensor): Tensor of shape (C, 1, 1).
        
    Returns:
        torch.Tensor: Denormalized image, clamped between 0 and 1.
    """
    image = image * pixel_std + pixel_mean
    return torch.clamp(image, 0, 1)

def display_prediction(image_norm, prediction, ground_truth, model):
    """
    Displays the denormalized input image and the model's prediction.
    
    Args:
        image_norm (torch.Tensor): Input image tensor (C, H, W) after normalization.
        prediction (torch.Tensor): Model output (usually in same value range as the input).
        ground_truth (torch.Tensor): Ground truth mask.
        model: The Lightning module (or its inner model) that contains the pixel_mean and pixel_std buffers.
    """
    
    # Convert tensors (C, H, W) to numpy arrays (H, W, C)
    img_np = image_norm.permute(1, 2, 0).cpu().numpy()
    gt_np = ground_truth.permute(1, 2, 0).cpu().numpy()
    # For prediction, if it is not in a displayable range, you might need to normalize it.
    # Here, we assume the prediction is either single-channel (mask) or RGB.
    pred_np = prediction.squeeze(0).cpu().numpy()
    if pred_np.ndim == 2:
        cmap = "gray"
    else:
        pred_np = pred_np.transpose(1, 2, 0)  # from (C,H,W) to (H,W,C)
        cmap = None

    # Decide number of subplots based on whether last_mask exists
    if hasattr(model.model.image_encoder, "last_mask") and model.model.image_encoder.last_mask is not None:
        ncols = 4
    else:
        ncols = 3

    # Create a figure with three subplots
    fig, axs = plt.subplots(1, ncols, figsize=(6 * ncols, 6))
    axs[0].imshow(img_np)
    axs[0].set_title("Input Image")
    axs[0].axis("off")
    
    axs[1].imshow(gt_np)
    axs[1].set_title("Ground Truth Mask")
    axs[1].axis("off")
    
    axs[2].imshow(pred_np, cmap=cmap)
    axs[2].set_title("Prediction")
    axs[2].axis("off")

    # Display the self-supervised mask if it exists
    if ncols == 4:
        mask_np = model.model.image_encoder.last_mask.cpu().numpy()
        axs[3].imshow(mask_np, cmap="gray")
        axs[3].set_title("SelfSup Mask")
        axs[3].axis("off")
    
    plt.show()

In [None]:
# Iterate over the dataloader and display different batches
for i, batch in enumerate(test_dataloader):
    if i >= 5:  # Display 5 different batches
        break
    # Assume 'model' is your selfSupSamModel instance and you have a sample batch.
    # Get one sample image and its prediction.
    image_norm = batch["pixel_values"][0]  # normalized input
    gt_mask = batch["ground_truth_mask"][0]  # ground truth mask
    with torch.no_grad():
        pred = model(batch["pixel_values"].to(model.device))
        # Resize prediction if needed:
        pred = torch.nn.functional.interpolate(pred, size=image_norm.shape[-2:], mode='bilinear', align_corners=False)
    pred = pred[0]  # take first sample

    # Display the denormalized image and prediction.
    display_prediction(image_norm, pred, gt_mask, model)

In [None]:
import numpy as np

def analyze_mask(mask):
    """
    Analyzes a binary mask and returns the counts and percentages of masked (white) and unmasked (black) pixels.
    
    Args:
        mask (np.ndarray): A 2D binary mask (0 for unmasked, 1 for masked).
        
    Returns:
        dict: A dictionary containing count and percentage for masked and unmasked pixels.
    """
    total_pixels = mask.size
    masked_pixels = np.sum(mask > 0)
    unmasked_pixels = total_pixels - masked_pixels
    perc_masked = (masked_pixels / total_pixels) * 100
    perc_unmasked = (unmasked_pixels / total_pixels) * 100
    
    return {
        "masked_pixels": masked_pixels,
        "unmasked_pixels": unmasked_pixels,
        "perc_masked": perc_masked,
        "perc_unmasked": perc_unmasked
    }

# Create a random mask of size 64x64 that masks ~30% of the pixels.
mask = (np.random.rand(64, 64) < 0.3).astype(np.uint8)

# Analyze the mask:
results = analyze_mask(mask)
print("Masked pixels:", results["masked_pixels"])
print("Unmasked pixels:", results["unmasked_pixels"])
print("Percentage masked: {:.2f}%".format(results["perc_masked"]))
print("Percentage unmasked: {:.2f}%".format(results["perc_unmasked"]))

# Optionally, visualize the mask:
import matplotlib.pyplot as plt
plt.imshow(mask, cmap="gray")
plt.title("Random Mask (approx. 30% masked)")
plt.axis("off")
plt.show()

In [None]:
# Create a trainer
trainer = pl.Trainer(accelerator='gpu', devices=1)

# Run the evaluation
trainer.test(model, dataloaders=test_dataloader)

In [None]:
outputs = model.test_outputs

In [None]:
outputs[4]['test_iou'].shape

In [None]:
# Access the stored results
test_results = model.test_results
ground_truth_masks = test_results['ground_truth_masks']
predicted_masks = test_results['predicted_masks']
individual_ious = test_results['individual_ious']

In [None]:
sam_seg = ground_truth_masks[2]
print(sam_seg)
print(sam_seg.__contains__(1))

In [None]:
predicted_masks = np.concatenate(predicted_masks, axis=0)
ground_truth_masks = np.concatenate(ground_truth_masks, axis=0)

In [None]:
results_zero_shot = []

for index in range(len(test_dataset)):
        mask = ground_truth_masks[index]
        sam_seg = predicted_masks[index]

        sam_seg_prob = torch.sigmoid(torch.tensor(sam_seg))
        # convert soft mask to hard mask
        sam_seg_prob = sam_seg_prob.cpu().numpy().squeeze()
        sam_seg = (sam_seg_prob > 0.5).astype(np.uint8)

        # Calculate IoU
        iou, intersection, union = calculate_iou(mask, sam_seg)

        results_zero_shot.append({'mask': mask,
                                  'calculated_mask': sam_seg, 
                                  'intersection': intersection.cpu().squeeze().numpy(), 
                                  'union': union.cpu().squeeze().numpy(), 
                                  'iou': iou.cpu().numpy(),
                                  'empty': False, 
                                  })

In [None]:
# Create a DataFrame from the list of dictionaries
df_finetune = pd.DataFrame(results_zero_shot)

# Save the DataFrame to a file (optional)
#df_finetune.to_pickle('dataframe_finetune_20epochs.pkl')

In [None]:
df_finetune = pd.read_pickle('dataframe_finetune_20epochs.pkl')

In [None]:
counter = 0

In [None]:
df_finetune.at[counter, 'iou'].item()

In [None]:
count = counter
for index in range(len(df_finetune)):
    if df_finetune.at[index, 'mask_area'].item() > 5000:
        if(count == 0):
            count = index
            break
        else:
            count -= 1
counter += 1

mask1_np = df_finetune.at[count, 'mask']
combined_mask_np = df_finetune.at[count, 'calculated_mask']
intersection_np = df_finetune.at[count, 'intersection']
union_np = df_finetune.at[count, 'union']

# Create a visualization
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Original mask1
axes[0].imshow(mask1_np, cmap='gray')
axes[0].set_title('Mask')
axes[0].axis('off')

# Combined mask
axes[1].imshow(combined_mask_np, cmap='gray')
axes[1].set_title('Calculated Mask')
axes[1].axis('off')

# Intersection
axes[2].imshow(intersection_np, cmap='gray')
axes[2].set_title('Intersection')
axes[2].axis('off')

# Union
axes[3].imshow(union_np, cmap='gray')
axes[3].set_title('Union')
axes[3].axis('off')

# Add IoU text to the plot
fig.suptitle(f'IoU: {df_finetune.at[count, 'iou']:.4f}', fontsize=16)

plt.show()

In [None]:
print(df_finetune.at[count, 'mask_area'].item())

In [None]:
# Filter out rows with None IoU values
df_filtered = df_finetune.dropna(subset=['iou'])

# Extract IoU values from the filtered DataFrame
iou_values = df_filtered['iou'].tolist()
#iou_values = individual_ious

# Calculate the average IoU
average_iou = sum(iou_values) / len(iou_values)

# Create a bar chart
plt.figure(figsize=(10, 6))
plt.bar(range(len(iou_values)), iou_values, color='blue')
plt.xlabel('Sample Index')
plt.ylabel('IoU')
plt.title('IoU Values for Different Samples')
plt.text(0.5, 0.95, f'Average IoU: {average_iou:.4f}', ha='center', va='center', transform=plt.gca().transAxes, fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
plt.show()

# Create a line plot
plt.figure(figsize=(10, 6))
plt.plot(range(len(iou_values)), iou_values, marker='o', linestyle='-', color='blue')
plt.xlabel('Sample Index')
plt.ylabel('IoU')
plt.title('IoU Values for Different Samples')
plt.text(0.5, 0.95, f'Average IoU: {average_iou:.4f}', ha='center', va='center', transform=plt.gca().transAxes, fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
plt.show()

In [None]:
# Filter out rows with None IoU values
df_filtered = df_finetune.dropna(subset=['iou'])

# Extract IoU values from the filtered DataFrame
iou_values = df_filtered['iou'].tolist()
#iou_values = individual_ious

# Calculate the average IoU
average_iou = sum(iou_values) / len(iou_values)

# Create a bar chart
plt.figure(figsize=(10, 6))
plt.bar(range(len(iou_values)), iou_values, color='blue')
plt.xlabel('Sample Index')
plt.ylabel('IoU')
plt.title('IoU Values for Different Samples (only gaussian noise)')
plt.text(0.5, 0.95, f'Average IoU: {average_iou:.4f}', ha='center', va='center', transform=plt.gca().transAxes, fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
plt.show()

# Create a line plot
plt.figure(figsize=(10, 6))
plt.plot(range(len(iou_values)), iou_values, marker='o', linestyle='-', color='blue')
plt.xlabel('Sample Index')
plt.ylabel('IoU')
plt.title('IoU Values for Different Samples (only gaussian noise)')
plt.text(0.5, 0.95, f'Average IoU: {average_iou:.4f}', ha='center', va='center', transform=plt.gca().transAxes, fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Calculate the area of each mask
df_finetune['mask_area'] = df_finetune['mask'].apply(lambda mask: np.sum(mask))

# Filter out rows where the mask area is above 5000
df_filtered2 = df_finetune[df_finetune['mask_area'] <= 500]

# Create a scatter plot
plt.figure(figsize=(10, 6))
plt.scatter(df_filtered2['mask_area'], df_filtered2['iou'], marker='o', color='blue')
plt.xlabel('Mask Area')
plt.ylabel('IoU')
plt.title('IoU vs. Mask Area')
plt.grid(True)
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Create a scatter plot with a logarithmic scale for the mask area axis
plt.figure(figsize=(10, 6))
plt.scatter(df_finetune['mask_area'], df_finetune['iou'], marker='o', color='blue')
plt.xscale('log')
plt.xlabel('Mask Area (log scale)')
plt.ylabel('IoU')
plt.title('IoU vs. Mask Area')
plt.grid(True)
plt.show()