In [1]:
import io
import os
import numpy as np

from PIL import Image
from typing import Dict, List, Optional, OrderedDict, Tuple

from datasets import load_dataset, Features, ClassLabel, Array3D

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from torchvision.utils import save_image

from fid_score import calculate_fretchet
from inception import InceptionV3


In [2]:
def preprocess_data(examples):
    images = [np.array(Image.open(io.BytesIO(example))) for example in examples["image"]]
    labels = examples["label"]
    return {"image": images, "label": labels}

In [3]:
ALL_IMAGES = []

img_size = 640
batch_size = 128
normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
latent_size = 512
data_dir = "ChainYo/rvl-cdip-only-invoices"

transforms = transforms.Compose(
    [
        transforms.Resize(img_size), 
        transforms.CenterCrop(img_size), 
        transforms.ToTensor(), 
        transforms.Normalize(*normalize)
    ]
)
features = Features({
    "image": Array3D(shape=(3, 640, 640), dtype="int64"),
    "label": ClassLabel(num_classes=1, names="invoice"),
})
dataset = load_dataset(data_dir)
print(dataset)
# pil_dataset = dataset.map(preprocess_data, remove_columns=dataset.column_names, features=features, batched=True, batch_size=batch_size)
# transformed_dataset = pil_dataset.with_transform(transforms)
# transformed_dataset.set_format(type="torch", device="cuda")

# train_dataloader = torch.utils.data.DataLoader(
#     transformed_dataset["train"], batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
# )

ValueError: ClassLabel number of names do not match the defined num_classes. Got 7 names VS 1 num_classes

In [None]:
def denormalize(input_image_tensors: torch.Tensor) -> torch.Tensor:
    """
    Denormalizes the input image tensors.

    Parameters
    ----------
    input_image_tensors : torch.Tensor
        The input image tensors.
    
    Returns
    -------
    torch.Tensor
        The denormalized image tensors.
    """
    return input_image_tensors.mul(normalize[1][0]).add(normalize[0][0])


def save_samples(index: int, sample_images: torch.Tensor) -> None:
    """
    Saves the generated samples.

    Parameters
    ----------
    index : int
        The index of the sample.
    sample_images : torch.Tensor
        The generated sample images.
    """
    fake_name = f"generated-images-{index}.png"
    save_image(denormalize(sample_images[-64:]), os.path.join("generated", fake_name), nrow=8)

In [None]:
class Discriminator(nn.Module):
    def __init__(
        self,
        hidden_size: Optional[int] = 64,
        channels: Optional[int] = 3,
        kernel_size: Optional[int] = 4,
        stride: Optional[int] = 2,
        padding: Optional[int] = 1,
        negative_slope: Optional[float] = 0.2,
        bias: Optional[bool] = False,
    ):
        """
        Initializes the discriminator.

        Parameters
        ----------
        hidden_size : int, optional
            The input size. (the default is 64)
        channels : int, optional
            The number of channels. (default: 3)
        kernel_size : int, optional
            The kernal size. (default: 4)
        stride : int, optional
            The stride. (default: 2)
        padding : int, optional
            The padding. (default: 1)
        negative_slope : float, optional
            The negative slope. (default: 0.2)
        bias : bool, optional
            Whether to use bias. (default: False)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.negative_slope = negative_slope
        self.bias = bias

        self.model = nn.Sequential(
            # input size: (channels, 64, 64)
            nn.Conv2d(
                self.channels, self.hidden_size, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (hidden_size, 32, 32)
            nn.Conv2d(
                hidden_size, hidden_size * 2, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(hidden_size * 2),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (hidden_size * 2, 16, 16)
            nn.Conv2d(
                hidden_size * 2, hidden_size * 4,
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(hidden_size * 4),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (hidden_size * 4, 8, 8)
            nn.Conv2d(
                hidden_size * 4, hidden_size * 8, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(hidden_size * 8),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (hidden_size * 8, 4, 4)
            nn.Conv2d(hidden_size * 8, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), # output size: (1, 1, 1)
            nn.Flatten(),
            nn.Sigmoid()
        )

    
    def forward(self, input_img: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        input_img : torch.Tensor
            The input image.

        Returns
        -------
        torch.Tensor
            The output.
        """
        logits = self.model(input_img)
        return logits

