In [None]:
import torch
import numpy as np
from datasets import Dataset, Image
# from torch.utils.data import Dataset, DataLoader, random_split
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
import evaluate
import glob
import torch.nn as nn

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE = (512, 512)  # Resize images to this size
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 5e-5
VAL_SPLIT = 0.125

id2label = {0: 'background', 1: 'water'}
label2id = {label: id for id, label in id2label.items()}
NUM_CLASSES = len(id2label)

MODEL_CHECKPOINT = 'nvidia/mit-b4'

# Check GPU availability
model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=NUM_CLASSES,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# Load Pretrained SegFormer with 2 Classes

processor = SegformerImageProcessor.from_pretrained(MODEL_CHECKPOINT)
# model = SegformerForSemanticSegmentation.from_pretrained(model_name, ignore_mismatched_sizes=True)

# Modify the classifier head
# model.config.num_labels = NUM_CLASSES
# model.decode_head.classifier = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1)

# Move model to GPU
model.to(device)



In [None]:
train_image_dir = "./sar_images/images/train/*.png"
train_mask_dir = "./sar_images/masks/train/*.png"
test_image_dir = "./sar_images/images/test"
test_mask_dir = "./sar_images/masks/test"

images = list(glob.glob(train_image_dir))
# images = [str(path) for path in images]
masks = [path.replace('/images', '/masks') for path in images]

print(images)
print(masks)

print(f'{len(images)} images detected.')

train_images, val_images, train_masks, val_masks = train_test_split(
    images, masks, test_size=VAL_SPLIT, random_state=0, shuffle=True)

print(f'Train images: {len(train_images)}\nValidation images: {len(val_images)}')

In [None]:
def load_image_as_rgb(image_path):
    # Open image
    img = PILImage.open(image_path)
    
    # If the image is grayscale (mode 'L'), convert it to RGB
    if img.mode == 'L':
        img = img.convert('RGB')  # Convert grayscale to RGB
    return img

def load_mask_as_binary(mask_path):
    # Open mask image (keep it in grayscale)
    mask = PILImage.open(mask_path)

    # Convert to grayscale (if not already in mode 'L')
    if mask.mode != 'L':
        mask = mask.convert('L')
    
    # Convert mask values from 0-255 to 0-1 (binary)
    mask = np.array(mask)  # Convert to NumPy array
    mask[mask == 255] = 1   # Replace 255 with 1
    mask[mask == 0] = 0     # Ensure 0 stays as 0
    
    # Convert back to PIL Image for compatibility
    mask = PILImage.fromarray(mask)
    
    return mask

def create_dataset(image_paths, mask_paths):
    # Apply the custom loader for RGB images and grayscale masks
    image_paths_rgb = [load_image_as_rgb(img_path) for img_path in image_paths]
    mask_paths_gray = [load_mask_as_binary(mask_path) for mask_path in mask_paths]
    
    # Create dataset from image paths and mask paths
    dataset = Dataset.from_dict({'pixel_values': image_paths_rgb,
                                 'label': mask_paths_gray})
    
    # Ensure images are loaded as Image() format (from PIL images)
    dataset = dataset.cast_column('pixel_values', Image())
    dataset = dataset.cast_column('label', Image())  # Keep masks as grayscale images
    return dataset

# Usage
ds_train = create_dataset(train_images, train_masks)
ds_valid = create_dataset(val_images, val_masks)

def apply_transforms(batch):
    images = [x for x in batch['pixel_values']]
    labels = [x for x in batch['label']]
    
    # Convert PIL images to NumPy arrays for processing
    images = [np.array(image) for image in images]
    labels = [np.array(label) for label in labels]
    
    # print(labels)    

    inputs = processor(images, labels)
    return inputs

ds_train.set_transform(apply_transforms)
ds_valid.set_transform(apply_transforms)


In [None]:
metric = evaluate.load('mean_iou')

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        print(labels.shape[-2:])
        logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode='bilinear',
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        # currently using _compute instead of compute
        # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        metrics = metric._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=None,
                reduce_labels=processor.do_reduce_labels,
            )

        # add per category metrics as individual key-value pairs
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()

        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

        return metrics

In [None]:
training_args = TrainingArguments(
    output_dir='segformer_water_finetuned',
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    save_total_limit=3,
    # eval_strategy='steps',
    save_strategy='epoch',
    eval_strategy='epoch',
    logging_strategy='epoch',
    # save_steps=20,
    # eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to='none'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
model.save_pretrained('segformer_water')