In [97]:
import torch
import numpy as np
from datasets import Dataset, Image
# from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation, TrainingArguments, Trainer, MaskFormerConfig, MaskFormerModel, MaskFormerImageProcessor
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
import glob
import torch.nn as nn
from torchvision import transforms

from typing import Dict, List, Mapping
from transformers.trainer import EvalPrediction
# from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchmetrics import JaccardIndex, Accuracy
from dataclasses import dataclass

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [98]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE = (512, 512)  # Resize images to this size
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 5e-5
VAL_SPLIT = 0.125

id2label = {0: 'background', 1: 'water'}
label2id = {label: id for id, label in id2label.items()}
NUM_CLASSES = len(id2label)

MODEL_CHECKPOINT = "facebook/maskformer-resnet50-coco-stuff"

config = MaskFormerConfig.from_pretrained(MODEL_CHECKPOINT)
config.id2label = id2label
config.label2id = label2id
config.num_labels = NUM_CLASSES

# Use the config object to initialize a MaskFormer model with randomized weights
model = MaskFormerForInstanceSegmentation(config)

base_model = MaskFormerModel.from_pretrained(MODEL_CHECKPOINT)
model.model = base_model

# processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
processor = MaskFormerImageProcessor.from_pretrained(MODEL_CHECKPOINT)


model.to(device)

  return func(*args, **kwargs)


