In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Dict, Union
import os
import mlflow
import pytorch_lightning as pl
import torch
from torch import optim
import numpy as np
from PIL import Image
from torch import nn, optim
from torch.nn import functional as F
from torchvision.models import vgg19, vgg19_bn, VGG19_Weights, VGG19_BN_Weights

class TableNetModule(pl.LightningModule):
    """
    Pytorch Lightning Module for TableNet.
    """

    def __init__(
        self,
        optimizer: Union[optim.SGD, optim.Adam],
        optimizer_params: Dict,
        scheduler: Union[
            optim.lr_scheduler.OneCycleLR, optim.lr_scheduler.ReduceLROnPlateau
        ],
        scheduler_params: Dict,
        scheduler_interval: str,
        num_class: int = 1,
        batch_norm: bool = False,
    ):
        """
        Initialize TableNet Module.

        Args:
            optimizer
            optimizer_params
            scheduler
            scheduler_params
            scheduler_interval (str):
            num_class (int): Number of classes per point.
            batch_norm (bool): Select VGG with or without batch normalization.
        """
        super().__init__()
        self.save_hyperparameters()

        self.model = TableNet(num_class, batch_norm)
        self.num_class = num_class
        self.dice_loss = DiceLoss()

        self.optimizer = optimizer
        self.optimizer_params = optimizer_params
        self.scheduler = scheduler
        self.scheduler_params = scheduler_params
        self.scheduler_interval = scheduler_interval

    def forward(self, batch):
        """
        Perform forward-pass.

        Args:
            batch (tensor): Batch of images to perform forward-pass.

        Returns (Tuple[tensor, tensor]): Table, Column prediction.
        """
        return self.model(batch)

    def training_step(self, batch, batch_idx):
        """
        Training step.

        Args:
            batch (List[Tensor]): Data for training.
            batch_idx (int): batch index.

        Returns: Tensor
        """
        samples, labels_table, labels_column = batch
        output_table, output_column = self.forward(samples)

        loss_table = self.dice_loss(output_table, labels_table)
        loss_column = self.dice_loss(output_column, labels_column)

        return loss_table + loss_column

    def validation_step(self, batch, batch_idx):
        """
        Validation step.

        Args:
            batch (List[Tensor]): Data for training.
            batch_idx (int): batch index.

        Returns: Tensor
        """
        samples, labels_table, labels_column = batch
        output_table, output_column = self.forward(samples)

        loss_table = self.dice_loss(output_table, labels_table)
        loss_column = self.dice_loss(output_column, labels_column)

        return loss_table + loss_column

    def test_step(self, batch, batch_idx):
        """
        Test step.

        Args:
            batch (List[Tensor]): Data for training.
            batch_idx (int): batch index.

        Returns: Tensor
        """
        samples, labels_table, labels_column = batch
        output_table, output_column = self.forward(samples)

        loss_table = self.dice_loss(output_table, labels_table)
        loss_column = self.dice_loss(output_column, labels_column)

        return loss_table + loss_column

    def configure_optimizers(self):
        """
        Configure optimizer for pytorch lighting.

        Returns: optimizer and scheduler for pytorch lighting.

        """
        optimizer = self.optimizer(self.parameters(), **self.optimizer_params)
        scheduler = self.scheduler(optimizer, **self.scheduler_params)
        scheduler = {
            "scheduler": scheduler,
            "monitor": "validation_loss",
            "interval": self.scheduler_interval,
        }

        return [optimizer], [scheduler]

In [23]:
from torch import nn