In [None]:
class Generator(nn.Module):
    def __init__(
        self,
        hidden_size: Optional[int] = 64,
        latent_size: Optional[int] = 128,
        channels: Optional[int] = 3,
        kernel_size: Optional[int] = 4,
        stride: Optional[int] = 2,
        padding: Optional[int] = 1,
        bias: Optional[bool] = False,
    ):
        """
        Initializes the generator.

        Parameters
        ----------
        hidden_size : int, optional
            The hidden size. (default: 64)
        latent_size : int, optional
            The latent size. (default: 128)
        channels : int, optional
            The number of channels. (default: 3)
        kernel_size : int, optional
            The kernel size. (default: 4)
        stride : int, optional
            The stride. (default: 2)
        padding : int, optional
            The padding. (default: 1)
        bias : bool, optional
            Whether to use bias. (default: False)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.bias = bias

        self.model = nn.Sequential(
            # input size: (latent_size=128, 1, 1)
            nn.ConvTranspose2d(
                self.latent_size, self.hidden_size * 8, 
                kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size * 8),
            nn.ReLU(inplace=True),

            # input size: (hidden_size * 8, 4, 4)
            nn.ConvTranspose2d(
                self.hidden_size * 8, self.hidden_size * 4, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size * 4),
            nn.ReLU(inplace=True),

            # input size: (hidden_size * 4, 8, 8)
            nn.ConvTranspose2d(
                self.hidden_size * 4, self.hidden_size * 2, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size * 2),
            nn.ReLU(inplace=True),

            # input size: (self.hidden_size * 2, 16, 16)
            nn.ConvTranspose2d(
                self.hidden_size * 4, self.hidden_size, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size),
            nn.ReLU(inplace=True),

            # input size: (64, 32, 32)
            nn.ConvTranspose2d(
                self.hidden_size, self.channels, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.Tanh() # output size: (channels, 64, 64)
        )

    
    def forward(self, input_img: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        input_img : torch.Tensor
            The input image.

        Returns
        -------
        torch.Tensor
            The output.
        """
        return self.model(input_img)

