In [25]:
import torch
import numpy as np
from datasets import Dataset, Image
# from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation, TrainingArguments, Trainer, MaskFormerConfig, MaskFormerModel
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
import evaluate
import glob
import torch.nn as nn
from torchvision import transforms

from typing import Dict, List, Mapping
from transformers.trainer import EvalPrediction
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from dataclasses import dataclass

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE = (512, 512)  # Resize images to this size
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 5e-5
VAL_SPLIT = 0.125

id2label = {0: 'background', 1: 'water'}
label2id = {label: id for id, label in id2label.items()}
NUM_CLASSES = len(id2label)

MODEL_CHECKPOINT = "facebook/maskformer-resnet50-coco-stuff"

config = MaskFormerConfig.from_pretrained(MODEL_CHECKPOINT)
config.id2label = id2label
config.label2id = label2id
config.num_labels = NUM_CLASSES

# Use the config object to initialize a MaskFormer model with randomized weights
model = MaskFormerForInstanceSegmentation(config)

base_model = MaskFormerModel.from_pretrained(MODEL_CHECKPOINT)
model.model = base_model

processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)

model.to(device)

  return func(*args, **kwargs)


MaskFormerForInstanceSegmentation(
  (model): MaskFormerModel(
    (pixel_level_module): MaskFormerPixelLevelModule(
      (encoder): ResNetBackbone(
        (embedder): ResNetEmbeddings(
          (embedder): ResNetConvLayer(
            (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activation): ReLU()
          )
          (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
        (encoder): ResNetEncoder(
          (stages): ModuleList(
            (0): ResNetStage(
              (layers): Sequential(
                (0): ResNetBottleNeckLayer(
                  (shortcut): ResNetShortCut(
                    (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                    (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True

In [27]:
train_image_dir = "./sar_images/images/train/*.png"
train_mask_dir = "./sar_images/masks/train/*.png"
test_image_dir = "./sar_images/images/test"
test_mask_dir = "./sar_images/masks/test"

images = list(glob.glob(train_image_dir))
# images = [str(path) for path in images]
masks = [path.replace('/images', '/masks') for path in images]

# print(images)
# print(masks)

# print(f'{len(images)} images detected.')

train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=VAL_SPLIT, random_state=0, shuffle=True)

print(f'Train images: {len(train_images)}\nValidation images: {len(val_images)}')

Train images: 883
Validation images: 127


In [28]:
def load_image_as_rgb(image_path):
    # Open image
    img = PILImage.open(image_path)
    
    # If the image is grayscale (mode 'L'), convert it to RGB
    if img.mode == 'L':
        img = img.convert('RGB')  # Convert grayscale to RGB
    return img

def load_mask_as_binary(mask_path):
    # Open mask image (keep it in grayscale)
    mask = PILImage.open(mask_path)

    # Convert to grayscale (if not already in mode 'L')
    if mask.mode != 'L':
        mask = mask.convert('L')
    
    # Convert mask values from 0-255 to 0-1 (binary)
    mask = np.array(mask)  # Convert to NumPy array
    mask[mask == 255] = 1   # Replace 255 with 1
    # mask[mask == 0] = 0     # Ensure 0 stays as 0
    
    # Convert back to PIL Image for compatibility
    mask = PILImage.fromarray(mask)
    
    # mask = mask[np.newaxis, :, :]
    
    return mask

def create_dataset(image_paths, mask_paths):
    """Creates a dataset storing file paths as individual strings, not lists"""
    return Dataset.from_dict({"pixel_values": image_paths, "label": mask_paths})

# Create dataset from file paths
ds_train = create_dataset(train_images, train_masks)
ds_valid = create_dataset(val_images, val_masks)

def transform(example):
    """Loads images/masks from file paths and applies transformations."""
    
    # Apply transformation to each item in the batch
    images = []
    masks = []
    
    # print("batch")
    
    for img_path, mask_path in zip(example["pixel_values"], example["label"]):
        # Open images and masks
        image = PILImage.open(img_path).convert("RGB")
        mask = PILImage.open(mask_path).convert("L")  # Convert mask to grayscale
        
        # Transform image
        image = transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor()
        ])(image)
        
        mask = np.array(mask)  # Convert to NumPy array
        mask[mask == 255] = 1   # Replace 255 with 1
        # print(mask)
        # mask[mask == 0] = 0     # Ensure 0 stays as 0
        
        # # Convert back to PIL Image for compatibility
        mask = PILImage.fromarray(mask)

        # # Transform mask (nearest neighbor interpolation for segmentation)
        mask = transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor()
        ])(mask).long().squeeze(0)  # Convert to tensor and remove extra channel

        images.append(image)
        masks.append(mask)
        
        # print(f"Image shape: {image.shape}, Mask shape: {mask.shape}")
    
    inputs = processor(images, segmentation_maps=masks, return_tensors='pt')
    return inputs

# Apply transformation correctly
ds_train.set_transform(transform)
ds_valid.set_transform(transform)


In [29]:
# metric = evaluate.load('mean_iou')

# def compute_metrics(eval_pred):
#     with torch.no_grad():
#         logits, labels = eval_pred
        
#         if isinstance(logits, tuple):
#             for i, logit in enumerate(logits):
#                 print(f"Logits[{i}] shape: {logit.shape}")
        
#         if isinstance(labels, tuple):
#             for i, label in enumerate(labels):
#                 print(f"Labels[{i}] shape: {label.shape}")

#         logits_tensor = torch.from_numpy(logits[1])
#         # logits_tensor = torch.from_numpy(logits)
#         # scale the logits to the size of the label
        
#         logits_tensor = nn.functional.interpolate(
#             logits_tensor,
#             size=IMAGE_SIZE,
#             mode='bilinear',
#             align_corners=False,
#         ).argmax(dim=1)

#         pred_labels = logits_tensor.detach().cpu().numpy()
        
#         labels_tensor = torch.from_numpy(labels[0])
#         labels_tensor = nn.functional.interpolate(
#             labels_tensor.float(), size=IMAGE_SIZE, mode="nearest"
#         ).long()
        
#         labels_resized = labels_tensor.squeeze(1).detach().cpu().numpy()
        
#         # currently using _compute instead of compute
#         # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
#         metrics = metric._compute(
#             predictions=pred_labels,
#             references=labels_resized,
#             num_labels=len(id2label),
#             ignore_index=None,
#             reduce_labels=processor.do_reduce_labels,
#         )

#         # add per category metrics as individual key-value pairs
#         per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
#         per_category_iou = metrics.pop("per_category_iou").tolist()

#         metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
#         metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

#         return metrics

@dataclass
class ModelOutput:
    class_queries_logits: torch.Tensor
    masks_queries_logits: torch.Tensor

def nested_cpu(tensors):
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_cpu(t) for t in tensors)
    elif isinstance(tensors, Mapping):
        return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()})
    elif isinstance(tensors, torch.Tensor):
        return tensors.cpu().detach()
    else:
        return tensors

class Evaluator:
    """
    Compute metrics for the instance segmentation task.
    """

    def __init__(
        self,
        image_processor: AutoImageProcessor,
        id2label: Mapping[int, str],
        threshold: float = 0.0,
    ):
        """
        Initialize evaluator with image processor, id2label mapping and threshold for filtering predictions.

        Args:
            image_processor (AutoImageProcessor): Image processor for
                `post_process_instance_segmentation` method.
            id2label (Mapping[int, str]): Mapping from class id to class name.
            threshold (float): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
        """
        self.image_processor = image_processor
        self.id2label = id2label
        self.threshold = threshold
        self.metric = self.get_metric()

    def get_metric(self):
        metric = MeanAveragePrecision(iou_type="segm", class_metrics=True)
        return metric

    def reset_metric(self):
        self.metric.reset()

    def postprocess_target_batch(self, target_batch) -> List[Dict[str, torch.Tensor]]:
        """Collect targets in a form of list of dictionaries with keys "masks", "labels"."""
        batch_masks = target_batch[0]
        batch_labels = target_batch[1]
        post_processed_targets = []
        for masks, labels in zip(batch_masks, batch_labels):
            if not isinstance(masks, torch.Tensor):
                masks = torch.tensor(masks)
    
            post_processed_targets.append(
                {
                    "masks": masks.to(dtype=torch.bool),
                    "labels": labels,
                }
            )
        return post_processed_targets

    def get_target_sizes(self, post_processed_targets) -> List[List[int]]:
        target_sizes = []
        for target in post_processed_targets:
            target_sizes.append(target["masks"].shape[-2:])
        return target_sizes

    def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> List[Dict[str, torch.Tensor]]:
        """Collect predictions in a form of list of dictionaries with keys "masks", "labels", "scores"."""
        
        class_queries_logits = prediction_batch[0]
        masks_queries_logits = prediction_batch[1]
        if not isinstance(class_queries_logits, torch.Tensor):
            class_queries_logits = torch.tensor(class_queries_logits)
        
        if not isinstance(masks_queries_logits, torch.Tensor):
            masks_queries_logits = torch.tensor(masks_queries_logits)
        
        print("Here 1")

        model_output = ModelOutput(class_queries_logits=class_queries_logits, masks_queries_logits=masks_queries_logits)
        post_processed_output = self.image_processor.post_process_instance_segmentation(
            model_output,
            threshold=self.threshold,
            target_sizes=target_sizes,
            return_binary_maps=True,
        )
        
        print("Here 2")

        post_processed_predictions = []
        for image_predictions, target_size in zip(post_processed_output, target_sizes):
            if image_predictions["segments_info"]:
                if not isinstance(image_predictions["segmentation"], torch.Tensor):
                    image_predictions["segmentation"] = torch.tensor(image_predictions["segmentation"])
                    
                post_processed_image_prediction = {
                    "masks": image_predictions["segmentation"].to(dtype=torch.bool),
                    "labels": torch.tensor([x["label_id"] for x in image_predictions["segments_info"]]),
                    "scores": torch.tensor([x["score"] for x in image_predictions["segments_info"]]),
                }
                
            else:
                # for void predictions, we need to provide empty tensors
                post_processed_image_prediction = {
                    "masks": torch.zeros([0, *target_size], dtype=torch.bool),
                    "labels": torch.tensor([]),
                    "scores": torch.tensor([]),
                }
            post_processed_predictions.append(post_processed_image_prediction)

        return post_processed_predictions

    @torch.no_grad()
    def __call__(self, evaluation_results: EvalPrediction, compute_result: bool = True) -> Mapping[str, float]:
        """
        Update metrics with current evaluation results and return metrics if `compute_result` is True.

        Args:
            evaluation_results (EvalPrediction): Predictions and targets from evaluation.
            compute_result (bool): Whether to compute and return metrics.

        Returns:
            Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>}
        """
        prediction_batch = nested_cpu(evaluation_results.predictions)
        target_batch = nested_cpu(evaluation_results.label_ids)
        print("Checkpoint 1")

        # For metric computation we need to provide:
        #  - targets in a form of list of dictionaries with keys "masks", "labels"
        #  - predictions in a form of list of dictionaries with keys "masks", "labels", "scores"
        post_processed_targets = self.postprocess_target_batch(target_batch)
        print("Checkpoint 2")
        target_sizes = self.get_target_sizes(post_processed_targets)
        print("Checkpoint 3")
        post_processed_predictions = self.postprocess_prediction_batch(prediction_batch, target_sizes)

        print("Checkpoint 4")
        
        # Compute metrics
        
        if not isinstance(post_processed_targets, torch.Tensor):
            post_processed_targets = torch.tensor(post_processed_targets)
        
        if not isinstance(post_processed_predictions, torch.Tensor):
            post_processed_predictions = torch.tensor(post_processed_predictions)
        
        self.metric.update(post_processed_predictions, post_processed_targets)

        print("Checkpoint 5")
        if not compute_result:
            return

        metrics = self.metric.compute()
        
        print("Checkpoint 6")

        # Replace list of per class metrics with separate metric for each class
        classes = metrics.pop("classes")
        map_per_class = metrics.pop("map_per_class")
        mar_100_per_class = metrics.pop("mar_100_per_class")
        for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
            class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item()
            metrics[f"map_{class_name}"] = class_map
            metrics[f"mar_100_{class_name}"] = class_mar

        metrics = {k: round(v.item(), 4) for k, v in metrics.items()}

        # Reset metric for next evaluation
        self.reset_metric()
        
        print("Checkpoint 7")

        return metrics

    

In [None]:
compute_metrics = Evaluator(image_processor=processor, id2label=id2label, threshold=0.0)

training_args = TrainingArguments(
    output_dir="maskformer_water_finetuned",
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    fp16=False,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    push_to_hub=False,
    lr_scheduler_type="constant",
    # dataloader_num_workers=8,
    # dataloader_persistent_workers=True,
    # dataloader_prefetch_factor=4,
    load_best_model_at_end=True,
    eval_accumulation_steps=5,
    report_to=None
)

# Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    compute_metrics=compute_metrics
)

trainer.train()

Epoch,Training Loss,Validation Loss


In [None]:
model.save_pretrained('maskformer_water')

In [None]:
# class SegmentationDataset(Dataset):
#     def __init__(self, image_dir, mask_dir, processor):
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.processor = processor
#         self.image_filenames = sorted(os.listdir(image_dir))
#         self.mask_filenames = sorted(os.listdir(mask_dir))
        
#         print(self.image_filenames)
#         print(self.mask_filenames)

#     def __len__(self):
#         return len(self.image_filenames)

#     def __getitem__(self, idx):
#         img_path = os.path.join(self.image_dir, self.image_filenames[idx])
#         mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])

#         # Load and preprocess image
#         image = Image.open(img_path).convert("RGB").resize(IMAGE_SIZE)
#         image = np.array(image) / 255.0  # Normalize

#         # Load and preprocess mask
#         mask = Image.open(mask_path).resize(IMAGE_SIZE)  # Nearest-neighbor for masks
#         mask = np.array(mask) / 255
        
#         # Ensure mask is single channel
#         if len(mask.shape) == 3:
#             mask = mask[:, :, 0]

#         # Convert image to model format
#         inputs = self.processor(image, return_tensors="pt")
#         pixel_values = inputs["pixel_values"].squeeze(0)  # Remove batch dimension

#         # Convert mask to tensor (0 and 1 for binary classification)
#         mask = torch.tensor(mask, dtype=torch.long)  # Shape: (512, 512)

#         return pixel_values, mask

# class TestDataset(Dataset):
#     def __init__(self, image_dir, processor):
#         self.image_dir = image_dir
#         self.processor = processor
#         self.image_filenames = sorted(os.listdir(image_dir))

#     def __len__(self):
#         return len(self.image_filenames)

#     def __getitem__(self, idx):
#         img_path = os.path.join(self.image_dir, self.image_filenames[idx])

#         # Load and preprocess image
#         image = Image.open(img_path).convert("RGB").resize(IMAGE_SIZE)
#         image_array = np.array(image) / 255.0  # Normalize

#         # Convert image to model format
#         inputs = self.processor(image_array, return_tensors="pt")
#         pixel_values = inputs["pixel_values"].squeeze(0)  # Remove batch dimension

#         return pixel_values, self.image_filenames[idx]  # Return filename to save output later


In [None]:
# full_dataset = SegmentationDataset(train_image_dir, train_mask_dir, processor)

# # Split into Train and Validation
# train_size = int((1 - VAL_SPLIT) * len(full_dataset))
# val_size = len(full_dataset) - train_size
# train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# # Create DataLoaders
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# criterion = torch.nn.CrossEntropyLoss()

# for epoch in range(NUM_EPOCHS):
#     model.train()
#     total_train_loss = 0

#     for step, (images, masks) in enumerate(train_loader):
#         images, masks = images.to(device), masks.to(device)

#         optimizer.zero_grad()

#         # No mixed precision (removed torch.amp.autocast and GradScaler)
#         outputs = model(pixel_values=images).logits  # Shape: (B, C, H, W)
#         outputs = F.interpolate(outputs, size=IMAGE_SIZE, mode="bilinear", align_corners=False)  # Resize to match masks
#         loss = criterion(outputs, masks)

#         loss.backward()
#         optimizer.step()

#         total_train_loss += loss.item()

#     # Validation Loop
#     model.eval()
#     total_val_loss = 0
#     with torch.no_grad():
#         for images, masks in val_loader:
#             images, masks = images.to(device), masks.to(device)

#             outputs = model(pixel_values=images).logits
#             outputs = F.interpolate(outputs, size=IMAGE_SIZE, mode="bilinear", align_corners=False)
#             loss = criterion(outputs, masks)

#             total_val_loss += loss.item()

#     avg_train_loss = total_train_loss / len(train_loader)
#     avg_val_loss = total_val_loss / len(val_loader)
#     print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

# # Save Model
# torch.save(model.state_dict(), "segformer_binary.pth")

In [None]:
# test_dataset = TestDataset(test_image_dir, processor)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# model.load_state_dict(torch.load("segformer_binary.pth"))
# model.eval()

# output_dir = "./predicted_masks_segformer"
# os.makedirs(output_dir, exist_ok=True)  # Create directory to save masks

# for images, filenames in test_loader:
#     images = images.to(device)

#     # Inference
#     with torch.no_grad():
#         outputs = model(pixel_values=images).logits  # (B, 2, H, W)
#         outputs = F.interpolate(outputs, size=IMAGE_SIZE, mode="bilinear", align_corners=False)
#         predicted_masks = torch.argmax(outputs, dim=1).cpu().numpy()  # Convert to numpy array

#     # Save or Display Results
#     for i in range(len(filenames)):
#         mask = Image.fromarray((predicted_masks[i] * 255).astype(np.uint8))  # Convert to image format
#         mask.save(os.path.join(output_dir, filenames[i].replace(".png", "_mask.png")))

# print(f"Predicted masks saved to {output_dir}")
