In [None]:
import torch
import numpy as np
from datasets import Dataset, Image
# from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation, TrainingArguments, Trainer, MaskFormerConfig, MaskFormerModel, MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation, Mask2FormerConfig, Mask2FormerModel
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
import glob

from typing import Dict, List, Mapping
from transformers.trainer import EvalPrediction
# from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchmetrics import JaccardIndex, Accuracy
from dataclasses import dataclass

import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import os
import cv2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE = (512, 512)  # Resize images to this size

id2label = {0: 'background', 1: 'water'}
label2id = {label: id for id, label in id2label.items()}
NUM_CLASSES = len(id2label)

MODEL_CHECKPOINT = "mask2former_water_new"
model = Mask2FormerForUniversalSegmentation.from_pretrained(MODEL_CHECKPOINT)
# model = MaskFormerForInstanceSegmentation.from_pretrained(MODEL_CHECKPOINT)

# MODEL_CHECKPOINT = "facebook/maskformer-swin-large-ade"
MODEL_CHECKPOINT = "facebook/mask2former-swin-large-cityscapes-semantic"
processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)

model.to(device)

In [None]:
test_image_dir = "./sar_images/images/test/*.png"
# test_mask_dir = "./sar_images/masks/test/*.png"

test_images_paths = list(glob.glob(test_image_dir))
# images = [str(path) for path in images]
test_masks_paths = [path.replace('/images', '/masks') for path in test_images_paths]


In [None]:
alb_transform = A.Compose([
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
    ToTensorV2()
])

def create_dataset(image_paths, mask_paths):
    
    images = []
    masks = []
    
    for img_path, mask_path in zip(image_paths, mask_paths):
        
        image = PILImage.open(img_path).convert("RGB")
        mask = np.array(PILImage.open(mask_path).convert("L"), dtype=np.uint8)  # Convert mask to grayscale
        mask[mask == 255] = 1  
        
        images.append(image)
        masks.append(mask)
                
    return images, masks

test_images, test_masks = create_dataset(test_images_paths, test_masks_paths)

In [None]:
def visualize(image, mask, pred, idx, name):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f'Sample {idx}', fontsize=16)

    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title('Prediction')
    axes[2].axis('off')

    plt.savefig(f"./GEE_Output/Maskformer/Visualizations/{name}")
    plt.show()
    
output_dir = "./GEE_Output/Maskformer/Outputs"
for i in range(len(test_images)):
    
    image = test_images[i]
    
    # image = test_images[i].cpu().numpy()
    # image = PILImage.fromarray(image.transpose(1, 2, 0), mode='RGB')    
    mask = test_masks[i]
    name = test_images_paths[i]
    
    inputs = processor(image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)

    class_queries_logits = outputs.class_queries_logits
    masks_queries_logits = outputs.masks_queries_logits   
    
    predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
    name = os.path.basename(name)
    pred_mask = predicted_semantic_map.cpu().numpy()
    
    pred_norm = pred_mask * 255
    pred_norm = pred_norm.astype(np.uint8)
    pred_path = os.path.join(output_dir, f"pred_{name}")
    cv2.imwrite(pred_path, pred_norm)
    visualize(np.array(image), mask, pred_mask, i, name)
