In [7]:
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

# Your existing function to load images and masks
def load_images_and_masks(image_dir, mask_dir, image_filenames, mask_filenames):
    images = []
    masks = []
    instance_class_mapping = []

    for img_file, mask_file in zip(image_filenames, mask_filenames):
        # Load image and mask
        img = Image.open(os.path.join(image_dir, img_file))
        mask = Image.open(os.path.join(mask_dir, mask_file))
        
        # Convert to numpy arrays
        img = np.array(img)
        mask = np.array(mask)
        
        # Extract class labels from the mask (assuming the class is encoded as wound_class * 15)
        mask_class = mask // 15  # This extracts the class label for each pixel

        # Create an instance mask (simply for this example, assuming each unique value is an instance)
        instance_ids = mask  # Use the raw mask values as instance IDs (assuming each unique value is an instance)

        # Find all unique instance IDs (ignoring background class 0)
        unique_instance_ids = np.unique(instance_ids[instance_ids != 0])

        # Create a dictionary mapping each instance ID to its class
        class_mapping = {}
        for instance_id in unique_instance_ids:
            # Map the instance ID to the class by taking the first pixel of that instance
            class_mapping[instance_id] = mask_class[mask == instance_id][0]

        # Append the results
        images.append(img)
        masks.append(mask_class)  # Use the class-based mask
        instance_class_mapping.append(class_mapping)  # Map instance IDs to their classes
    
    return np.array(images), np.array(masks), instance_class_mapping


image_dir = r'C:/users/comi/Desktop/Wound_segmentation_III/Data/new_images_640_1280'
mask_dir = r'C:/users/comi/Desktop/Wound_segmentation_III/Data/new_masks_640_1280'
image_filenames = os.listdir(image_dir)
mask_filenames = os.listdir(mask_dir)

images, masks, instance_class_mapping = load_images_and_masks(image_dir, mask_dir, image_filenames, mask_filenames)

# Split data into training and validation sets
X_train, X_val, y_train, y_val, mapping_train, mapping_val = train_test_split(
    images, masks, instance_class_mapping, test_size=0.2, random_state=42
)

# Output the number of training and validation images
print(f"Training images: {len(X_train)}, Validation images: {len(X_val)}")
print(f"Instance-Class Mapping for the first training image: {mapping_train[0]}")


Training images: 1600, Validation images: 401
Instance-Class Mapping for the first training image: {15: 1}


In [11]:
import torch
from torch.utils.data import Dataset

class SegmentationDataset(Dataset):
    def __init__(self, images, masks, instance_class_mapping, transform=None):
        self.images = images
        self.masks = masks
        self.instance_class_mapping = instance_class_mapping
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        instance_mapping = self.instance_class_mapping[idx]

        # Apply any transformations (optional)
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Convert the images and masks to tensors
        image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)  # HWC -> CHW
        mask_tensor = torch.tensor(mask, dtype=torch.long)  # Masks should be long type for labels

        # Extract class labels from the mask (unique values representing object classes)
        class_labels = torch.tensor(np.unique(mask_tensor[mask_tensor > 0]))  # Exclude background (class 0)

        # Create a pixel mask (assumes the entire image is valid, no padding)
        pixel_mask = torch.ones(image_tensor.shape[-2:], dtype=torch.bool)  # Shape: (height, width)

        return {
            "pixel_values": image_tensor,             # Input image
            "pixel_mask": pixel_mask,                 # Pixel mask (binary mask of valid image areas)
            "class_labels": class_labels,             # Unique class labels
            "mask_labels": mask_tensor,               # Full segmentation mask (pixel-wise class labels)
            "instance_class_mapping": instance_mapping # Mapping of instance IDs to their classes
        }


In [12]:
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation

# Define the configuration for the model
config = MaskFormerConfig(
    num_labels=14,  # Number of classes in your dataset
    ignore_mismatched_sizes=True
)

# Initialize MaskFormer model from scratch
model = MaskFormerForInstanceSegmentation(config)


In [14]:
class CustomDataCollator:
    def __call__(self, features):
        pixel_values = torch.stack([f["pixel_values"] for f in features])
        pixel_mask = torch.stack([f["pixel_mask"] for f in features])
        mask_labels = torch.stack([f["mask_labels"] for f in features])

        # Pad class_labels to the maximum length in the batch
        max_len = max([len(f["class_labels"]) for f in features])
        padded_class_labels = [torch.cat([f["class_labels"], torch.zeros(max_len - len(f["class_labels"]), dtype=torch.long)]) for f in features]
        class_labels = torch.stack(padded_class_labels)
        
        return {
            "pixel_values": pixel_values,
            "pixel_mask": pixel_mask,
            "mask_labels": mask_labels,
            "class_labels": class_labels,  # Now padded to have the same size
        }


In [15]:
from transformers import Trainer, TrainingArguments

# Define the training arguments
training_args = TrainingArguments(
    output_dir="./results",          # Output directory for model checkpoints
    evaluation_strategy="epoch",     # Evaluate the model at the end of every epoch
    save_strategy="epoch",           # Save checkpoints every epoch
    learning_rate=5e-5,              # Learning rate
    per_device_train_batch_size=4,   # Batch size for training
    per_device_eval_batch_size=4,    # Batch size for evaluation
    num_train_epochs=10,             # Number of epochs to train
    weight_decay=0.01,               # Weight decay (L2 regularization)
    logging_dir='./logs',            # Directory for logging
    logging_steps=100,               # Log every 100 steps
)

# Create the training and validation datasets
train_dataset = SegmentationDataset(X_train, y_train, mapping_train)
val_dataset = SegmentationDataset(X_val, y_val, mapping_val)

data_collator = CustomDataCollator()

# Initialize the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator  # Pass the custom data collator here
)

# Start training
trainer.train()



ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [640] and output size of torch.Size([320, 160]). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.