# Mask2Former Fine-tuning for Contrail Segmentation

This notebook fine-tunes a pre-trained Mask2Former model on the contrail segmentation dataset.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from pathlib import Path
from transformers import (
    Mask2FormerImageProcessor,
    Mask2FormerForUniversalSegmentation,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
)
import torch
import albumentations as A
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
from torch.utils.data import Dataset
import random

# Configuration
run_id = "contrail_segmentation"
task = "instance"  # Instance segmentation for contrail detection
batch_size = 2  # Adjust based on GPU memory
num_train_epochs = 5
model_size = "base"

# Dataset paths
dataset_dir = Path("/data/common/STEREOSTUDYIPSL/Codebase/FineTuning/dataset/D-imageWithAnnotation/D-imageWithAnnotation")
model_dir = Path("/data/common/STEREOSTUDYIPSL/Codebase/FineTuning")
base_model = f"facebook/mask2former-swin-{model_size}-coco-{task}"

print(f"Base model: {base_model}")
print(f"Dataset directory: {dataset_dir}")
print(f"Output directory: {model_dir}")

2026-01-13 16:16:37.846629: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-13 16:16:37.926103: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-13 16:16:39.783303: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Base model: facebook/mask2former-swin-base-coco-instance
Dataset directory: /data/common/STEREOSTUDYIPSL/Codebase/FineTuning/dataset/D-imageWithAnnotation/D-imageWithAnnotation
Output directory: /data/common/STEREOSTUDYIPSL/Codebase/FineTuning


In [2]:
# Load COCO annotations and setup categories
coco = COCO(str(dataset_dir / "annotations.coco.json"))

# Get categories from the dataset
cat_ids = coco.getCatIds()
categories = coco.loadCats(cat_ids)

# Create id2label mapping (model expects 0-indexed labels)
id2label = {idx: cat["name"] for idx, cat in enumerate(categories)}
label2id = {v: k for k, v in id2label.items()}

print(f"Categories found: {id2label}")
print(f"Number of categories: {len(categories)}")
print(f"Number of images: {len(coco.getImgIds())}")

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
Categories found: {0: 'contrail', 1: 'contrail maybe', 2: 'contrail old', 3: 'contrail veryold', 4: 'contrail young', 5: 'parasite', 6: 'sun', 7: 'unknow'}
Number of categories: 8
Number of images: 1600


In [22]:
# Split dataset into train and validation sets
def split_coco_dataset(coco, img_dir, train_size=0.8, seed=42):
    """Split COCO dataset into train and validation sets, filtering for existing images."""
    random.seed(seed)
    img_ids = coco.getImgIds()
    
    # Filter to only include images that exist on disk
    valid_img_ids = []
    for img_id in img_ids:
        img_info = coco.loadImgs(img_id)[0]
        img_path = img_dir / img_info["file_name"]
        if img_path.exists():
            valid_img_ids.append(img_id)
    
    print(f"Found {len(valid_img_ids)} valid images out of {len(img_ids)} total")
    
    random.shuffle(valid_img_ids)
    
    split_idx = int(len(valid_img_ids) * train_size)
    train_ids = valid_img_ids[:split_idx]
    val_ids = valid_img_ids[split_idx:]
    
    return train_ids, val_ids

train_img_ids, val_img_ids = split_coco_dataset(coco, dataset_dir, train_size=0.8, seed=42)
print(f"Training images: {len(train_img_ids)}")
print(f"Validation images: {len(val_img_ids)}")

Found 1568 valid images out of 1600 total
Training images: 1254
Validation images: 314


In [18]:
# Define custom dataset class for instance segmentation
class InstanceSegmentationDataset(Dataset):
    """Dataset class for instance segmentation with Mask2Former."""
    
    def __init__(self, coco, img_ids, img_dir, processor, transform=None):
        self.coco = coco
        self.img_ids = img_ids
        self.img_dir = Path(img_dir)
        self.processor = processor
        self.transform = transform
        
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        
        # Load image
        img_path = self.img_dir / img_info["file_name"]
        image = Image.open(img_path).convert("RGB")
        image = np.array(image, dtype=np.uint8)  # Ensure uint8 format
        
        # Get annotations for this image
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        
        # Create instance masks and class labels
        masks = []
        class_labels = []
        
        for ann in anns:
            # Get binary mask from annotation
            mask = self.coco.annToMask(ann)
            if mask.sum() > 0:  # Only add non-empty masks
                masks.append(mask.astype(np.uint8))
                # Map category id to 0-indexed label
                cat_id = ann["category_id"]
                class_labels.append(cat_id)
        
        # Apply augmentations if provided
        if self.transform is not None and len(masks) > 0:
            transformed = self.transform(image=image, masks=masks)
            image = transformed["image"]
            masks = transformed["masks"]
        
        # If no masks, create dummy mask
        if len(masks) == 0:
            masks = [np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)]
            class_labels = [0]  # Default to first class
        
        # Stack masks into (num_instances, H, W) format
        instance_masks = np.stack(masks, axis=0).astype(np.uint8)
        
        # Convert class labels to tensor
        class_labels_tensor = torch.tensor(class_labels, dtype=torch.long)
        
        # Process image with Mask2Former processor
        # Use input_data_format to explicitly specify HWC format
        inputs = self.processor(
            images=image,
            return_tensors="pt",
            input_data_format="channels_last",  # Explicitly specify HWC format
        )
        
        # Convert masks to tensors
        mask_labels = torch.from_numpy(instance_masks).float()
        
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "pixel_mask": inputs["pixel_mask"].squeeze(0),
            "class_labels": class_labels_tensor,
            "mask_labels": mask_labels,
        }

