In [1]:
from albumentations import pytorch
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import CenterCrop

import numpy as np
from cv2 import cv2, transform
import matplotlib.pyplot as plt
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_lightning as pl
from pathlib import Path

from UNet import UNet_2, AttentionUNet

from torchsummary import summary

In [23]:
def print_summary(model, input_size=(3, 128, 128)):
    return summary(
        model=model, 
        input_data=input_size,
        device='cpu',
        col_names=["input_size", "output_size", "num_params", "kernel_size"],
        col_width=20
    )

In [24]:
net = AttentionUNet(3, 1)
print_summary(net, input_size=(3, 1024, 1024))

Layer (type:depth-idx)                   Input Shape          Output Shape         Param #              Kernel Shape
├─DoubleConvSame: 1-1                    [-1, 3, 1024, 1024]  [-1, 64, 1024, 1024] --                   --
|    └─Sequential: 2-1                   [-1, 3, 1024, 1024]  [-1, 64, 1024, 1024] --                   --
|    |    └─Conv2d: 3-1                  [-1, 3, 1024, 1024]  [-1, 64, 1024, 1024] 1,792                [3, 64, 3, 3]
|    |    └─ReLU: 3-2                    [-1, 64, 1024, 1024] [-1, 64, 1024, 1024] --                   --
|    |    └─Conv2d: 3-3                  [-1, 64, 1024, 1024] [-1, 64, 1024, 1024] 36,928               [64, 64, 3, 3]
|    |    └─ReLU: 3-4                    [-1, 64, 1024, 1024] [-1, 64, 1024, 1024] --                   --
├─MaxPool2d: 1-2                         [-1, 64, 1024, 1024] [-1, 64, 512, 512]   --                   --
├─Encoder: 1-3                           [-1, 64, 512, 512]   [-1, 128, 512, 512]  --                   --
|   

Layer (type:depth-idx)                   Input Shape          Output Shape         Param #              Kernel Shape
├─DoubleConvSame: 1-1                    [-1, 3, 1024, 1024]  [-1, 64, 1024, 1024] --                   --
|    └─Sequential: 2-1                   [-1, 3, 1024, 1024]  [-1, 64, 1024, 1024] --                   --
|    |    └─Conv2d: 3-1                  [-1, 3, 1024, 1024]  [-1, 64, 1024, 1024] 1,792                [3, 64, 3, 3]
|    |    └─ReLU: 3-2                    [-1, 64, 1024, 1024] [-1, 64, 1024, 1024] --                   --
|    |    └─Conv2d: 3-3                  [-1, 64, 1024, 1024] [-1, 64, 1024, 1024] 36,928               [64, 64, 3, 3]
|    |    └─ReLU: 3-4                    [-1, 64, 1024, 1024] [-1, 64, 1024, 1024] --                   --
├─MaxPool2d: 1-2                         [-1, 64, 1024, 1024] [-1, 64, 512, 512]   --                   --
├─Encoder: 1-3                           [-1, 64, 512, 512]   [-1, 128, 512, 512]  --                   --
|   

In [None]:
BATCH_SIZE = 8

In [None]:
class NucleiData(Dataset):
    def __init__(
        self, data_dir="./data-science-bowl-2018/stage1_train", transforms=None
    ):
        train_dir = Path(data_dir)
        self.images = list(train_dir.glob("*/images/*.png"))
        self.masks = list(train_dir.glob("*/masks/*.*"))
        self.transforms = transforms

    def __getitem__(self, idx):
        image = cv2.imread(self.images[idx].as_posix(), cv2.IMREAD_COLOR)
        mask = self.get_mask(self.images[idx].parent.parent.glob("masks/*.*"))

        if self.transforms is not None:
            transform = self.transforms(image=image)

        transformed_image = transform["image"]
        transformed_mask = ToTensorV2()(image=mask)

        return transformed_image, transformed_mask["image"]

    def get_mask(self, masks_gen):
        H, W = 128, 128
        target_mask = np.zeros((H, W, 1), dtype=np.uint8)
        for mask in masks_gen:
            curr_mask = cv2.imread(mask.as_posix(), cv2.IMREAD_GRAYSCALE)
            transform = A.Resize(height=H, width=W)(image=curr_mask)
            mask_ = np.expand_dims(transform["image"], axis=-1)
            target_mask = np.maximum(target_mask, mask_)
        return target_mask

    def __len__(self):
        return len(self.images)


class NucleiDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.image_transforms = A.Compose(
            [A.Resize(128, 128), A.Normalize(), A.pytorch.ToTensorV2()]
        )
        self.dims = (3, 128, 128)

    def setup(self, stage) -> None:
        if stage == "fit" or stage is None:
            data = NucleiData(transforms=self.image_transforms)
            lengths = [int(len(data) * 0.8), int(len(data) * 0.2)]
            self.train_data, self.val_data = random_split(data, lengths)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=BATCH_SIZE, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=BATCH_SIZE, num_workers=8)


class LitNuclei(pl.LightningModule):
    def __init__(self):
        super(LitNuclei, self).__init__()
        self.model = UNet_2(3, 1)
        self.loss = nn.BCEWithLogitsLoss()

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters())

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        image, mask = batch

        preds = self.forward(image)

        loss = F.binary_cross_entropy_with_logits(input=preds, target=mask.float())

        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        image, mask = batch

        preds = self.forward(image)

        loss = F.binary_cross_entropy_with_logits(input=preds, target=mask.float())

        self.log("val_loss", loss)

        return loss

In [None]:
model = LitNuclei()
dm = NucleiDataModule()

trainer = pl.Trainer(
    checkpoint_callback=True,
    logger=True,
    max_epochs=1,
    gpus=1,
)

trainer.fit(model, datamodule=dm)