In [None]:
from datasets import load_from_disk
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from transformers import SamProcessor
from torch.utils.data import DataLoader
import os
import sys
# Add the directory containing lit_sam_model.py to the Python path
sys.path.append(os.path.abspath("../"))
from model.minor_models.prefixModel import LitSamModel
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import numpy as np
from PIL import Image

In [None]:
import yaml
import os
from pathlib import Path

# 1. Get the path of the script
current_file = Path(__file__).resolve() # src/training/your_script.py

# 2. Go up one level to 'src', then into 'config'
config_path = current_file.parent.parent / "config" / "config_general.yaml"

# 3. Load the YAML
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

# 4. Resolve the root of the project (one level above 'src')
# This ensures that "./data" in the YAML is interpreted relative to the Project_Root
PROJECT_ROOT = current_file.parent.parent.parent
os.chdir(PROJECT_ROOT) 

# Extract paths from YAML
DATA_DIR = config['paths']['data']
CHECKPOINT_DIR = config['paths']['checkpoints']
SAM_CHECKPOINT = config['paths']['sam_checkpoint']

In [None]:
test_dataset = load_from_disk(os.path.join(DATA_DIR, 'datasetTestFinal'))

In [None]:
from transformers import SamModel, SamConfig, SamProcessor
import torch

sam_checkpoint = os.path.join(CHECKPOINT_DIR, "/prefix/sam-float-adapter-smallprefix-model-epoch=74-val_loss=0.212-val_iou=0.606.ckpt")

# Create an instance of the model architecture with the loaded configuration
model = LitSamModel.load_from_checkpoint(sam_checkpoint, model_name="vit-b", normalize = True, adapt = True, freeze = False, num_classes = 1)

In [None]:
prefix = model.model.prefix

In [None]:
def show_raw_input(array):
    # Display the images
    plt.figure(figsize=(10, 10))

    plt.subplot(3, 2, 1)
    plt.imshow(array[:, :, 0], cmap='gray')
    plt.title("HV0")

    plt.subplot(3, 2, 2)
    plt.imshow(array[:, :, 1], cmap='gray')
    plt.title("HV1")

    plt.subplot(3, 2, 3)
    plt.imshow(array[:, :, 2], cmap='gray')
    plt.title("VV0")

    plt.subplot(3, 2, 4)
    plt.imshow(array[:, :, 3], cmap='gray')
    plt.title("VV1")

    plt.subplot(3, 2, 5)
    plt.imshow(array[:, :, 4], cmap='gray')
    plt.title("DEM")

    plt.subplot(3, 2, 6)
    plt.imshow(array[:, :, 5], cmap='gray')
    plt.title("SLOPE")

    plt.show()

In [None]:
def visualize_prefix_output(prefix, sample):
    import torch
    import numpy as np
    import matplotlib.pyplot as plt

    # Convert sample "image" field into a NumPy array (if not already)
    ground_truth_mask = np.array(sample["label"])
    VH0 = np.array(sample["VH0"])
    VH1 = np.array(sample["VH1"])
    VV0 = np.array(sample["VV0"])
    VV1 = np.array(sample["VV1"])
    dem = np.array(sample["dem"])
    slope = np.array(sample["slope"])

    image = np.stack([VH0, VH1, VV0, VV1, dem, slope], axis=-1)

    # Create a batch dimension (assumes image shape is (H, W, C))
    image_batch = np.expand_dims(image, axis=0)
    
    # Convert to tensor and permute to shape (batch, C, H, W)
    image_tensor = torch.from_numpy(image_batch).permute(0, 3, 1, 2).float()

    # Move tensor to same device as prefix module
    device = next(prefix.parameters()).device
    image_tensor = image_tensor.to(device)

    # Pass the image through the prefix module (inference mode)
    with torch.no_grad():
        prefix_output = prefix(image_tensor)

    # Process the output: assume shape [batch, 3, H, W], take first output
    output_image = prefix_output[0].cpu()

    # Normalize output to [0, 1]
    #output_image = (output_image - output_image.min()) / (output_image.max() - output_image.min())

    # Change shape from (C, H, W) to (H, W, C)
    output_image_np = output_image.permute(1, 2, 0).numpy()

    # Visualize the RGB image
    plt.imshow(output_image_np)
    plt.title("Prefix Module Output as RGB")
    plt.axis('off')
    plt.show()

