# Second try to implement a GAN in python lightning
---

## Imports and stuff

In [1]:
import logging

# Create a logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# Create a handler to output logs to the console
handler = logging.StreamHandler()

# Create a formatter to format the log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Add the formatter to the handler
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)
logging.getLogger('src.handlers').setLevel(logging.WARNING)

In [10]:
import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import Normalize
from torch.utils.data import DataLoader, random_split
from monai.data import (CacheDataset, DataLoader, ImageDataset, PersistentDataset,
                        pad_list_data_collate)
from monai.transforms import (Compose, EnsureChannelFirst, Resize, ScaleIntensity, ToTensor,
                              Orientation, ScaleIntensityRange)

from src.handlers import Handler, OpHandler, TciaHandler

BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)
# NUM_WORKERS = 1

NUM_WORKERS

6

In [3]:
import sys

try:
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.append('/content/drive/MyDrive/School/NTU/training')
    is_colab = True
except:
    print('Not a google drive environment')
    is_colab = False

Not a google drive environment


In [4]:
if is_colab:
  BASE_PATH = '/content/drive/MyDrive/School/NTU/training/Data/'
else:
  BASE_PATH = 'Data/'
# ...
TCIA_IMG_SUFFIX = '_PV.nii.gz'
TCIA_LOCATION = BASE_PATH + 'TCIA/'
TCIA_EXCEL_NAME = 'HCC-TACE-Seg_clinical_data-V2.xlsx'
# ...
OP_LOCATION = BASE_PATH + 'OP/'
NIFTI_PATH = 'OP_C+P_nifti'
NNU_NET_PATH = 'OP_C+P_nnUnet'
OP_EXCEL = 'OP_申請建模_1121110_20231223.xlsx'
OP_IMG_SUFFIX = '_VENOUS_PHASE.nii.gz'
OP_MASK_SUFFIX = '_VENOUS_PHASE_seg.nii.gz'
OP_ID_COL_NAME = 'OP_C+P_Tumor識別碼'

## Data module

In [5]:
class ImgDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = Compose([
            EnsureChannelFirst(),
            Resize((512, 512, 20)),
            ScaleIntensity(),
            ToTensor(),
            Normalize((0.1307,), (0.3081,))
        ])

        self.dims = (1, 512, 512, 20)  # Update dimensions
        self.num_classes = 3

    def prepare_data(self):
        global_handler = Handler()

        tcia = TciaHandler(TCIA_LOCATION, TCIA_IMG_SUFFIX, TCIA_EXCEL_NAME)
        global_handler.add_source(tcia)

        op = OpHandler(OP_LOCATION, NIFTI_PATH, NNU_NET_PATH, OP_IMG_SUFFIX, OP_MASK_SUFFIX, OP_EXCEL, OP_ID_COL_NAME)
        global_handler.add_source(op)

        self.data = global_handler.df

    def setup(self, stage=None):
        imgs = self.data['img'].tolist()
        classes = self.data['class'].tolist()
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            # Define the sizes for the train and test sets
            train_size = int(0.8 * len(self.data))  # 80% for training
            test_size = len(self.data) - train_size  # Remaining 20% for testing
            _full = ImageDataset(
                image_files=imgs,
                labels=classes,
                transform=self.transform,
                # cache_rate=1.0,
                # num_workers=num_workers,
                # cache_dir=BASE_PATH + 'cache'
            )
            self.train_ds, self.val_ds = random_split(_full, [train_size, test_size])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test_ds = ImageDataset(
                image_files=imgs,
                labels=classes,
                transform=self.transform,
                # cache_rate=1.0,
                # num_workers=num_workers,
                # cache_dir=BASE_PATH + 'cache'
            )

    def __default_dl__(self, dataset):
        return DataLoader(
            dataset,
            batch_size=1,
            num_workers=self.num_workers,
            pin_memory=torch.cuda.is_available(),
            collate_fn=pad_list_data_collate
        )

    def train_dataloader(self):
        return self.__default_dl__(self.train_ds)

    def val_dataloader(self):
        return self.__default_dl__(self.val_ds)

    def test_dataloader(self):
        return self.__default_dl__(self.test_ds)

## Generator

In [6]:
class Generator(nn.Module):
    def __init__(self,  latent_dim, img_shape):
        super().__init__()
        logger.debug(f'Generator with input_dim: {latent_dim} and output_dim: {img_shape} ')
        self.img_shape = img_shape
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        logger.debug('*******generator.forward*****************')
        logger.debug(f'z size: {z.size()}, type: {z.dtype}')
        logger.debug('Calling model sequential...')
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img
        # return self.model(z)

## Discriminator

In [8]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
        # return self.model(img)

## Other attempt