class TableNet(nn.Module):
    """
    TableNet.
    """

    def __init__(self, num_class: int, batch_norm: bool = False):
        """
        Initialize TableNet.

        Args:
            num_class (int): Number of classes per point.
            batch_norm (bool): Select VGG with or without batch normalization.
        """
        super().__init__()

        self.vgg = (
            vgg19(weights=VGG19_Weights.DEFAULT).features
            if not batch_norm
            else vgg19_bn(weights=VGG19_BN_Weights.DEFAULT).features
        )
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.feature_maps_ids = [18, 27] if not batch_norm else [26, 39]
        self.table_decoder = TableDecoder(num_class)
        self.column_decoder = ColumnDecoder(num_class)

    def forward(self, x):
        """
        Forward pass.

        Args:
            x (tensor): Batch of images to perform forward-pass.

        Returns (Tuple[torch.Tensor, torch.Tensor]): Table, Column prediction.
        """
        feature_maps = []
        with torch.no_grad():
            for i, layer in enumerate(self.vgg):
                x = layer(x)
                if i in self.feature_maps_ids:
                    feature_maps.append(x)

        x_table = self.table_decoder(x, feature_maps)
        table_output = torch.sigmoid(x_table)

        x_column = self.column_decoder(x, feature_maps)
        column_output = torch.sigmoid(x_column)
        return table_output, column_output


class ColumnDecoder(nn.Module):
    """
    Column Decoder.
    """

    def __init__(self, num_classes: int):
        """
        Initialize Column Decoder.

        Args:
            num_classes (int): Number of classes per point.
        """
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.8),
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(inplace=True),
        )
        self.layer = nn.ConvTranspose2d(
            1280, num_classes, kernel_size=2, stride=2, dilation=1
        )

    def forward(self, x, pools):
        """
        Forward pass.

        Args:
            x (tensor): Batch of images to perform forward-pass.
            pools (Tuple[tensor, tensor]): The 3 and 4 pooling layer
                from VGG-19.

        Returns (tensor): Forward-pass result tensor.

        """
        pool_3, pool_4 = pools
        x = self.decoder(x)
        x = F.interpolate(x, scale_factor=2)
        x = torch.cat([x, pool_4], dim=1)
        x = F.interpolate(x, scale_factor=2)
        x = torch.cat([x, pool_3], dim=1)
        x = F.interpolate(x, scale_factor=2)
        x = F.interpolate(x, scale_factor=2)
        return self.layer(x)


class TableDecoder(ColumnDecoder):
    """
    Table Decoder.
    """

    def __init__(self, num_classes):
        """
        Initialize Table decoder.

        Args:
            num_classes (int): Number of classes per point.
        """
        super().__init__(num_classes)
        self.decoder = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(inplace=True),
        )

In [20]:

class ColumnDecoder(nn.Module):
    """
    Column Decoder.
    """

    def __init__(self, num_classes: int):
        """
        Initialize Column Decoder.

        Args:
            num_classes (int): Number of classes per point.
        """
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Dropout(0.8),
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(inplace=True),
        )
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.layer = nn.ConvTranspose2d(
            64, num_classes, kernel_size=2, stride=2, dilation=1
        )

    def forward(self, x, pools):
        """
        Forward pass.

        Args:
            x (tensor): Batch of images to perform forward-pass.
            pools (Tuple[tensor, tensor]): The 3 and 4 pooling layer
                from VGG-19.

        Returns (tensor): Forward-pass result tensor.

        """
        pool_3, pool_4 = pools
        x = self.decoder(x)
        x = F.interpolate(x, scale_factor=2)
        x = torch.cat([x, pool_4], dim=1)
        x = F.interpolate(x, scale_factor=2)
        x = self.conv_1(x)
        x = torch.cat([x, pool_3], dim=1)
        x = F.interpolate(x, scale_factor=2)
        x = F.interpolate(x, scale_factor=2)
        x = self.conv_2(x)
        return self.layer(x)


class TableDecoder(ColumnDecoder):
    """
    Table Decoder.
    """

    def __init__(self, num_classes):
        """
        Initialize Table decoder.

        Args:
            num_classes (int): Number of classes per point.
        """
        super().__init__(num_classes)
        self.decoder = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=1),
            nn.ReLU(inplace=True),
        )

In [24]:
class DiceLoss(nn.Module):
    """
    Dice loss.
    """

    def __init__(self):
        """
        Dice Loss.
        """
        super().__init__()

    def forward(self, inputs, targets, smooth=1):
        """
        Calculate loss.

        Args:
            inputs (tensor): Output from the forward pass.
            targets (tensor): Labels.
            smooth (float): Value to smooth the loss.

        Returns (tensor): Dice loss.

        """
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.0 * intersection + smooth) / (
            inputs.sum() + targets.sum() + smooth
        )

        return 1 - dice

