In [1]:
import lightning as L
import numpy as np
import torch
from minerva.data.datasets.supervised_dataset import SupervisedReconstructionDataset
from minerva.data.readers.png_reader import PNGReader
from minerva.data.readers.tiff_reader import TiffReader
from minerva.models.nets.unet import UNet
from minerva.transforms.transform import _Transform
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex
from matplotlib import pyplot as plt

In [2]:
class Padding(_Transform):
    def __init__(self, target_size: int):
        self.target_size = target_size

    def __call__(self, x: np.ndarray) -> np.ndarray:
        h, w = x.shape[:2]
        pad_h = max(0, self.target_size - h)
        pad_w = max(0, self.target_size - w)
        if len(x.shape) == 2:
            padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect")
            padded = np.expand_dims(padded, axis=2)
            padded = torch.from_numpy(padded).float()
        else:
            padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
            padded = torch.from_numpy(padded).float()

        padded = np.transpose(padded, (2, 0, 1))
        return padded


transform = Padding(768)

In [3]:
train_att_reader = TiffReader("f3/images/train")
val_att_reader = TiffReader("f3/images/val")
test_att_reader = TiffReader("f3/images/test")

train_lbl_reader = PNGReader("f3/annotations/train")
val_lbl_reader = PNGReader("f3/annotations/val")
test_lbl_reader = PNGReader("f3/annotations/test")

train_dataset = SupervisedReconstructionDataset(
    [train_att_reader, train_lbl_reader], transform
)
val_dataset = SupervisedReconstructionDataset(
    [val_att_reader, val_lbl_reader], transform
)
test_dataset = SupervisedReconstructionDataset(
    [test_att_reader, test_lbl_reader], transform
)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [4]:
class F3DataModule(L.LightningDataModule):
    def __init__(self, train_dataloader, val_dataloader, test_dataloader):
        super().__init__()
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.test_dl = test_dataloader

    def train_dataloader(self):
        return self.train_dl

    def val_dataloader(self):
        return self.val_dl

    def test_dataloader(self):
        return self.test_dl


data_module = F3DataModule(train_dataloader, val_dataloader, test_dataloader)

In [5]:
model = UNet(
    n_channels=3,
    loss_fn=torch.nn.CrossEntropyLoss(),
    train_metrics={"IoU": JaccardIndex(task="multiclass", num_classes=6)},
    val_metrics={"IoU": JaccardIndex(task="multiclass", num_classes=6)},
    test_metrics={"IoU": JaccardIndex(task="multiclass", num_classes=6)},
)

In [6]:
trainer = L.Trainer(max_epochs=200, fast_dev_run=2)
trainer.fit(model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 2 batch(es). Logging and checkpointing is suppressed.
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type             | Params
----------------------------------------------
0 | backbone | _UNet            | 31.0 M
1 | fc       | Identity         | 0     
2 | loss_fn  | CrossEntropyLoss | 0     
----------------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 

Training: |          | 0/? [00:00<?, ?it/s]

RuntimeError: bincount only supports 1-d non-negative integral inputs.

In [None]:
trainer.test(model, data_module.test_dataloader())

In [None]:
preds = trainer.predict(model, data_module.test_dataloader())

len(preds)

In [None]:
image = torch.argmax(preds[1][3], dim=1, keepdim=True)

In [None]:
plt.imshow(image[0, 0, :, :])
plt.show()