In [9]:
import torch.nn.functional as F
# Define the GAN model
class GAN(pl.LightningModule):
    def __init__(self, generator, discriminator, latent_dim, lr):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim
        self.lr = lr

        self.automatic_optimization = False

    def forward(self, z):
        logger.debug('*******CGAN.forward************')
        return self.generator(z)

    def generator_step(self, real_images):
        logger.debug('************CGAN.generator_step********')
        z = torch.randn(real_images.size(0), self.latent_dim, device=self.device)
        fake_images = self(z)
        fake_preds = self.discriminator(fake_images)
        g_loss = nn.BCELoss()(fake_preds, torch.ones_like(fake_preds))
        return g_loss

    def discriminator_step(self, real_images):
        logger.debug('************CGAN.discriminator_step********')
        z = torch.randn(real_images.size(0), self.latent_dim)
        fake_images = self(z)
        real_preds = self.discriminator(real_images)
        fake_preds = self.discriminator(fake_images)
        real_loss = nn.BCELoss()(real_preds, torch.ones_like(real_preds))
        fake_loss = nn.BCELoss()(fake_preds, torch.zeros_like(fake_preds))
        d_loss = (real_loss + fake_loss) / 2
        return d_loss

    def training_step(self, batch, batch_idx):
        logger.debug('********CGAN.training_step****************')
        real_images, _ = batch
        z = torch.randn(real_images.size(0), self.latent_dim, device=self.device)
        
        logger.debug('z datatype: %s' % z.dtype)

        # Get optimizers
        opt_g, opt_d = self.optimizers()

        # fake_imgs = self.generator(z, labels)

        # Train generator
        if self.global_step % 2 == 0:
            logger.debug('Training generator....')
            fake_images = self(z)
            # fake_preds = self.discriminator(torch.cat([fake_images, labels], dim=1))
            fake_preds = self.discriminator(fake_images)
            g_loss = nn.BCELoss()(fake_preds, torch.ones_like(fake_preds))
            opt_g.zero_grad()
            self.manual_backward(g_loss)
            opt_g.step()
            self.log('g_loss', g_loss, prog_bar=True)
            return g_loss

        # Train discriminator
        else:
            logger.debug('Training discriminator...')
            fake_images = self(z)
            real_preds = self.discriminator(real_images)
            fake_preds = self.discriminator(fake_images)
            real_loss = nn.BCELoss()(real_preds, torch.ones_like(real_preds))
            fake_loss = nn.BCELoss()(fake_preds, torch.zeros_like(fake_preds))
            d_loss = (real_loss + fake_loss) / 2
            opt_d.zero_grad()
            self.manual_backward(d_loss)
            opt_d.step()
            self.log('d_loss', d_loss, prog_bar=True)
            return d_loss

    def configure_optimizers(self):
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr)
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr)
        return [g_optimizer, d_optimizer]

In [11]:
from monai.utils.misc import first

# Prepare the dataset
dm = ImgDataModule()
dm.prepare_data()
dm.setup()

dataloader = dm.train_dataloader()

data = first(dataloader)
logger.debug('----inspecting first element---')
logger.debug(f'type: {type(data)}, length: {len(data)}')
logger.debug('--element at index 0--')
logger.debug(data[0].size())
logger.debug('--element at index 1--')
logger.debug(data[1].size())

# Initialize the model
latent_dim = 100
img_shape = (1, 512, 512, 20)
generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)
model = GAN(generator, discriminator, latent_dim, lr=0.0002)

# Train the model
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, dataloader)

3 files not found


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 99 entries, 0 to 98
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   99 non-null     object
 1   img     99 non-null     object
 2   mask    99 non-null     object
dtypes: object(3)
memory usage: 2.4+ KB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 244 entries, 0 to 243
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   244 non-null    object
 1   img     244 non-null    object
 2   mask    244 non-null    object
dtypes: object(3)
memory usage: 5.8+ KB


2024-09-05 02:17:23,704 - __main__ - DEBUG - ----inspecting first element---
2024-09-05 02:17:23,705 - __main__ - DEBUG - type: <class 'list'>, length: 2
2024-09-05 02:17:23,705 - __main__ - DEBUG - --element at index 0--
2024-09-05 02:17:23,706 - __main__ - DEBUG - torch.Size([1, 1, 512, 512, 20])
2024-09-05 02:17:23,707 - __main__ - DEBUG - --element at index 1--
2024-09-05 02:17:23,707 - __main__ - DEBUG - torch.Size([1])
2024-09-05 02:17:23,708 - __main__ - DEBUG - Generator with input_dim: 100 and output_dim: (1, 512, 512, 20) 
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name          | Type          | Params | Mode 
--------------------------------------------------------
0 | generator     | Generator     | 1.3 B  | train
1 | discriminator | Discriminator | 1.3 B  | train
--------------------------------------------------------
2.7 B     Trainable params
0         Non-trainable params
2.7 B     Total params
1

Training: |          | 0/? [00:00<?, ?it/s]

2024-09-05 02:17:40,031 - __main__ - DEBUG - ********CGAN.training_step****************
2024-09-05 02:17:40,063 - __main__ - DEBUG - z datatype: torch.float32
2024-09-05 02:17:40,064 - __main__ - DEBUG - Training generator....
2024-09-05 02:17:40,064 - __main__ - DEBUG - *******CGAN.forward************
2024-09-05 02:17:40,064 - __main__ - DEBUG - *******generator.forward*****************
2024-09-05 02:17:40,064 - __main__ - DEBUG - z size: torch.Size([1, 100]), type: torch.float32
2024-09-05 02:17:40,065 - __main__ - DEBUG - Calling model sequential...


: 