# First try to implement a GAN in python lightning
---

## Imports and stuff

In [1]:
import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import Normalize
from torch.utils.data import DataLoader, random_split
from monai.data import (CacheDataset, DataLoader, ImageDataset, PersistentDataset,
                        pad_list_data_collate)
from monai.transforms import (Compose, EnsureChannelFirst, Resize, ScaleIntensity, ToTensor,
                              Orientation, ScaleIntensityRange)

from src.handlers import Handler, OpHandler, TciaHandler

BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)

NUM_WORKERS

6

In [2]:
import sys

try:
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.append('/content/drive/MyDrive/School/NTU/training')
    is_colab = True
except:
    print('Not a google drive environment')
    is_colab = False

Not a google drive environment


In [3]:
if is_colab:
  BASE_PATH = '/content/drive/MyDrive/School/NTU/training/Data/'
else:
  BASE_PATH = 'Data/'
# ...
TCIA_IMG_SUFFIX = '_PV.nii.gz'
TCIA_LOCATION = BASE_PATH + 'TCIA/'
TCIA_EXCEL_NAME = 'HCC-TACE-Seg_clinical_data-V2.xlsx'
# ...
OP_LOCATION = BASE_PATH + 'OP/'
NIFTI_PATH = 'OP_C+P_nifti'
NNU_NET_PATH = 'OP_C+P_nnUnet'
OP_EXCEL = 'OP_申請建模_1121110_20231223.xlsx'
OP_IMG_SUFFIX = '_VENOUS_PHASE.nii.gz'
OP_MASK_SUFFIX = '_VENOUS_PHASE_seg.nii.gz'
OP_ID_COL_NAME = 'OP_C+P_Tumor識別碼'

## Data module

In [14]:
class ImgDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = Compose([
            EnsureChannelFirst(),
            Resize((512, 512, 20)),
            ScaleIntensity(),
            ToTensor(),
            Normalize((0.1307,), (0.3081,))
        ])

        self.dims = (1, 512, 512)
        self.num_classes = 10

    def prepare_data(self):
        global_handler = Handler()

        tcia = TciaHandler(TCIA_LOCATION, TCIA_IMG_SUFFIX, TCIA_EXCEL_NAME)
        global_handler.add_source(tcia)

        op = OpHandler(OP_LOCATION, NIFTI_PATH, NNU_NET_PATH, OP_IMG_SUFFIX, OP_MASK_SUFFIX, OP_EXCEL, OP_ID_COL_NAME)
        global_handler.add_source(op)

        self.data = global_handler.df

    def setup(self, stage=None):
        imgs = self.data['img'].tolist()
        classes = self.data['class'].tolist()
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            # Define the sizes for the train and test sets
            train_size = int(0.8 * len(self.data))  # 80% for training
            test_size = len(self.data) - train_size  # Remaining 20% for testing
            _full = ImageDataset(
                image_files=imgs,
                labels=classes,
                transform=self.transform,
                # cache_rate=1.0,
                # num_workers=num_workers,
                # cache_dir=BASE_PATH + 'cache'
            )
            self.train_ds, self.val_ds = random_split(_full, [train_size, test_size])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test_ds = ImageDataset(
                image_files=imgs,
                labels=classes,
                transform=self.transform,
                # cache_rate=1.0,
                # num_workers=num_workers,
                # cache_dir=BASE_PATH + 'cache'
            )

    def __default_dl__(self, dataset):
        return DataLoader(
            dataset,
            batch_size=1,
            num_workers=self.num_workers,
            pin_memory=torch.cuda.is_available(),
            collate_fn=pad_list_data_collate
        )

    def train_dataloader(self):
        return self.__default_dl__(self.train_ds)

    def val_dataloader(self):
        return self.__default_dl__(self.val_ds)

    def test_dataloader(self):
        return self.__default_dl__(self.test_ds)

## Generator

In [15]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape, input_dim=None, output_dim=None):
        super().__init__()
        self.img_shape = img_shape

        '''
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                # layers.append(nn.BatchNorm1d(out_feat, 0.8))
                layers.append(nn.InstanceNorm1d(out_feat, affine=True))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )
        '''

        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        print('Generator.forward')
        print('z: ')
        print(z.shape)
        print(z.size())
        print(z.dtype)
        print(z.numel())
        #img = self.model(z)
        #img = img.view(img.size(0), *self.img_shape)
        print('finished Generator.forward')
        #return img
        return self.model(z)

