In [1]:
import os
import albumentations as album
import torch
import pytorch_lightning as pl
from albumentations.pytorch.transforms import ToTensorV2
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

# from tablenet import MarmotDataModule
from tablenet import TableNetModule

RuntimeError: KeyboardInterrupt: 

In [None]:
from pathlib import Path
from typing import List

import numpy as np
import pytorch_lightning as pl
from albumentations import Compose
from PIL import Image
from torch.utils.data import Dataset, DataLoader


class O7Dataset(Dataset):
    """O7 Dataset."""

    def __init__(self, data: List[Path], transforms: Compose = None) -> None:
        """O7 Dataset initialization.

        Args:
            data (List[Path]): A list of Path.
            transforms (Optional[Compose]): Compose object from albumentations.
        """
        self.data = data
        self.transforms = transforms

    def __len__(self):
        """Dataset Length."""
        return len(self.data)

    def __getitem__(self, item):
        """Get sample data.

        Args:
            item (int): sample id.

        Returns (Tuple[tensor, tensor, tensor]): Image, Table Mask, Column Mask
        """
        sample_id = self.data[item].stem

        image_path  = self.data[item]
#         table_path = self.data[item].parent.parent.joinpath("table_mask", sample_id + ".bmp")
#         column_path = self.data[item].parent.parent.joinpath("column_mask", sample_id + ".bmp")
        table_path  = os.path.join("/home/ubuntu/storage/Doc2Answer/handigit/OCR_tablenet/masks/table/",   str(image_path).split('/')[-1])
        column_path = os.path.join("/home/ubuntu/storage/Doc2Answer/handigit/OCR_tablenet/masks/columns/", str(image_path).split('/')[-1])

        image = np.array(Image.open(image_path))
        table_mask = np.expand_dims(np.array(Image.open(table_path)), axis=2)
        column_mask = np.expand_dims(np.array(Image.open(column_path)), axis=2)
        mask = np.concatenate([table_mask, column_mask], axis=2) / 255
        sample = {"image": image, "mask": mask}
        if self.transforms:
            sample = self.transforms(image=image, mask=mask)

        image = sample["image"]
        mask_table = sample["mask"][:, :, 0].unsqueeze(0)
        mask_column = sample["mask"][:, :, 1].unsqueeze(0)
        return image, mask_table, mask_column


class O7DataModule(pl.LightningDataModule):
    """Pytorch Lightning Data Module for O7."""

    def __init__(self, data_dir: str = "./data", transforms_preprocessing: Compose = None,
                 transforms_augmentation: Compose = None, batch_size: int = 8, num_workers: int = 4):
        """O7  Data Module initialization.

        Args:
            data_dir (str): Dataset directory.
            transforms_preprocessing (Optional[Compose]): Compose object from albumentations applied
             on validation an test dataset.
            transforms_augmentation (Optional[Compose]): Compose object from albumentations applied
             on training dataset.
            batch_size (int): Define batch size.
            num_workers (int): Define number of workers to process data.
        """
        super().__init__()
        self.data = list(Path(data_dir).rglob("*.jpg"))
        self.transforms_preprocessing = transforms_preprocessing
        self.transforms_augmentation = transforms_augmentation
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.setup()

    def setup(self, stage: str = None) -> None:
        """Start training, validation and test datasets.

        Args:
            stage (Optional[str]): Used to separate setup logic for trainer.fit and trainer.test.
        """
        n_samples = len(self.data)
        self.data.sort()
        train_slice = slice(0, int(n_samples * 0.8))
        val_slice = slice(int(n_samples * 0.8), int(n_samples * 0.9))
        test_slice = slice(int(n_samples * 0.9), n_samples)

        self.complaint_train = O7Dataset(self.data[train_slice], transforms=self.transforms_augmentation)
        self.complaint_val = O7Dataset(self.data[val_slice], transforms=self.transforms_preprocessing)
        self.complaint_test = O7Dataset(self.data[test_slice], transforms=self.transforms_preprocessing)

    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        """Create Dataloader.

        Returns: DataLoader
        """
        return DataLoader(self.complaint_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self, *args, **kwargs) -> DataLoader:
        """Create Dataloader.

        Returns: DataLoader
        """
        return DataLoader(self.complaint_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self, *args, **kwargs) -> DataLoader:
        """Create Dataloader.

        Returns: DataLoader
        """
        return DataLoader(self.complaint_test, batch_size=self.batch_size, num_workers=self.num_workers)


In [None]:
image_size = (896, 896)
transforms_augmentation = album.Compose([
    album.Resize(1024, 1024, always_apply=True),
    album.RandomResizedCrop(*image_size, scale=(0.7, 1.0), ratio=(0.7, 1)),
    album.HorizontalFlip(),
    album.VerticalFlip(),
    album.Normalize(),
    ToTensorV2()
])

transforms_preprocessing = album.Compose([
    album.Resize(*image_size, always_apply=True),
    album.Normalize(),
    ToTensorV2()
])

complaint_dataset = O7DataModule(
    data_dir="/home/ubuntu/storage/Doc2Answer/download_from_drive/data/ProcessedO7/",
    transforms_preprocessing=transforms_preprocessing,
    transforms_augmentation=transforms_augmentation,
    batch_size=50
)

In [None]:
model = TableNetModule(batch_norm=False)

EXPERIMENT_NAME = f"{model.__class__.__name__}"
logger = TensorBoardLogger('tb_logs', name=EXPERIMENT_NAME)

checkpoint_callback = ModelCheckpoint(monitor='validation_loss', save_top_k=5, save_last=True, mode="min")
early_stop_callback = EarlyStopping(monitor='validation_loss', mode="min", patience=10)
lr_monitor = LearningRateMonitor(logging_interval='step')

In [None]:
trainer = pl.Trainer(
    callbacks=[lr_monitor, checkpoint_callback, early_stop_callback],
    logger=logger,
    max_epochs=5,
    gpus=1 if torch.cuda.is_available() else None
)
trainer.fit(model, datamodule=complaint_dataset)

In [None]:
trainer.test()