In [None]:
# %% [code] Cell: Visualize Augmentations and Bounding Boxes from SAMDataset3

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
from datasets import load_from_disk
from transformers import SamProcessor
import sys
import os
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
from model.samDataset import SAMDataset3
import pickle
import yaml
import os
from pathlib import Path
from model.inputTypes import InputTypes

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

# Load your validation dataset from disk (use the same path as in your training script)
val_dataset = load_from_disk(os.path.join(DATA_DIR, 'datasetValFinal'))

with open(os.path.join(DATA_DIR, 'val_indices.pkl'), "rb") as f:
    val_indices = pickle.load(f)

all_indices = set(range(len(val_dataset)))
exclude_indices = set(val_indices)
val_indices = list(all_indices - exclude_indices)

val_dataset = val_dataset.select(val_indices)

# Initialize the processor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Create an instance of SAMDataset3 with augmentations enabled,
# so that we can inspect the output after augmentations and bounding box calculation.
sam_dataset = SAMDataset3(dataset=val_dataset, processor=processor, augment=True, type = InputTypes.Normal)

from torch.utils.data._utils.collate import default_collate
from torch.utils.data import DataLoader

def custom_collate(batch):
    collated = {}
    for key in batch[0]:
        if key == "input_boxes":
            # Leave input_boxes as a list so that each sample can have different numbers of boxes.
            collated[key] = [item[key] for item in batch]
        else:
            collated[key] = default_collate([item[key] for item in batch])
    return collated

# Create a DataLoader instance for the training dataset
dataloader = DataLoader(sam_dataset, batch_size=5, shuffle=True, drop_last=True, collate_fn=custom_collate)


In [None]:
batch = next(iter(dataloader))

# Get the first image, mask, and boxes from the batch
image = batch["pixel_values"][1]
mask = batch["ground_truth_mask"][1]
boxes = batch["input_boxes"][1]

print(image.shape)
print(mask.shape)
print(boxes.shape)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

def show_batch(batch):
    for i in range(len(batch["pixel_values"])):
        #singleimage = batch["pixel_values"][i].numpy()
        ## Plot each channel of the single image
        #fig, axs = plt.subplots(1, 6, figsize=(15, 5))
        #axs[0].imshow(singleimage[0, :, :], cmap='gray')
        #axs[0].set_title('VH0')
        #axs[1].imshow(singleimage[1, :, :], cmap='gray')
        #axs[1].set_title('VH1')
        #axs[2].imshow(singleimage[2, :, :], cmap='gray')
        #axs[2].set_title('VV0')
        #axs[3].imshow(singleimage[3, :, :], cmap='gray')
        #axs[3].set_title('VV1')
        #axs[4].imshow(singleimage[4, :, :], cmap='gray')
        #axs[4].set_title('DEM')
        #axs[5].imshow(singleimage[5, :, :], cmap='gray')
        #axs[5].set_title('Slope')
        #plt.tight_layout()
        #plt.show()
        fig, axs = plt.subplots(1, 2, figsize=(15, 5))
        singleimage = batch["ground_truth_mask"][i].numpy()
        axs[0].imshow(singleimage, cmap='gray')
        axs[0].set_title('GT')
        #image VV0,VV1,VV0
        image = batch["pixel_values"][i].numpy()
        if image.shape[0] != 3:
            image = np.stack([image[2, :, :], image[3, :, :], image[2, :, :]], axis=-1)
        else:
            image = np.transpose(image, (1,2,0))
        axs[1].imshow(image)
        axs[1].set_title('Image')
        #print("boxes:", batch["input_boxes"][i])
        for box in batch["input_boxes"][i]:
            print("box:", box)
            rect2 = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor='r', facecolor='none')
            box = box / 2  # Scale down the box coordinates by 2
            # Create a Rectangle patch
            rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor='r', facecolor='none')
            # Add the patch to the Axes
            axs[0].add_patch(rect)
            axs[1].add_patch(rect2)
        plt.tight_layout()
        plt.show()

In [None]:
batch = next(iter(dataloader))
show_batch(batch)

In [None]:
test_dataset_test = load_from_disk(os.path.join(DATA_DIR, 'datasetTestDEMFloat'))

In [None]:
#Fileter for full image bboxes only
full_image_indices = []
for i in range(len(test_dataset_test)):
    box = test_dataset_test[i]['box']
    if box == [0, 0, 512, 512]:
        full_image_indices.append(i)
test_dataset_test = test_dataset_test.select(full_image_indices)

In [None]:
print(test_dataset_test)

In [None]:
for index in range(len(test_dataset_test)):
    if index < 20:
        continue
    if index > 30:
        break
    sample = test_dataset_test[index]
    
    # Show image
    plt.figure()  # Create a new figure for each image
    plt.imshow(sample['image'])
    plt.title(f"Image {index}")  # Optional: Add a title to identify the image
    plt.axis('off')  # Optional: Turn off the axis for better visualization
    plt.show()  # Display the image

In [None]:
from creationOfDataframe import *

In [None]:
sample = test_dataset_test[22]

label_array = np.array(sample['label'])
bounding_box_base = find_bounding_boxes(label_array)
bounding_box_increase = increase_bounds(bounding_box_base,512,512, increase_by=20)
bounding_box_merged = bounding_box_increase.copy()
# Initialize num_of_boxes to a value different from len(bounding_box_merged)
num_of_boxes = -1
while(True):
        bounding_box_merged = merge_overlapping_boxes(bounding_box_merged)
        if (num_of_boxes == len(bounding_box_merged)):
            break
        else:
            num_of_boxes = len(bounding_box_merged)



In [None]:
# Visualize the image with bounding boxes
plt.figure()
image = sample['image']
image = np.array(image)
image[:,:,2] = image[:,:,0]
plt.imshow(image)
ax = plt.gca()
for box in bounding_box_merged:
    rect = patches.Rectangle((box[0], box[1]), box[2], box[3], linewidth=2, edgecolor='g', facecolor='none')
    ax.add_patch(rect)
plt.title("Image with Bounding Boxes")
plt.axis('off')
plt.show()

In [None]:
# Visualize the image with bounding boxes at different stages
plt.figure(figsize=(15, 5))
#increase font size
plt.rcParams.update({'font.size': 14})

# Convert the image to a NumPy array
image = sample['image']
image = np.array(image)
image[:, :, 2] = image[:, :, 0]  # Adjust the image for visualization

# --- Stage 1: Bounding Boxes Before Increase ---
ax1 = plt.subplot(1, 3, 1)
ax1.imshow(image)
ax1.set_title("Before Increase")
ax1.axis('off')
for box in bounding_box_base:
    rect = patches.Rectangle((box[0], box[1]), box[2], box[3],
                              linewidth=2, edgecolor='g', facecolor='none')
    ax1.add_patch(rect)

# --- Stage 2: Bounding Boxes After Increase ---
ax2 = plt.subplot(1, 3, 2)
ax2.imshow(image)
ax2.set_title("After Increase")
ax2.axis('off')
for box in bounding_box_increase:
    rect = patches.Rectangle((box[0], box[1]), box[2], box[3],
                              linewidth=2, edgecolor='g', facecolor='none')
    ax2.add_patch(rect)

# --- Stage 3: Bounding Boxes After Merge ---
ax3 = plt.subplot(1, 3, 3)
ax3.imshow(image)
ax3.set_title("After Merge")
ax3.axis('off')
for box in bounding_box_merged:
    rect = patches.Rectangle((box[0], box[1]), box[2], box[3],
                              linewidth=2, edgecolor='g', facecolor='none')
    ax3.add_patch(rect)

plt.tight_layout()
plt.show()