In [7]:
import torch
import numpy as np
from datasets import Dataset, Image
# from torch.utils.data import Dataset, DataLoader, random_split
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
import evaluate
import glob
import torch.nn as nn

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE = (512, 512)  # Resize images to this size
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 5e-5
VAL_SPLIT = 0.125

id2label = {0: 'background', 1: 'water'}
label2id = {label: id for id, label in id2label.items()}
NUM_CLASSES = len(id2label)

MODEL_CHECKPOINT = 'nvidia/mit-b4'

# Check GPU availability
model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=NUM_CLASSES,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# Load Pretrained SegFormer with 2 Classes

processor = SegformerImageProcessor.from_pretrained(MODEL_CHECKPOINT)
# model = SegformerForSemanticSegmentation.from_pretrained(model_name, ignore_mismatched_sizes=True)

# Modify the classifier head
# model.config.num_labels = NUM_CLASSES
# model.decode_head.classifier = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1)

# Move model to GPU
model.to(device)



Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b4 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)

In [9]:
train_image_dir = "./sar_images/images/train/*.png"
train_mask_dir = "./sar_images/masks/train/*.png"
test_image_dir = "./sar_images/images/test"
test_mask_dir = "./sar_images/masks/test"

images = list(glob.glob(train_image_dir))
# images = [str(path) for path in images]
masks = [path.replace('/images', '/masks') for path in images]

print(images)
print(masks)

print(f'{len(images)} images detected.')

train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=VAL_SPLIT, random_state=0, shuffle=True)

print(f'Train images: {len(train_images)}\nValidation images: {len(val_images)}')

['./sar_images/images/train\\0.png', './sar_images/images/train\\1.png', './sar_images/images/train\\100.png', './sar_images/images/train\\1000.png', './sar_images/images/train\\1001.png', './sar_images/images/train\\1002.png', './sar_images/images/train\\1003.png', './sar_images/images/train\\1004.png', './sar_images/images/train\\1005.png', './sar_images/images/train\\1006.png', './sar_images/images/train\\1007.png', './sar_images/images/train\\1009.png', './sar_images/images/train\\101.png', './sar_images/images/train\\1011.png', './sar_images/images/train\\1013.png', './sar_images/images/train\\1014.png', './sar_images/images/train\\1015.png', './sar_images/images/train\\1016.png', './sar_images/images/train\\1017.png', './sar_images/images/train\\1018.png', './sar_images/images/train\\1019.png', './sar_images/images/train\\102.png', './sar_images/images/train\\1021.png', './sar_images/images/train\\1022.png', './sar_images/images/train\\1023.png', './sar_images/images/train\\1024.

In [10]:
def load_image_as_rgb(image_path):
    # Open image
    img = PILImage.open(image_path)
    
    # If the image is grayscale (mode 'L'), convert it to RGB
    if img.mode == 'L':
        img = img.convert('RGB')  # Convert grayscale to RGB
    return img

def load_mask_as_binary(mask_path):
    # Open mask image (keep it in grayscale)
    mask = PILImage.open(mask_path)

    # Convert to grayscale (if not already in mode 'L')
    if mask.mode != 'L':
        mask = mask.convert('L')
    
    # Convert mask values from 0-255 to 0-1 (binary)
    mask = np.array(mask)  # Convert to NumPy array
    mask[mask == 255] = 1   # Replace 255 with 1
    mask[mask == 0] = 0     # Ensure 0 stays as 0
    
    # Convert back to PIL Image for compatibility
    mask = PILImage.fromarray(mask)
    
    return mask

def create_dataset(image_paths, mask_paths):
    # Apply the custom loader for RGB images and grayscale masks
    image_paths_rgb = [load_image_as_rgb(img_path) for img_path in image_paths]
    mask_paths_gray = [load_mask_as_binary(mask_path) for mask_path in mask_paths]
    
    # Create dataset from image paths and mask paths
    dataset = Dataset.from_dict({'pixel_values': image_paths_rgb,
                                 'label': mask_paths_gray})
    
    # Ensure images are loaded as Image() format (from PIL images)
    dataset = dataset.cast_column('pixel_values', Image())
    dataset = dataset.cast_column('label', Image())  # Keep masks as grayscale images
    return dataset

# Usage
ds_train = create_dataset(train_images, train_masks)
ds_valid = create_dataset(val_images, val_masks)

def apply_transforms(batch):
    images = [x for x in batch['pixel_values']]
    labels = [x for x in batch['label']]
    
    # Convert PIL images to NumPy arrays for processing
    images = [np.array(image) for image in images]
    labels = [np.array(label) for label in labels]
    
    # print(labels)    

    inputs = processor(images, labels)
    return inputs

ds_train.set_transform(apply_transforms)
ds_valid.set_transform(apply_transforms)


In [11]:
metric = evaluate.load('mean_iou')

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode='bilinear',
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        # currently using _compute instead of compute
        # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        metrics = metric._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=None,
                reduce_labels=processor.do_reduce_labels,
            )

        # add per category metrics as individual key-value pairs
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()

        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

        return metrics