In [25]:
import yaml
import sys
sys.path.append("../src/extraction/")
import albumentations as album
from pathlib import Path
import torch
import gc
import mlflow
import pytorch_lightning as pl
from albumentations.pytorch.transforms import ToTensorV2
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
)

from tablenet.marmot import MarmotDataModule
from optimizers import optimizers
from schedulers import schedulers


# Parameters
with open("../config/tablenet_config/tablenet_config_cycle.yaml", "r") as stream:
    config = yaml.safe_load(stream)
batch_size = config["batch_size"]
max_epochs = config["max_epochs"]
num_sanity_val_steps = config["num_sanity_val_steps"]
patience = config["patience"]
batch_norm = config["batch_norm"]
fp_data = config["fp_data"]

optimizer_params = config["optimizer_params"]
optimizer = optimizer_params.pop("optimizer")
optimizer = optimizers[optimizer]

scheduler_params = config["scheduler_params"]
scheduler = scheduler_params.pop("scheduler")
scheduler_interval = scheduler_params.pop("interval")
scheduler = schedulers[scheduler]

torch.cuda.empty_cache()
gc.collect()

image_size = (896, 896)
transforms_augmentation = album.Compose(
    [
        album.Resize(1024, 1024, always_apply=True),
        album.RandomResizedCrop(
            *image_size, scale=(0.7, 1.0), ratio=(0.7, 1)
        ),
        album.HorizontalFlip(),
        album.VerticalFlip(),
        album.Normalize(),
        ToTensorV2(),
    ]
)

transforms_preprocessing = album.Compose(
    [
        album.Resize(*image_size, always_apply=True),
        album.Normalize(),
        ToTensorV2(),
    ]
)

# Data for the training pipeline
# Clean up this code
data_dir = "../data/marmot_data"
siren_test = [
    "305756413",
    "324084698",
    "326300159",
    "331154765",
    "333916385",
    "334303823",
    "344066733",
    "393525852",
    "393712286",
    "411787567",
    "414728337",
    "552065187",
    "552081317",
    "702012956",
    "797080850",
]
test_data = [
    Path(data_dir).joinpath(siren + ".bmp") for siren in siren_test
]

train_data = [
    path
    for path in (
        list(Path(data_dir).glob("*.png"))
        + list(Path(data_dir).glob("*.bmp"))
    )
    if path not in test_data
]

if not fp_data:
    train_data = [path for path in train_data if len(path.name) > 13]

# Data module
data_module = MarmotDataModule(
    train_data=train_data,
    test_data=test_data,
    transforms_preprocessing=transforms_preprocessing,
    transforms_augmentation=transforms_augmentation,
    batch_size=batch_size,
    num_workers=0,
)  # type: ignore

model = TableNetModule(
    batch_norm=batch_norm,
    optimizer=optimizer,
    optimizer_params=optimizer_params,
    scheduler=scheduler,
    scheduler_params=scheduler_params,
    scheduler_interval=scheduler_interval,
)

checkpoint_callback = ModelCheckpoint(
    monitor="validation_loss", save_top_k=1, save_last=True, mode="min"
)
early_stop_callback = EarlyStopping(
    monitor="validation_loss", mode="min", patience=patience
)
lr_monitor = LearningRateMonitor(logging_interval="step")


trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    callbacks=[lr_monitor, checkpoint_callback, early_stop_callback],
    max_epochs=max_epochs,
    num_sanity_val_steps=num_sanity_val_steps,
    accumulate_grad_batches=2,
    precision=16,
)

trainer.fit(model, datamodule=data_module)
trainer.test(datamodule=data_module)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type     | Params
---------------------------------------
0 | model     | TableNet | 20.8 M
1 | dice_loss | DiceLoss | 0     
---------------------------------------
798 K     Trainable params
20.0 M    Non-trainable params
20.8 M    Total params
83.290    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

ValueError: `.test(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.