In [12]:
# Setup image processor and model
processor = Mask2FormerImageProcessor.from_pretrained(
    base_model,
    do_resize=True,
    size={"height": 512, "width": 512},  # Resize to manageable size
    do_rescale=True,
    do_normalize=True,
    ignore_index=255,
)

# Load model with custom class mappings
model = Mask2FormerForUniversalSegmentation.from_pretrained(
    base_model,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,  # Class count differs from COCO
)

print(f"Model loaded: {base_model}")
print(f"Number of parameters: {model.num_parameters():,}")

  image_processor = cls(**image_processor_dict)
Some weights of Mask2FormerForUniversalSegmentation were not initialized from the model checkpoint at facebook/mask2former-swin-base-coco-instance and are newly initialized because the shapes did not match:
- class_predictor.bias: found shape torch.Size([81]) in the checkpoint and torch.Size([9]) in the model instantiated
- class_predictor.weight: found shape torch.Size([81, 256]) in the checkpoint and torch.Size([9, 256]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([81]) in the checkpoint and torch.Size([9]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded: facebook/mask2former-swin-base-coco-instance
Number of parameters: 106,885,697


In [23]:
# Define augmentations
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
])

# Create datasets
train_dataset = InstanceSegmentationDataset(
    coco=coco,
    img_ids=train_img_ids,
    img_dir=dataset_dir,
    processor=processor,
    transform=transform,
)

val_dataset = InstanceSegmentationDataset(
    coco=coco,
    img_ids=val_img_ids,
    img_dir=dataset_dir,
    processor=processor,
    transform=None,  # No augmentation for validation
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Train dataset size: 1254
Validation dataset size: 314


In [14]:
# Custom collate function for batching
def collate_fn(batch):
    """Custom collate function to batch samples properly."""
    pixel_values = torch.stack([example["pixel_values"] for example in batch])
    pixel_mask = torch.stack([example["pixel_mask"] for example in batch])
    class_labels = [example["class_labels"] for example in batch]
    mask_labels = [example["mask_labels"] for example in batch]

    return {
        "pixel_values": pixel_values,
        "pixel_mask": pixel_mask,
        "class_labels": class_labels,
        "mask_labels": mask_labels,
    }

In [15]:
# Training configuration
training_args = TrainingArguments(
    output_dir=str(model_dir / "checkpoints"),
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=100,
    save_steps=500,
    save_total_limit=2,
    num_train_epochs=num_train_epochs,
    learning_rate=(batch_size / 16.0) * 1e-4,
    fp16=True,  # Mixed-precision training
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    remove_unused_columns=False,
    dataloader_num_workers=4,
)

# Early stopping callback
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)

print("Training configuration:")
print(f"  - Batch size: {batch_size}")
print(f"  - Learning rate: {training_args.learning_rate}")
print(f"  - Epochs: {num_train_epochs}")
print(f"  - Output dir: {training_args.output_dir}")

Training configuration:
  - Batch size: 2
  - Learning rate: 1.25e-05
  - Epochs: 2
  - Output dir: /data/common/STEREOSTUDYIPSL/Codebase/FineTuning/checkpoints


In [24]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    callbacks=[early_stopping],
)

print("Trainer initialized. Starting training...")
print(f"Total training steps: {len(train_dataset) // batch_size * num_train_epochs}")

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Trainer initialized. Starting training...
Total training steps: 1254


In [25]:
# Start training
trainer.train()

Step,Training Loss,Validation Loss
500,46.3112,46.055614
1000,42.7594,44.42664


TrainOutput(global_step=1254, training_loss=47.37378630569677, metrics={'train_runtime': 253.2772, 'train_samples_per_second': 9.902, 'train_steps_per_second': 4.951, 'total_flos': 1.2642947857058365e+18, 'train_loss': 47.37378630569677, 'epoch': 2.0})

In [26]:
# Save the best model
import shutil

# Get path to best checkpoint
best_ckpt_path = Path(trainer.state.best_model_checkpoint)
best_path = model_dir / run_id

print(f"Best checkpoint: {best_ckpt_path}")

# Remove existing directory if it exists
if best_path.exists():
    shutil.rmtree(best_path)

# Rename best checkpoint to final model directory
best_ckpt_path.rename(best_path)

print(f"Model saved to: {best_path}")

# Also save the processor
processor.save_pretrained(best_path)
print(f"Processor saved to: {best_path}")

Best checkpoint: /data/common/STEREOSTUDYIPSL/Codebase/FineTuning/checkpoints/checkpoint-1000
Model saved to: /data/common/STEREOSTUDYIPSL/Codebase/FineTuning/contrail_segmentation
Processor saved to: /data/common/STEREOSTUDYIPSL/Codebase/FineTuning/contrail_segmentation


In [None]:
# Clean up checkpoints directory (optional)
checkpoints_dir = model_dir / "checkpoints"
if checkpoints_dir.exists():
    shutil.rmtree(checkpoints_dir)
    print(f"Cleaned up checkpoints directory: {checkpoints_dir}")

print("\n" + "="*50)
print("Training complete!")
print(f"Final model saved to: {best_path}")
print("="*50)