In [1]:
from torchsummary import summary
from mobileresunet import MobileResUNet
from resunet import ResUNet
from unet import UNet
from datasets import EchoNetDataset
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

  rank_zero_deprecation(


In [2]:
class LitSeg(pl.LightningModule):
    def __init__(self,
                 backbone,
                 learning_rate=1e-4,
                 root_dir="/home/tienyu/data/EchoNet-Dynamic"):
        super().__init__()
        self.backbone = backbone
        self.learning_rate = learning_rate
        self.root_dir = root_dir

    def forward(self, x):
        return self.backbone(x)

    def training_step(self, batch, batch_idx):
        inputs, _, masks, _ = batch
        input_frames = torch.cat(inputs[1:])
        masks = torch.cat(masks).float()

        logits = self(input_frames)
        loss = F.binary_cross_entropy(logits, masks)

        tensorboard_logs = {'loss': {'train': loss}}
        self.log('loss', loss, logger=True)
        return {"loss": loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        inputs, _, masks, _ = batch
        input_frames = torch.cat(inputs[1:])
        masks = torch.cat(masks).float()

        logits = self(input_frames)
        loss = F.binary_cross_entropy(logits, masks)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True, logger=True)
        return loss

    def validation_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'loss': {'val': val_loss}}
        return {"val_loss": val_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    # Data hooks
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.echo_train = EchoNetDataset(root_dir=self.root_dir,
                                             split="train")
            self.echo_val = EchoNetDataset(root_dir=self.root_dir, split="val")

        if stage == 'test' or stage is None:
            self.echo_test = EchoNetDataset(root_dir=self.root_dir,
                                            split="test")

    def train_dataloader(self):
        return DataLoader(self.echo_train,
                          batch_size=16,
                          shuffle=True,
                          num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.echo_val, batch_size=16, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.echo_test, batch_size=16, num_workers=4)

In [None]:
AVAIL_GPUS = min(1, torch.cuda.device_count())
backbone = MobileResUNet(1, 1)
model = LitSeg(backbone=backbone)
tb_logger = pl_loggers.TensorBoardLogger('logs_mobileresunet/')
model_checkpoint = ModelCheckpoint(monitor='val_loss', every_n_val_epochs=1)
trainer = pl.Trainer(gpus=AVAIL_GPUS,
                     max_epochs=30,
                     progress_bar_refresh_rate=20,
                     logger=tb_logger,
                     callbacks=[model_checkpoint])
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type          | Params
-------------------------------------------
0 | backbone | MobileResUNet | 4.4 M 
-------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total params
17.758    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [None]:
trainer.test()

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir logs_cspresxunet/

## Dev

In [None]:
# """
#    Author: Aaron Liu
#    Email: tl254@duke.edu
#    Created on: June 29 2021
#    Code structure reference: https://github.com/milesial/Pytorch-UNet
# """

# import torch
# import torch.nn as nn
# import torch.nn.functional as F


# class LevelBlock(nn.Module):
#     """(BN ==> ReLU ==> Conv) x 2"""
#     def __init__(
#             self,
#             in_channels,
#             out_channels,
#             stride=(1, 1),
#     ):
#         super().__init__()

#         self.stride = stride
#         self.activation = nn.ReLU(inplace=True)

#         self.bn1 = nn.BatchNorm2d(in_channels)
#         self.stacked_blocks1 = nn.Conv2d(in_channels,
#                                          out_channels,
#                                          kernel_size=3,
#                                          padding=1,
#                                          stride=self.stride[0],
#                                          bias=False)
#         self.bn2 = nn.BatchNorm2d(out_channels)
#         self.stacked_blocks2 = nn.Conv2d(out_channels,
#                                          out_channels,
#                                          kernel_size=1,
#                                          stride=self.stride[1],
#                                          bias=False)

#     def forward(self, x):
#         x = self.stacked_blocks1(self.activation(self.bn1(x)))
#         x = self.stacked_blocks2(self.activation(self.bn2(x)))

#         return x


# class UpSamplingConcatenate(nn.Module):
#     """Upscaling"""
#     def __init__(self, in_channels, out_channels):
#         super().__init__()

#         self.up = nn.ConvTranspose2d(in_channels,
#                                      out_channels,
#                                      kernel_size=2,
#                                      stride=2)

#     def forward(self, x1, x2):
#         x1 = self.up(x1)
#         diffY = x2.size()[2] - x1.size()[2]
#         diffX = x2.size()[3] - x1.size()[3]
#         x1 = F.pad(
#             x1,
#             [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

#         x = torch.cat([x2, x1], dim=1)

#         return x

In [None]:
# """
#    Author: Aaron Liu
#    Email: tl254@duke.edu
#    Created on: June 29 2021
# """

# import torch.nn as nn
# import torch.nn.functional as F

# # from .cspresunet_parts import LevelBlock, UpSamplingConcatenate


# class CSPResUNet(nn.Module):
#     def __init__(self, n_channels, n_classes):
#         super(CSPResUNet, self).__init__()
#         self.n_channels = n_channels
#         self.n_classes = n_classes

#         # Encoding
#         self.level1 = nn.Sequential(
#             nn.Conv2d(n_channels, 64, kernel_size=3, padding=1, bias=False),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(64, 64, kernel_size=1, bias=False),
#         )
#         self.level2 = LevelBlock(64, 128, stride=(2, 1))
#         self.level3 = LevelBlock(128, 256, stride=(2, 1))
#         self.level4 = LevelBlock(256, 512, stride=(2, 1))
#         self.level5 = LevelBlock(512, 256, stride=(1, 1))
#         self.level6 = LevelBlock(256, 128, stride=(1, 1))
#         self.level7 = LevelBlock(128, 64, stride=(1, 1))

#         self.up1 = UpSamplingConcatenate(512, 256)
#         self.up2 = UpSamplingConcatenate(256, 128)
#         self.up3 = UpSamplingConcatenate(128, 64)

#         self.shortcut1 = nn.Conv2d(n_channels, 64, kernel_size=1)
#         self.shortcut2 = nn.Conv2d(64, 128, kernel_size=1, stride=2)
#         self.shortcut3 = nn.Conv2d(128, 256, kernel_size=1, stride=2)
#         self.shortcut5 = nn.Conv2d(512, 256, kernel_size=1)
#         self.shortcut6 = nn.Conv2d(256, 128, kernel_size=1)
#         self.shortcut7 = nn.Conv2d(128, 64, kernel_size=1)

#         self.outconv = nn.Conv2d(64, n_classes, kernel_size=1)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x):
#         # Encoding
#         x1 = self.level1(x)
        
#         x2_in = x1 + self.shortcut1(x)
#         x2 = self.level2(x2_in)
#         x3_in = x2 + self.shortcut2(x1)
#         x3 = self.level3(x3_in)
#         x4_in = x3 + self.shortcut3(x2)

#         # Bridge
#         x4 = self.level4(x4_in)

#         # Decoding
#         x_cat = self.up1(x4, x4_in)
#         x5 = self.level5(x_cat)
#         x_cat = self.up2(x5 + self.shortcut5(x_cat), x3_in)
#         x6 = self.level6(x_cat)
#         x_cat = self.up3(x6 + self.shortcut6(x_cat), x2_in)
#         x7 = self.level7(x_cat)
#         x = self.outconv(x7 + self.shortcut7(x_cat))
        
#         # Sigmoid
#         x = self.sigmoid(x)

#         return x

In [None]:
# model = UNet(1,1)

In [None]:
# x = torch.rand(1,1,112,112)

In [None]:
# model(x).shape

In [None]:
# summary(model, input_size=(1,112,112), device='cpu')