MaskFormerForInstanceSegmentation(
  (model): MaskFormerModel(
    (pixel_level_module): MaskFormerPixelLevelModule(
      (encoder): ResNetBackbone(
        (embedder): ResNetEmbeddings(
          (embedder): ResNetConvLayer(
            (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (activation): ReLU()
          )
          (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
        (encoder): ResNetEncoder(
          (stages): ModuleList(
            (0): ResNetStage(
              (layers): Sequential(
                (0): ResNetBottleNeckLayer(
                  (shortcut): ResNetShortCut(
                    (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                    (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True

In [99]:
train_image_dir = "./sar_images/images/train/*.png"
train_mask_dir = "./sar_images/masks/train/*.png"
test_image_dir = "./sar_images/images/test"
test_mask_dir = "./sar_images/masks/test"

images = list(glob.glob(train_image_dir))
# images = [str(path) for path in images]
masks = [path.replace('/images', '/masks') for path in images]

# print(images)
# print(masks)

# print(f'{len(images)} images detected.')

train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=VAL_SPLIT, random_state=0, shuffle=True)

print(f'Train images: {len(train_images)}\nValidation images: {len(val_images)}')

Train images: 883
Validation images: 127


In [100]:
def create_dataset(image_paths, mask_paths):
    """Creates a dataset storing file paths as individual strings, not lists"""
    return Dataset.from_dict({"pixel_values": image_paths, "label": mask_paths})

# Create dataset from file paths
ds_train = create_dataset(train_images, train_masks)
ds_valid = create_dataset(val_images, val_masks)

alb_transform = A.Compose([
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
    ToTensorV2()
])

def transform(example):
    """Loads images/masks from file paths and applies transformations using Albumentations."""
    
    images = []
    masks = []
    
    batch = {
        "pixel_values": [],
        "mask_labels": [],
        "class_labels": [],
    }
    
    for img_path, mask_path in zip(example["pixel_values"], example["label"]):
        
        # print(img_path, mask_path)
        
        # Open images and masks
        image = np.array(PILImage.open(img_path).convert("RGB"))
        mask = np.array(PILImage.open(mask_path).convert("L"), dtype=np.uint8)  # Convert mask to grayscale

        mask[mask == 255] = 1  # Convert 255 to 1
        # Apply Albumentations transform
        transformed = alb_transform(image=image, mask=mask)
        
        # Extract transformed image and mask
        image = transformed["image"]
        mask = transformed["mask"].long()  # Ensure mask is long tensor
        
        images.append(image)
        masks.append(mask)
        
        # print(mask.min(), mask.max())

    # Process inputs using the Hugging Face processor
    model_inputs = processor(images, segmentation_maps=masks, return_tensors='pt')
    
    batch["pixel_values"].append(model_inputs.pixel_values[0])
    batch["mask_labels"].append(model_inputs.mask_labels[0])
    batch["class_labels"].append(model_inputs.class_labels[0])
       
    return batch


# def collate_fn(examples):
#     batch = {}
#     batch["pixel_values"] = torch.stack([example["pixel_values"] for example in examples])
#     batch["class_labels"] = [example["class_labels"] for example in examples]
#     batch["mask_labels"] = [example["mask_labels"] for example in examples]
#     if "pixel_mask" in examples[0]:
#         batch["pixel_mask"] = torch.stack([example["pixel_mask"] for example in examples])
#     return batch

# Apply transformation correctly
ds_train.set_transform(transform)
ds_valid.set_transform(transform)


In [101]:
@dataclass
class ModelOutput:
    class_queries_logits: torch.Tensor
    masks_queries_logits: torch.Tensor

def nested_cpu(tensors):
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_cpu(t) for t in tensors)
    elif isinstance(tensors, Mapping):
        return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()})
    elif isinstance(tensors, torch.Tensor):
        return tensors.cpu().detach()
    elif isinstance(tensors, np.ndarray):
        return torch.from_numpy(tensors)
    else:
        return tensors

class Evaluator:
    """
    Compute metrics for the instance segmentation task.
    """

    def __init__(
        self,
        image_processor: AutoImageProcessor,
        id2label: Mapping[int, str],
        num_classes
    ):
        """
        Initialize evaluator with image processor, id2label mapping and threshold for filtering predictions.

        Args:
            image_processor (AutoImageProcessor): Image processor for
                `post_process_instance_segmentation` method.
            id2label (Mapping[int, str]): Mapping from class id to class name.
            threshold (float): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
        """
        self.image_processor = image_processor
        self.id2label = id2label
        # self.threshold = threshold
        self.iou_metric = JaccardIndex(task="multiclass", num_classes=num_classes, average=None)
        self.acc_metric = Accuracy(task="multiclass", num_classes=num_classes, average=None)
        self.num_classes = num_classes

    def reset_metric(self):
        self.iou_metric.reset()
        self.acc_metric.reset()

    def postprocess_target_batch(self, target_batch) -> List[Dict[str, torch.Tensor]]:
        """Process targets into dict format."""
        # print(type(target_batch))
        # print(len(target_batch))
        batch_masks, batch_labels = target_batch[0], target_batch[1]
        
        # print(type(batch_masks))
        # print(type(batch_labels))
        # print(len(batch_masks))
        # print(len(batch_labels))
        
        # print(type(batch_masks[0]))
        # print(batch_masks[0].shape)
        # print(type(batch_labels[0]))
        # print(batch_labels[0].shape)
        # print(len(batch_masks))
        # print(len(batch_labels))
        
        # print(batch_masks.shape)
        # print(batch_labels.shape)
        # print(torch.tensor(masks).shape, torch.tensor(labels).shape)
        
        post_processed_targets = [
            {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}
            for masks, labels in zip(batch_masks, batch_labels)
        ]
        
        return post_processed_targets
    
    def get_target_sizes(self, post_processed_targets) -> List[List[int]]:
        """Get sizes of target masks."""
        return [target["masks"].shape[-2:] for target in post_processed_targets]

    def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> List[Dict[str, torch.Tensor]]:
        """Collect predictions in a form of list of dictionaries with keys "masks", "labels", "scores"."""
        
        # print("Postprocess Prediction")
        
        class_queries_logits = prediction_batch[0]
        masks_queries_logits = prediction_batch[1]
        
        # print("Here 1")

        model_output = ModelOutput(class_queries_logits=class_queries_logits, masks_queries_logits=masks_queries_logits)
        post_processed_output = self.image_processor.post_process_semantic_segmentation(
            model_output, target_sizes=target_sizes
        )

        # print("Here 2")

        post_processed_predictions = []
        for image_prediction in post_processed_output:  # No need for segments_info handling
            post_processed_image_prediction = {
                "masks": image_prediction.to(dtype=torch.long),  # Ensure it's a tensor of labels
            }
            post_processed_predictions.append(post_processed_image_prediction)

        return post_processed_predictions

    @torch.no_grad()
    def __call__(self, evaluation_results: EvalPrediction, compute_result: bool = True) -> Mapping[str, float]:
        """
        Update metrics with current evaluation results and return metrics if `compute_result` is True.

        Args:
            evaluation_results (EvalPrediction): Predictions and targets from evaluation.
            compute_result (bool): Whether to compute and return metrics.

        Returns:
            Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>}
        """

        prediction_batch = nested_cpu(evaluation_results.predictions)
        target_batch = nested_cpu(evaluation_results.label_ids)

        # print("Checkpoint 1")
        post_processed_targets = self.postprocess_target_batch(target_batch)
        # print(len(post_processed_targets))
        # print("Checkpoint 2")
        target_sizes = self.get_target_sizes(post_processed_targets)
        # print(len(target_sizes))
        # print("Checkpoint 3")
        post_processed_predictions = self.postprocess_prediction_batch(prediction_batch, target_sizes)
        
        # print("Checkpoint 4")
        
        post_processed_predictions = [nested_cpu(item) for item in post_processed_predictions]
        post_processed_targets = [nested_cpu(item) for item in post_processed_targets]

        # print("Checkpoint 5")
        if not compute_result:
            return
            
        pred_masks = torch.stack([p["masks"] for p in post_processed_predictions])
        
        for t in post_processed_targets:
            t["masks"] = torch.argmax(t["masks"], dim=0, keepdim=False)
        
        target_masks = torch.stack([t["masks"] for t in post_processed_targets]) 
        
        # print("Checkpoint 6")
        # Compute metrics
        
        print(pred_masks.min(), pred_masks.max())
        print(target_masks.min(), target_masks.max())
        
        # self.iou_metric.update(pred_masks, target_masks)
        # self.acc_metric.update(pred_masks, target_masks)
        # iou_metrics = self.iou_metric.compute()
        # acc_metrics = self.acc_metric.compute()
        
        iou_per_class = self.iou_metric(pred_masks, target_masks)
        acc_per_class = self.acc_metric(pred_masks, target_masks)
        
        # print("Checkpoint 7")

        # Compute mean IoU & accuracy
        mean_iou = iou_per_class.mean().item()
        mean_acc = acc_per_class.mean().item()
        overall_acc = (pred_masks == target_masks).float().mean().item()

        # Get specific class IoU and accuracy
        iou_background = iou_per_class[0].item()
        iou_water = iou_per_class[1].item() if self.num_classes > 1 else -1
        acc_background = acc_per_class[0].item()
        acc_water = acc_per_class[1].item() if self.num_classes > 1 else -1
        
        # print("Checkpoint 8")

        metrics = {
            "mean_iou": round(mean_iou, 4),
            "mean_accuracy": round(mean_acc, 4),
            "overall_accuracy": round(overall_acc, 4),
            "accuracy_background": round(acc_background, 4),
            "accuracy_water": round(acc_water, 4),
            "iou_background": round(iou_background, 4),
            "iou_water": round(iou_water, 4),
        }

        self.reset_metric()
        # print("Checkpoint 9")
        return metrics

In [102]:
compute_metrics = Evaluator(image_processor=processor, id2label=id2label, num_classes=NUM_CLASSES)

training_args = TrainingArguments(
    output_dir="maskformer_water_finetuned",
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    fp16=False,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    push_to_hub=False,
    lr_scheduler_type="constant",
    # eval_do_concat_batches=False,
    # dataloader_num_workers=8,
    # dataloader_persistent_workers=True,
    # dataloader_prefetch_factor=4,
    load_best_model_at_end=True,
    eval_accumulation_steps=5,
    report_to=None
)

# Trainer setup
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=ds_train,
#     eval_dataset=ds_valid,
#     data_collator=collate_fn,
#     compute_metrics=compute_metrics
# )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    compute_metrics=compute_metrics
)

trainer.train()

Epoch,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Accuracy Background,Accuracy Water,Iou Background,Iou Water
1,3.4284,12.590865,0.1266,0.5002,0.2529,0.0006,0.9998,0.0006,0.2526
2,1.8753,1.574219,0.4616,0.5948,0.713,0.8337,0.3558,0.6847,0.2384
3,1.2919,2.267871,0.6264,0.7341,0.8325,0.9328,0.5354,0.8063,0.4465
4,2.0819,1.32768,0.6167,0.837,0.7804,0.7227,0.9513,0.711,0.5225
5,1.0611,2.302904,0.7835,0.8505,0.9117,0.9741,0.7269,0.8918,0.6752
6,1.1839,3.193964,0.6644,0.8309,0.826,0.8211,0.8407,0.7792,0.5496
7,0.8363,1.418788,0.7476,0.8139,0.8988,0.9855,0.6424,0.8792,0.6159
8,0.9391,3.418244,0.6068,0.6943,0.8431,0.9949,0.3936,0.8258,0.3878
9,0.8752,2.796201,0.7898,0.8468,0.9166,0.9877,0.7059,0.8985,0.6812
10,1.0363,2.010036,0.595,0.7056,0.8167,0.9299,0.4814,0.7913,0.3987


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


  {"masks": torch.tensor(masks, dtype=torch.long), "labels": torch.tensor(labels)}


tensor(0) tensor(1)
tensor(0) tensor(1)


TrainOutput(global_step=5500, training_loss=0.7003448947559704, metrics={'train_runtime': 5036.8528, 'train_samples_per_second': 8.765, 'train_steps_per_second': 1.092, 'total_flos': 5.20318690418304e+18, 'train_loss': 0.7003448947559704, 'epoch': 49.55203619909502})

In [103]:
model.save_pretrained('maskformer_water')