In [1]:
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger

from src.data import RandomPatchesDataset
from src.models import UNetLit

In [2]:
batch_size = 50

train_dataset = RandomPatchesDataset(
    '../data/processed/train/img',
    '../data/processed/train/mask',
)
train_data_loader = DataLoader(train_dataset, batch_size = batch_size)

val_dataset = RandomPatchesDataset(
    '../data/processed/val/img',
    '../data/processed/val/mask',
)
val_data_loader = DataLoader(val_dataset, batch_size = batch_size)

test_dataset = RandomPatchesDataset(
    '../data/processed/test/img',
    '../data/processed/test/mask',
)
test_data_loader = DataLoader(test_dataset, batch_size = batch_size)

Succesfully loaded 38 images
Succesfully loaded 4 images
Succesfully loaded 37 images


In [3]:
from backbones_unet.model.unet import Unet
from backbones_unet.model.losses import DiceLoss
from backbones_unet.utils.trainer import Trainer


In [4]:
import pytorch_lightning as pl
from typing import Dict, Any, Optional
from pytorch_toolbelt.losses import BinaryFocalLoss
from ternausnet.models import UNet11
from torch.optim import Adam, lr_scheduler

class UNetLit(pl.LightningModule):
    """ A pytorch lightning wrapper for UNet11.
    """
    def __init__(self, config: Optional[Dict[str, Any]] = None):
        """Initializes the model with hyperparameters.

        Args:
            config:
                A dict of model hyperparameters. It should contain following fields:
                lr - learning rate of Adam optimizer
                eps - term added to denominator to improve numerical stability in Adam optimizer
                step_size - period of learning rate decay in scheduler
                gamma - multiplicative factor of learning rate decay in scheduler
        """
        super().__init__()
        if config:
            self.lr = config["lr"]
            self.eps = config["eps"]
            self.step_size = config["step_size"]
            self.gamma = config["gamma"]

        self.model = Unet(
            backbone='convnext_base', # backbone network name
            in_channels=1,            # input channels (1 for gray-scale images, 3 for RGB, etc.)
            num_classes=1,            # output channels (number of classes in your dataset)
        )
        self.loss_fn = BinaryFocalLoss()

    def forward(self, x):
        outputs = self.model(x)
        return outputs

    def configure_optimizers(self):
        optimizer = Adam(self.model.parameters(), lr=self.lr, eps=self.eps)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.float()
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        preds = (outputs > 0.5).float()
        self.log("train_loss", loss)
        self.log("train_acc", (preds == y).float().mean(), on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.float()
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        preds = (outputs > 0.5).float()
        self.log("val_loss", loss)
        self.log("val_acc", (preds == y).float().mean(), on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        x = x.float()
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        preds = (outputs > 0.5).float()
        self.log("test_loss", loss)
        self.log("test_acc", (preds == y).float().mean(), on_step=False, on_epoch=True)
        return loss


In [5]:
config = {
    "lr": 0.01,
    "eps": 1.0e-08,
    "step_size": 4,
    "gamma": 0.1
}
num_epochs = 10
checkpoints_dir_path = './models'
project = 'cell-nuclei-segmentation'
gpus = 0

model = UNetLit(config)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath=checkpoints_dir_path,
    filename='model-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)

trainer = pl.Trainer(
#    logger=WandbLogger(save_dir="logs/", project=project),
    max_epochs=num_epochs,
    callbacks=[checkpoint_callback],
    )

trainer.fit(model, train_data_loader, val_data_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type            | Params
--------------------------------------------
0 | model   | Unet            | 92.6 M
1 | loss_fn | BinaryFocalLoss | 0     
--------------------------------------------
92.6 M    Trainable params
0         Non-trainable params
92.6 M    Total params
370.589   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
