In [None]:
import torch
import pytorch_lightning
import lightning_utilities
import torchmetrics

print(f"PyTorch Version: {torch.__version__}")
print(f"PyTorch Lightning Version: {pytorch_lightning.__version__}")
print(f"Lightning Utilities Version: {lightning_utilities.__version__}")
print(f"Torchmetrics Version: {torchmetrics.__version__}")

## Reference

```python
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

```

## Implementation

In [17]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorch_lightning import LightningDataModule

from sklearn.model_selection import train_test_split
from pytorch_lightning import LightningModule, Trainer
import torch.nn as nn
from torch.optim import Adam

class SARDataModule(LightningDataModule):
    
    def __init__(self, data_dir: str = "./", batch_size: int = 8, val_split: float = 0.2):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.val_split = val_split

        # Transformation for images
        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
            #transforms.Lambda(lambda x: x[0, :, :].unsqueeze(0))  # Take the first channel
            # Augmentation TO DO
        ])

        # Transformation for masks
        self.mask_transform = transforms.Compose([
            transforms.ToTensor()
            # No channel selection for masks
        ])
        
    def prepare_data(self) -> None:
        # Not needed in our case, no download or labelling needed
        # prepare_data(self) is used for operations that run only once and on one process.
        pass 

        
    def setup(self, stage: str = None) -> None:
        
        # Helper function to get sorted file names
        def get_sorted_file_names(folder_path):
            return sorted([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))])

        # ----------------- TRAIN ----------------- 
        # Paths to the images and masks directories
        train_images_dir = os.path.join(self.data_dir, 'train/images')
        train_masks_dir = os.path.join(self.data_dir, 'train/labels_1D')

        # Get the list of image filenames
        train_images_filenames = get_sorted_file_names(train_images_dir)

        # Generate full paths for images and masks
        train_images_paths = [os.path.join(train_images_dir, f) for f in train_images_filenames]
        train_masks_paths = [os.path.join(train_masks_dir, os.path.splitext(f)[0] + '.png') for f in train_images_filenames]

        # Split into train and validation sets
        train_images_paths, val_images_paths, train_masks_paths, val_masks_paths = train_test_split(
            train_images_paths, train_masks_paths, test_size=self.val_split, random_state=42)
        
        # ----------------- TEST ----------------- 
        # Paths to the test dataset
        test_images_dir = os.path.join(self.data_dir, 'test/images')
        test_masks_dir = os.path.join(self.data_dir, 'test/labels_1D')

        # Get the list of test image filenames
        test_images_filenames = get_sorted_file_names(test_images_dir)

        # Generate full paths for test images and masks
        test_images_paths = [os.path.join(test_images_dir, f) for f in test_images_filenames]
        test_masks_paths = [os.path.join(test_masks_dir, os.path.splitext(f)[0] + '.png') for f in test_images_filenames]

        # ----------------- LOADING ----------------- 
        if stage == "fit" or stage is None:
            self.train_dataset = SARImageDataset(
                train_images_paths, train_masks_paths,
                image_transform=self.image_transform, mask_transform=self.mask_transform
            )
            self.val_dataset = SARImageDataset(
                val_images_paths, val_masks_paths,
                image_transform=self.image_transform, mask_transform=self.mask_transform
            )
        if stage == "test" or stage is None:
            self.test_dataset = SARImageDataset(
                test_images_paths, test_masks_paths,
                image_transform=self.image_transform, mask_transform=self.mask_transform
            )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

class SARImageDataset(Dataset): # Allows the user to apply a custom transformation via self.mask_transform. 

    def __init__(self, images_paths, masks_paths, image_transform=None, mask_transform=None):
        self.images_paths = images_paths
        self.masks_paths = masks_paths
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        # Load the image
        image_path = self.images_paths[idx]
        with Image.open(image_path) as img:
            if self.image_transform:
                img = self.image_transform(img)
            else:
                img = transforms.ToTensor()(img)

        # Load the corresponding mask
        mask_path = self.masks_paths[idx]
        if not os.path.exists(mask_path):
            raise FileNotFoundError(f"Mask file not found: {mask_path}")
        with Image.open(mask_path) as mask:
            # Convert mask to grayscale (DO WE NEED IT?)
            # mask = mask.convert('L')
            if self.mask_transform:
                mask = self.mask_transform(mask)
            else:
                mask = transforms.ToTensor()(mask)
            mask = mask.squeeze(0).long()              

        # Return both the image and its corresponding mask
        return img, mask

## Test

In [4]:
data_module = SARDataModule(data_dir="../dataset", batch_size=8, val_split=0.2)
data_module.prepare_data()
data_module.setup('fit')

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

for i, batch in enumerate(train_loader):
    print(f"Batch {i+1}")
    print(f"Images batch shape: {batch[0].shape}")
    
    print(f"Masks batch shape: {batch[1].shape}")
    print("---")

Batch 1
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 2
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 3
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 4
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 5
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 6
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 7
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 8
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 9
Images batch shape: torch.Size([8, 3, 650, 1250])
Masks batch shape: torch.Size([8, 650, 1250])
---
Batch 10
Images batch shape:

In [14]:
import pytorch_lightning
import lightning_utilities
#import torch
print(f"PyTorch Lightning Version: {pytorch_lightning.__version__}")
print(f"Lightning Utilities Version: {lightning_utilities.__version__}")
#torch.set_float32_matmul_precision('medium' | 'high')

PyTorch Lightning Version: 2.4.0
Lightning Utilities Version: 0.11.8


In [None]:
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
from pytorch_lightning import LightningModule, Trainer
from torch.optim import Adam
import torch.nn as nn
import torchmetrics

class SARSegmentationModel(LightningModule):
    def __init__(self, learning_rate=5e-5, num_classes=5):
        super().__init__()
        self.save_hyperparameters()

        self.model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)
        self.model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)

        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

        self.train_iou = torchmetrics.JaccardIndex(num_classes=num_classes, task="multiclass")
        self.val_iou = torchmetrics.JaccardIndex(num_classes=num_classes, task="multiclass")

    def forward(self, x):
        return self.model(x)["out"]

    def training_step(self, batch, batch_idx):
        images, masks = batch
        outputs = self(images)
        loss = self.criterion(outputs, masks.long())
        preds = outputs.argmax(dim=1)
        iou = self.train_iou(preds, masks)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_iou", iou, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        outputs = self(images)
        loss = self.criterion(outputs, masks.long())
        preds = outputs.argmax(dim=1)
        iou = self.val_iou(preds, masks)

        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_iou", iou, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Assuming SARDataModule is defined elsewhere and works correctly
data_module = SARDataModule(data_dir="../dataset", batch_size=8, val_split=0.2)

model = SARSegmentationModel(learning_rate=5e-5, num_classes=5)

trainer = Trainer(max_epochs=10, devices=1, accelerator="gpu")
trainer.fit(model, datamodule=data_module)