In [None]:
class DocuGAN(pl.LightningModule):
    def __init__(
        self,
        hidden_size: Optional[int] = 64,
        latent_size: Optional[int] = 128,
        num_channel: Optional[int] = 3,
        learning_rate: Optional[float] = 0.0002,
        batch_size: Optional[int] = 128,
        bias1: Optional[float] = 0.5,
        bias2: Optional[float] = 0.999,
    ):
        """
        Initializes the LightningGan.

        Parameters
        ----------
        hidden_size : int, optional
            The hidden size. (default: 64)
        latent_size : int, optional
            The latent size. (default: 128)
        num_channel : int, optional
            The number of channels. (default: 3)
        learning_rate : float, optional
            The learning rate. (default: 0.0002)
        batch_size : int, optional
            The batch size. (default: 128)
        bias1 : float, optional
            The bias1. (default: 0.5)
        bias2 : float, optional
            The bias2. (default: 0.999)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.num_channel = num_channel
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.bias1 = bias1
        self.bias2 = bias2
        self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1)
        self.save_hyperparameters()

        self.generator = Generator(
            latent_size=self.latent_size, channels=self.num_channel, hidden_size=self.hidden_size
        )
        self.generator.apply(self.weights_init)
        
        self.discriminator = Discriminator(channels=self.num_channel, hidden_size=self.hidden_size)
        self.discriminator.apply(self.weights_init)

        self.model = InceptionV3() # For FID metric


    def weights_init(self, m: nn.Module) -> None:
        """
        Initializes the weights.

        Parameters
        ----------
        m : nn.Module
            The module.
        """
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("BatchNorm") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)


    def adversarial_loss(self, preds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculates the adversarial loss.

        Parameters
        ----------
        preds : torch.Tensor
            The predictions.
        labels : torch.Tensor
            The labels.

        Returns
        -------
        torch.Tensor
            The adversarial loss.
        """
        return F.binary_cross_entropy(preds, labels)

    
    def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]:
        """
        Configures the optimizers.

        Returns
        -------
        Tuple[List[torch.optim.Optimizer], List]
            The optimizers and the LR schedulers.
        """
        opt_generator = torch.optim.Adam(
            self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
        )
        opt_discriminator = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
        )
        return [opt_generator, opt_discriminator], []

    
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        z : torch.Tensor
            The latent vector.

        Returns
        -------
        torch.Tensor
            The output.
        """
        return self.generator(z)

    
    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int
    ) -> Dict:
        """
        Training step.

        Parameters
        ----------
        batch : Tuple[torch.Tensor, torch.Tensor]
            The batch.
        batch_idx : int
            The batch index.
        optimizer_idx : int
            The optimizer index.

        Returns
        -------
        Dict
            The training loss.
        """
        real_images, _ = batch

        if optimizer_idx == 0: # Only train the generator
            fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
            fake_random_noise = fake_random_noise.type_as(real_images)
            fake_images = self(fake_random_noise)

            # Try to fool the discriminator
            preds = self.discriminator(fake_images)
            targets = torch.ones(self.batch_size, 1)
            targets = targets.type_as(real_images)
            loss = self.adversarial_loss(preds, targets)
            self.log("g_loss", loss, on_step=False, on_epoch=True)

            tqdm_dict = {"g_loss": loss}
            output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output

        elif optimizer_idx == 1: # Only train the discriminator
            real_preds = self.discriminator(real_images)
            real_targets = torch.ones(real_images.size(0), 1)
            real_targets = real_targets.type_as(real_images)
            real_loss = self.adversarial_loss(real_preds, real_targets)

            # Generate fake images
            real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
            real_random_noise = real_random_noise.type_as(real_images)
            fake_images = self(real_random_noise)

            # Pass fake images though discriminator
            fake_targets = torch.zeros(fake_images.size(0), 1)
            fake_targets = fake_targets.type_as(real_images)
            fake_preds = self.discriminator(fake_images)
            fake_loss = self.adversarial_loss(fake_preds, fake_targets)

            # Update discriminator weights
            loss = real_loss + fake_loss
            self.log("d_loss", loss, on_step=False, on_epoch=True)

            tqdm_dict = {"d_loss": loss}
            output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output

    
    def on_epoch_end(self):
        """
        Called at the end of an epoch.
        """
        z = self.validation.type_as(self.generator.model[0].weight)
        sample_images = self(z)
        ALL_IMAGES.append(sample_images.detach().cpu())
        save_samples(self.current_epoch, sample_images)
        self.logger[1].log_image(key=f"images-epoch{self.current_epoch}", images=[sample_images])

In [None]:
seed_everything(42)
gpus = 1 if torch.cuda.is_available() else 0

tf_logger = TensorBoardLogger("logs", name="docugan")
wandb_logger = WandbLogger(project="docugan")

model = DocuGAN()

trainer = pl.Trainer(
    gpus=gpus,
    max_epochs=500,
    progress_bar_refresh_rate=25,
    # callbacks=[early_stopping, checkpointer],
    logger=[tf_logger, wandb_logger],
)
trainer.fit(model, train_dataloader)

Global seed set to 42
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchainyo-mleng[0m (use `wandb login --relogin` to force relogin)


  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 3.9 M 
1 | discriminator | Discriminator | 2.8 M 
2 | model         | InceptionV3   | 21.8 M
------------------------------------------------
6.7 M     Trainable params
21.8 M    Non-trainable params
28.5 M    Total params
113.954   Total estimated model params size (MB)


Epoch 0:   0%|          | 0/196 [00:00<?, ?it/s]

AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1765, in __getitem__
    return self._getitem(
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 1750, in _getitem
    formatted_output = format_table(
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 532, in format_table
    return formatter(pa_table, query_type=query_type)
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/formatting/formatting.py", line 281, in __call__
    return self.format_row(pa_table)
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/formatting/torch_formatter.py", line 58, in format_row
    return self.recursive_tensorize(row)
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/formatting/torch_formatter.py", line 54, in recursive_tensorize
    return map_nested(self._recursive_tensorize, data_struct, map_list=False)
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 314, in map_nested
    mapped = [
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 315, in <listcomp>
    _single_map_nested((function, obj, types, None, True, None))
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 267, in _single_map_nested
    return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar}
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 267, in <dictcomp>
    return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar}
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/utils/py_utils.py", line 251, in _single_map_nested
    return function(data_struct)
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/formatting/torch_formatter.py", line 51, in _recursive_tensorize
    return self._tensorize(data_struct)
  File "/home/chainyo/miniconda3/envs/gan-bird/lib/python3.8/site-packages/datasets/formatting/torch_formatter.py", line 38, in _tensorize
    if np.issubdtype(value.dtype, np.integer):
AttributeError: 'bytes' object has no attribute 'dtype'
