In [None]:
import torch
import numpy as np
# from torch.utils.data import Dataset, DataLoader, random_split
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import torch.nn as nn
import matplotlib.pyplot as plt
import glob
import os
from PIL import Image as PILImage
import cv2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE = (512, 512)  # Resize images to this size

id2label = {0: 'background', 1: 'water'}
label2id = {label: id for id, label in id2label.items()}
NUM_CLASSES = len(id2label)

MODEL_CHECKPOINT = 'segformer_water'

# Check GPU availability
model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=NUM_CLASSES,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# Load Pretrained SegFormer with 2 Classes
MODEL_CHECKPOINT = 'nvidia/mit-b4'
processor = SegformerImageProcessor.from_pretrained(MODEL_CHECKPOINT)
model.to(device)

In [None]:
def load_image_as_rgb(image_path):
    # Open image
    img = PILImage.open(image_path)
    
    # If the image is grayscale (mode 'L'), convert it to RGB
    if img.mode == 'L':
        img = img.convert('RGB')  # Convert grayscale to RGB
    return img

def load_mask_as_binary(mask_path):
    # Open mask image (keep it in grayscale)
    mask = PILImage.open(mask_path)

    # Convert to grayscale (if not already in mode 'L')
    if mask.mode != 'L':
        mask = mask.convert('L')
    
    # Convert mask values from 0-255 to 0-1 (binary)
    mask = np.array(mask)  # Convert to NumPy array
    mask[mask == 255] = 1   # Replace 255 with 1
    mask[mask == 0] = 0     # Ensure 0 stays as 0
    
    # Convert back to PIL Image for compatibility
    mask = PILImage.fromarray(mask)
    
    return mask

def create_dataset(image_paths, mask_paths):
    # Apply the custom loader for RGB images and grayscale masks
    image_paths_rgb = [load_image_as_rgb(img_path) for img_path in image_paths]
    mask_paths_gray = [load_mask_as_binary(mask_path) for mask_path in mask_paths]
    
    return image_paths_rgb, mask_paths_gray

test_image_dir = "./sar_images/images/test/*.png"

test_images_paths = list(glob.glob(test_image_dir))
print(test_images_paths)
test_masks_paths = [path.replace('/images', '/masks') for path in test_images_paths]

test_images, test_masks = create_dataset(test_images_paths, test_masks_paths)

In [None]:
def visualize(image, mask, pred, idx, name):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f'Sample {idx}', fontsize=16)

    #if image.shape[0] == 3:
       # image = inv_normalize(image).permute(1, 2, 0).numpy()
    #else:
    # image = image.permute(1, 2, 0).numpy()
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title('Prediction')
    axes[2].axis('off')

    plt.savefig(f"./GEE_Output/Segformer/Visualizations/{name}")
    plt.show()

output_dir = "./GEE_Output/Segformer/Outputs"

for i in range(len(test_images)):
    image = np.array(test_images[i])
    mask = np.array(test_masks[i])
    name = test_images_paths[i]
    
    inputs = processor(images=image, return_tensors='pt')
    inputs = inputs.to(device)

    outputs = model(**inputs)
    logits = outputs.logits
    
    # Rescale logits to original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.shape[:-1], # (height, width)
        mode='bilinear',
        align_corners=False
    )

    # Apply argmax on the class dimension
    name = os.path.basename(name)
    pred_mask = upsampled_logits.argmax(dim=1)[0]
    pred_mask = pred_mask.cpu().numpy()
    
    pred_norm = pred_mask * 255
    pred_norm = pred_norm.astype(np.uint8)
    pred_path = os.path.join(output_dir, f"pred_{name}")
    cv2.imwrite(pred_path, pred_norm)

    visualize(image, mask, pred_mask, i, name)
