# Generative Adversarial Network

In this notebook we'll generate images to augment the dataset using a GAM

In [33]:
import torch
import torch.nn as nn
from torch.utils.data import random_split
from torchvision import transforms
from monai.transforms import LoadImaged, EnsureChannelFirstd, ScaleIntensityd, ToTensord
from monai.data import (CacheDataset, DataLoader, Dataset, PersistentDataset,
                        pad_list_data_collate)


from src.handlers import Handler, OpHandler, TciaHandler

## Gam preparation

First, we'll set the required code for the gam itself

### Generator

In [2]:
class Generator(nn.Module):
    def __init__(self, nz):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose3d(nz, 512, 4, 1, 0, bias=False),
            nn.BatchNorm3d(512),
            nn.ReLU(True),
            nn.ConvTranspose3d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm3d(256),
            nn.ReLU(True),
            nn.ConvTranspose3d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU(True),
            nn.ConvTranspose3d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(True),
            nn.ConvTranspose3d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

### Discriminator

In [4]:
# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv3d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm3d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

## Dataset loading

In [5]:
BASE_PATH = 'Data/'
# ...
TCIA_IMG_SUFFIX = '_PV.nii.gz'
TCIA_LOCATION = BASE_PATH + 'TCIA/'
TCIA_EXCEL_NAME = 'HCC-TACE-Seg_clinical_data-V2.xlsx'
# ...
OP_LOCATION = BASE_PATH + 'OP/'
NIFTI_PATH = 'OP_C+P_nifti'
NNU_NET_PATH = 'OP_C+P_nnUnet'
OP_EXCEL = 'OP_申請建模_1121110_20231223.xlsx'
OP_IMG_SUFFIX = '_VENOUS_PHASE.nii.gz'
OP_MASK_SUFFIX = '_VENOUS_PHASE_seg.nii.gz'
OP_ID_COL_NAME = 'OP_C+P_Tumor識別碼'

In [6]:
global_handler = Handler()

tcia = TciaHandler(TCIA_LOCATION, TCIA_IMG_SUFFIX, TCIA_EXCEL_NAME)
global_handler.add_source(tcia)

op = OpHandler(OP_LOCATION, NIFTI_PATH, NNU_NET_PATH, OP_IMG_SUFFIX, OP_MASK_SUFFIX, OP_EXCEL, OP_ID_COL_NAME)
global_handler.add_source(op)

DEBUG: reading file...
INFO: 105 rows in the excel file
INFO: Removed 3 stage-d elements
DEBUG: Classifying...
DEBUG: Looking for paths against contents
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_011_PV.nii.gz
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_031_PV.nii.gz
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_082_PV.nii.gz
DEBUG: None
DEBUG: reading file Data/OP/OP_申請建模_1121110_20231223.xlsx
INFO: 200 rows in the excel file
INFO: Removed 55 stage-d elements
DEBUG: Classifying...
DEBUG: Looking for paths against contents
DEBUG: Searching for mismatch on files vs excel data...
DEBUG: Returning new dataframe
DEBUG: None


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 99 entries, 0 to 98
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   99 non-null     object
 1   img     99 non-null     object
 2   mask    99 non-null     object
dtypes: object(3)
memory usage: 2.4+ KB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 244 entries, 0 to 243
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   244 non-null    object
 1   img     244 non-null    object
 2   mask    244 non-null    object
dtypes: object(3)
memory usage: 5.8+ KB


In [7]:
df = global_handler.df

df.head()

Unnamed: 0,class,img,mask
0,0,Data/TCIA/TCIA_image_PV/HCC_001_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_001_PV.nii.gz
1,2,Data/TCIA/TCIA_image_PV/HCC_002_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_002_PV.nii.gz
2,2,Data/TCIA/TCIA_image_PV/HCC_003_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_003_PV.nii.gz
3,1,Data/TCIA/TCIA_image_PV/HCC_004_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_004_PV.nii.gz
4,2,Data/TCIA/TCIA_image_PV/HCC_005_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_005_PV.nii.gz


## Model training

In [35]:
import os
num_workers = os.cpu_count()
num_workers

12

In [17]:
data_keys = ['class', 'mask']

data_dict = df[data_keys].to_dict(orient='records')

In [32]:
_transforms = transforms.Compose([
    LoadImaged(keys=data_keys),
    EnsureChannelFirstd(keys=data_keys),
    ScaleIntensityd(keys=data_keys),
    ToTensord(keys=data_keys)
])

ds = Dataset(
    data=data_dict,
    transform=_transforms,
    # cache_rate=1.0,
    # num_workers=num_workers,
    # cache_dir=BASE_PATH + 'cache'
)

DEBUG: required package for reader nrrdreader is not installed, or the version doesn't match requirement.


In [27]:
# Define the sizes for the train and test sets
train_size = int(0.8 * len(df))  # 80% for training
test_size = len(df) - train_size  # Remaining 20% for testing

In [34]:
train_ds, test_ds = random_split(ds, [train_size, test_size])

In [37]:
train_dl = DataLoader(
    train_ds,
    batch_size=1,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
    collate_fn=pad_list_data_collate
)

test_dl = DataLoader(
    test_ds,
    batch_size=1,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
    collate_fn=pad_list_data_collate
)

In [8]:
batch_size = 16
nz = 100  # Size of z latent vector (i.e. size of generator input)
num_epochs = 100
lr = 0.0002
beta1 = 0.5

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cpu')