In [12]:
training_args = TrainingArguments(
    output_dir='segformer_water_finetuned',
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    save_total_limit=3,
    # eval_strategy='steps',
    save_strategy='epoch',
    eval_strategy='epoch',
    logging_strategy='epoch',
    # save_steps=20,
    # eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to='none'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Accuracy Background,Accuracy Water,Iou Background,Iou Water
1,0.2169,0.143472,0.878603,0.925655,0.955402,0.97907,0.872239,0.944719,0.812487
2,0.158,0.135215,0.885657,0.932543,0.957933,0.978134,0.886953,0.947645,0.823669
3,0.1469,0.135052,0.887779,0.9389,0.958305,0.973744,0.904056,0.947862,0.827697
4,0.141,0.136396,0.882175,0.925281,0.957063,0.982351,0.868211,0.946837,0.817513
5,0.1355,0.126064,0.888053,0.933473,0.958924,0.979175,0.88777,0.948868,0.827239
6,0.1207,0.127867,0.889007,0.941328,0.958643,0.972419,0.910236,0.948196,0.829818
7,0.1127,0.142177,0.890695,0.937335,0.959758,0.977598,0.897073,0.949776,0.831613
8,0.1004,0.128205,0.8823,0.9288,0.956773,0.979029,0.87857,0.946325,0.818276
9,0.089,0.131612,0.889935,0.939193,0.959252,0.975211,0.903176,0.949059,0.830811
10,0.0827,0.136646,0.890086,0.93518,0.959678,0.97917,0.89119,0.949759,0.830413


TrainOutput(global_step=11050, training_loss=0.06722348838909718, metrics={'train_runtime': 7736.713, 'train_samples_per_second': 5.707, 'train_steps_per_second': 1.428, 'total_flos': 1.3331720679299482e+19, 'train_loss': 0.06722348838909718, 'epoch': 50.0})

In [13]:
model.save_pretrained('segformer_water')

In [14]:
# class SegmentationDataset(Dataset):
#     def __init__(self, image_dir, mask_dir, processor):
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.processor = processor
#         self.image_filenames = sorted(os.listdir(image_dir))
#         self.mask_filenames = sorted(os.listdir(mask_dir))
        
#         print(self.image_filenames)
#         print(self.mask_filenames)

#     def __len__(self):
#         return len(self.image_filenames)

#     def __getitem__(self, idx):
#         img_path = os.path.join(self.image_dir, self.image_filenames[idx])
#         mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])

#         # Load and preprocess image
#         image = Image.open(img_path).convert("RGB").resize(IMAGE_SIZE)
#         image = np.array(image) / 255.0  # Normalize

#         # Load and preprocess mask
#         mask = Image.open(mask_path).resize(IMAGE_SIZE)  # Nearest-neighbor for masks
#         mask = np.array(mask) / 255
        
#         # Ensure mask is single channel
#         if len(mask.shape) == 3:
#             mask = mask[:, :, 0]

#         # Convert image to model format
#         inputs = self.processor(image, return_tensors="pt")
#         pixel_values = inputs["pixel_values"].squeeze(0)  # Remove batch dimension

#         # Convert mask to tensor (0 and 1 for binary classification)
#         mask = torch.tensor(mask, dtype=torch.long)  # Shape: (512, 512)

#         return pixel_values, mask

# class TestDataset(Dataset):
#     def __init__(self, image_dir, processor):
#         self.image_dir = image_dir
#         self.processor = processor
#         self.image_filenames = sorted(os.listdir(image_dir))

#     def __len__(self):
#         return len(self.image_filenames)

#     def __getitem__(self, idx):
#         img_path = os.path.join(self.image_dir, self.image_filenames[idx])

#         # Load and preprocess image
#         image = Image.open(img_path).convert("RGB").resize(IMAGE_SIZE)
#         image_array = np.array(image) / 255.0  # Normalize

#         # Convert image to model format
#         inputs = self.processor(image_array, return_tensors="pt")
#         pixel_values = inputs["pixel_values"].squeeze(0)  # Remove batch dimension

#         return pixel_values, self.image_filenames[idx]  # Return filename to save output later


In [15]:
# full_dataset = SegmentationDataset(train_image_dir, train_mask_dir, processor)

# # Split into Train and Validation
# train_size = int((1 - VAL_SPLIT) * len(full_dataset))
# val_size = len(full_dataset) - train_size
# train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# # Create DataLoaders
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# criterion = torch.nn.CrossEntropyLoss()

# for epoch in range(NUM_EPOCHS):
#     model.train()
#     total_train_loss = 0

#     for step, (images, masks) in enumerate(train_loader):
#         images, masks = images.to(device), masks.to(device)

#         optimizer.zero_grad()

#         # No mixed precision (removed torch.amp.autocast and GradScaler)
#         outputs = model(pixel_values=images).logits  # Shape: (B, C, H, W)
#         outputs = F.interpolate(outputs, size=IMAGE_SIZE, mode="bilinear", align_corners=False)  # Resize to match masks
#         loss = criterion(outputs, masks)

#         loss.backward()
#         optimizer.step()

#         total_train_loss += loss.item()

#     # Validation Loop
#     model.eval()
#     total_val_loss = 0
#     with torch.no_grad():
#         for images, masks in val_loader:
#             images, masks = images.to(device), masks.to(device)

#             outputs = model(pixel_values=images).logits
#             outputs = F.interpolate(outputs, size=IMAGE_SIZE, mode="bilinear", align_corners=False)
#             loss = criterion(outputs, masks)

#             total_val_loss += loss.item()

#     avg_train_loss = total_train_loss / len(train_loader)
#     avg_val_loss = total_val_loss / len(val_loader)
#     print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

# # Save Model
# torch.save(model.state_dict(), "segformer_binary.pth")

In [16]:
# test_dataset = TestDataset(test_image_dir, processor)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# model.load_state_dict(torch.load("segformer_binary.pth"))
# model.eval()

# output_dir = "./predicted_masks_segformer"
# os.makedirs(output_dir, exist_ok=True)  # Create directory to save masks

# for images, filenames in test_loader:
#     images = images.to(device)

#     # Inference
#     with torch.no_grad():
#         outputs = model(pixel_values=images).logits  # (B, 2, H, W)
#         outputs = F.interpolate(outputs, size=IMAGE_SIZE, mode="bilinear", align_corners=False)
#         predicted_masks = torch.argmax(outputs, dim=1).cpu().numpy()  # Convert to numpy array

#     # Save or Display Results
#     for i in range(len(filenames)):
#         mask = Image.fromarray((predicted_masks[i] * 255).astype(np.uint8))  # Convert to image format
#         mask.save(os.path.join(output_dir, filenames[i].replace(".png", "_mask.png")))

# print(f"Predicted masks saved to {output_dir}")