In [None]:
def visualize_prefix_channels(prefix, sample):
    import torch
    import numpy as np
    import matplotlib.pyplot as plt

    # Convert sample "image" field into a NumPy array (if not already)
    VH0 = np.array(sample["VH0"])
    VH1 = np.array(sample["VH1"])
    VV0 = np.array(sample["VV0"])
    VV1 = np.array(sample["VV1"])
    dem = np.array(sample["dem"])
    slope = np.array(sample["slope"])
    
    image = np.stack([VH0, VH1, VV0, VV1, dem, slope], axis=-1)

    # Create a batch dimension and convert to tensor with shape (B, C, H, W)
    image_tensor = torch.from_numpy(np.expand_dims(image, axis=0)).permute(0, 3, 1, 2).float()

    # Send the tensor to the same device as your prefix module
    device = next(prefix.parameters()).device
    image_tensor = image_tensor.to(device)

    # Forward pass through the prefix module
    with torch.no_grad():
        prefix_output = prefix(image_tensor)
    
    # Assume prefix_output shape is [B, C, H, W] (for example, 3 channels)
    output_tensor = prefix_output[0].cpu()  # shape: (C, H, W)

    # Independently plot each channel
    num_channels = output_tensor.shape[0]
    fig, axes = plt.subplots(1, num_channels, figsize=(4 * num_channels, 4))
    if num_channels == 1:
        axes = [axes]
    for i in range(num_channels):
        channel_img = output_tensor[i].numpy()
        axes[i].imshow(channel_img, cmap='viridis')
        axes[i].set_title(f"Channel {i}")
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def show_normal_image(array):
    # Display the images
    plt.figure(figsize=(5, 5))
    plt.imshow(array)

    plt.show()

In [None]:
def visualize_mask(sample):
    # Assuming sample has a "mask" field that is a binary mask
    mask = np.array(sample['label'])

    # Display the mask
    plt.imshow(mask, cmap='gray')
    plt.title("Mask")
    plt.axis('off')
    plt.show()

In [None]:
item = test_dataset[20]
ground_truth_mask = np.array(item["label"])
VH0 = np.array(item["VH0"])
VH1 = np.array(item["VH1"])
VV0 = np.array(item["VV0"])
VV1 = np.array(item["VV1"])
dem = np.array(item["dem"])
slope = np.array(item["slope"])

image = np.stack([VH0, VH1, VV0, VV1, dem, slope], axis=-1)

In [None]:
visualize_prefix_channels(prefix, item)

In [None]:
show_raw_input(image)
visualize_prefix_output(prefix, item)
visualize_mask(item)

In [None]:
def show_n_complete_samples(dataset1, n=23):
    counter = 0
    for i in range(dataset1.__len__()):
        if counter >= n:
            break
        sample = dataset1[i]
        if(sample["box"] == [0, 0, 512, 512]):
            counter += 1
        else:
            continue
        if counter < 22:
            continue
        VH0 = np.array(sample["VH0"])
        VH1 = np.array(sample["VH1"])
        VV0 = np.array(sample["VV0"])
        VV1 = np.array(sample["VV1"])
        dem = np.array(sample["dem"])
        slope = np.array(sample["slope"])
        mask = np.array(sample['label'])

        image = np.stack([VH0, VH1, VV0, VV1, dem, slope], axis=-1)
        #show_raw_input(image)
        visualize_prefix_output(prefix, sample)
        
        # Display DEM
        plt.figure(figsize=(10, 10))  # Increase figure size for higher quality
        plt.imshow(dem, cmap='gray')
        plt.axis('off')  # Turn off axis to remove pixel coordinates
        plt.gca().set_position([0, 0, 1, 1])  # Remove all margins
        #plt.savefig("dem_no_padding.png", bbox_inches='tight', pad_inches=0)  # Save without padding
        plt.show()

        # Display Slope
        plt.figure(figsize=(10, 10))  # Increase figure size for higher quality
        plt.imshow(slope, cmap='gray')
        plt.axis('off')  # Turn off axis to remove pixel coordinates
        plt.gca().set_position([0, 0, 1, 1])  # Remove all margins
        #plt.savefig("slope_no_padding.png", bbox_inches='tight', pad_inches=0)  # Save without padding
        plt.show()
        #visualize_mask(sample)

In [None]:
show_n_complete_samples(test_dataset)

In [None]:
show_n_complete_samples(test_dataset)