## Discriminator

In [16]:
class Discriminator(nn.Module):
    def __init__(self, img_shape, input_dim=None):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        '''
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
        '''

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        # return validity
        return self.model(img)

## GAN

In [7]:
class GAN(pl.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = BATCH_SIZE,
        **kwargs,
    ):
        super().__init__()
        print('GAN.init')
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        data_shape = (channels, width, height)
        print(data_shape)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        print('GAN.forward')
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch):
        print('training step...')
        imgs, _ = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # log sampled images
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def validation_step(self, batch, batch_idx):
        print('validation skipped')
        pass

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_validation_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("validation/generated_images", grid, self.current_epoch)

## Training

In [None]:
from monai.utils import first

dm = ImgDataModule()
model = GAN(*dm.dims)
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=1,
)

# print(dm)
# print(model)
# summary(model)
print('-' * 10)
print(dm.dims)
dm.prepare_data()
dm.setup()

aux = first(dm.train_dataloader())
print(len(aux))
print(aux[0].size())
print(aux[1].size())

trainer.fit(model, dm)

## Other attempt

In [17]:
# Define the CGAN model
class CGAN(pl.LightningModule):
    def __init__(self, generator, discriminator, latent_dim, lr):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim
        self.lr = lr

    def forward(self, z, labels):
        return self.generator(None, None, torch.cat([z, labels], dim=1))

    def generator_step(self, real_images, labels):
        z = torch.randn(real_images.size(0), self.latent_dim)
        fake_images = self(z, labels)
        fake_preds = self.discriminator(None, torch.cat([fake_images, labels], dim=1))
        g_loss = nn.BCELoss()(fake_preds, torch.ones_like(fake_preds))
        return g_loss

    def discriminator_step(self, real_images, labels):
        z = torch.randn(real_images.size(0), self.latent_dim)
        fake_images = self(z, labels)
        real_preds = self.discriminator(None, torch.cat([real_images, labels], dim=1))
        fake_preds = self.discriminator(None, torch.cat([fake_images, labels], dim=1))
        real_loss = nn.BCELoss()(real_preds, torch.ones_like(real_preds))
        fake_loss = nn.BCELoss()(fake_preds, torch.zeros_like(fake_preds))
        d_loss = (real_loss + fake_loss) / 2
        return d_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_images, labels = batch
        if optimizer_idx == 0:
            loss = self.generator_step(real_images, labels)
        else:
            loss = self.discriminator_step(real_images, labels)
        return loss

    def configure_optimizers(self):
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr)
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr)
        return [g_optimizer, d_optimizer], []

In [18]:
# Prepare the dataset
dm = ImgDataModule()
dm.prepare_data()
dm.setup()

dataloader = dm.train_dataloader()

# Initialize the model
latent_dim = 100
generator = Generator(None, None, input_dim=latent_dim + 1, output_dim=256*256)
discriminator = Discriminator(None, input_dim=256*256 + 1)
model = CGAN(generator, discriminator, latent_dim, lr=0.0002)

# Train the model
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, dataloader)

DEBUG: reading file...
INFO: 105 rows in the excel file
INFO: Removed 3 stage-d elements
DEBUG: Classifying...
DEBUG: Looking for paths against contents
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_011_PV.nii.gz
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_031_PV.nii.gz
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_082_PV.nii.gz
DEBUG: None
DEBUG: reading file Data/OP/OP_申請建模_1121110_20231223.xlsx
INFO: 200 rows in the excel file
INFO: Removed 55 stage-d elements
DEBUG: Classifying...
DEBUG: Looking for paths against contents
DEBUG: Searching for mismatch on files vs excel data...
DEBUG: Returning new dataframe
DEBUG: None
DEBUG: required package for reader nrrdreader is not installed, or the version doesn't match requirement.
DEBUG: required package for reader nrrdreader is not installed, or the version doesn't match requirement.


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 99 entries, 0 to 98
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   99 non-null     object
 1   img     99 non-null     object
 2   mask    99 non-null     object
dtypes: object(3)
memory usage: 2.4+ KB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 244 entries, 0 to 243
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   244 non-null    object
 1   img     244 non-null    object
 2   mask    244 non-null    object
dtypes: object(3)
memory usage: 5.8+ KB


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


RuntimeError: Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx` argument from `training_step`, set `self.automatic_optimization = False` and access your optimizers in `training_step` with `opt1, opt2, ... = self.optimizers()`.