# Generative Adversarial Network

In this notebook we'll generate images to augment the dataset using a GAM

In [19]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import random_split
from torchvision import transforms
from monai.transforms import LoadImaged, EnsureChannelFirstd, ScaleIntensityd, ToTensord
from monai.data import (CacheDataset, DataLoader, Dataset, PersistentDataset,
                        pad_list_data_collate)


from src.handlers import Handler, OpHandler, TciaHandler

## Gam preparation

First, we'll set the required code for the gam itself

### Generator

In [2]:
class Generator(nn.Module):
    def __init__(self, nz):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose3d(nz, 512, 4, 1, 0, bias=False),
            nn.BatchNorm3d(512),
            nn.ReLU(True),
            nn.ConvTranspose3d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm3d(256),
            nn.ReLU(True),
            nn.ConvTranspose3d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU(True),
            nn.ConvTranspose3d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(True),
            nn.ConvTranspose3d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

### Discriminator

In [3]:
# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv3d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm3d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

## Dataset loading

In [4]:
BASE_PATH = 'Data/'
# ...
TCIA_IMG_SUFFIX = '_PV.nii.gz'
TCIA_LOCATION = BASE_PATH + 'TCIA/'
TCIA_EXCEL_NAME = 'HCC-TACE-Seg_clinical_data-V2.xlsx'
# ...
OP_LOCATION = BASE_PATH + 'OP/'
NIFTI_PATH = 'OP_C+P_nifti'
NNU_NET_PATH = 'OP_C+P_nnUnet'
OP_EXCEL = 'OP_申請建模_1121110_20231223.xlsx'
OP_IMG_SUFFIX = '_VENOUS_PHASE.nii.gz'
OP_MASK_SUFFIX = '_VENOUS_PHASE_seg.nii.gz'
OP_ID_COL_NAME = 'OP_C+P_Tumor識別碼'

In [6]:
# => fails, necessary on Fedora 27, ipython3 6.2.1
%config Application.log_level='INFO'

import logging

logging.basicConfig(level=logging.INFO)

# Get the root logger and set its level
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [7]:
global_handler = Handler()

tcia = TciaHandler(TCIA_LOCATION, TCIA_IMG_SUFFIX, TCIA_EXCEL_NAME)
global_handler.add_source(tcia)

op = OpHandler(OP_LOCATION, NIFTI_PATH, NNU_NET_PATH, OP_IMG_SUFFIX, OP_MASK_SUFFIX, OP_EXCEL, OP_ID_COL_NAME)
global_handler.add_source(op)

DEBUG: reading file...
INFO: 105 rows in the excel file
INFO: Removed 3 stage-d elements
DEBUG: Classifying...
DEBUG: Looking for paths against contents
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_011_PV.nii.gz
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_031_PV.nii.gz
DEBUG: File not found: Data/TCIA/TCIA_image_PV/HCC_082_PV.nii.gz
DEBUG: None
DEBUG: reading file Data/OP/OP_申請建模_1121110_20231223.xlsx
INFO: 200 rows in the excel file
INFO: Removed 55 stage-d elements
DEBUG: Classifying...
DEBUG: Looking for paths against contents
DEBUG: Searching for mismatch on files vs excel data...
DEBUG: Returning new dataframe
DEBUG: None


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 99 entries, 0 to 98
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   99 non-null     object
 1   img     99 non-null     object
 2   mask    99 non-null     object
dtypes: object(3)
memory usage: 2.4+ KB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 244 entries, 0 to 243
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   class   244 non-null    object
 1   img     244 non-null    object
 2   mask    244 non-null    object
dtypes: object(3)
memory usage: 5.8+ KB


In [8]:
df = global_handler.df

df.head()

Unnamed: 0,class,img,mask
0,0,Data/TCIA/TCIA_image_PV/HCC_001_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_001_PV.nii.gz
1,2,Data/TCIA/TCIA_image_PV/HCC_002_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_002_PV.nii.gz
2,2,Data/TCIA/TCIA_image_PV/HCC_003_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_003_PV.nii.gz
3,1,Data/TCIA/TCIA_image_PV/HCC_004_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_004_PV.nii.gz
4,2,Data/TCIA/TCIA_image_PV/HCC_005_PV.nii.gz,Data/TCIA/TCIA_results_phase_PV/HCC_005_PV.nii.gz


## Model training

In [84]:
import os
num_workers = os.cpu_count()
num_workers

4

In [85]:
imgs = df['img'].tolist()
classes = df['class'].tolist()

In [86]:
_transforms = transforms.Compose([
    EnsureChannelFirst(),
    Resize((512, 512, 20)),
    ScaleIntensity(),
    ToTensor(),
])

ds = ImageDataset(
    image_files=imgs,
    labels=classes,
    transform=_transforms,
    # cache_rate=1.0,
    # num_workers=num_workers,
    # cache_dir=BASE_PATH + 'cache'
)

In [87]:
# Define the sizes for the train and test sets
train_size = int(0.8 * len(df))  # 80% for training
test_size = len(df) - train_size  # Remaining 20% for testing

In [88]:
train_ds, test_ds = random_split(ds, [train_size, test_size])

train_ds

<torch.utils.data.dataset.Subset at 0x7f8ec2e17500>

In [89]:
train_dl = DataLoader(
    train_ds,
    batch_size=1,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
    collate_fn=pad_list_data_collate
)

test_dl = DataLoader(
    test_ds,
    batch_size=1,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
    collate_fn=pad_list_data_collate
)

In [91]:
from monai.utils.misc import first

im, label = first(train_dl)
print(type(im), im.shape, label, label.shape)

<class 'monai.data.meta_tensor.MetaTensor'> torch.Size([1, 1, 512, 512, 20]) tensor([0.], dtype=torch.float64) torch.Size([1])


In [92]:
batch_size = 16
nz = 100  # Size of z latent vector (i.e. size of generator input)
num_epochs = 100
lr = 0.0002
beta1 = 0.5

In [93]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cpu')

In [94]:
# Initialize models
gen = Generator(nz)
disc = Discriminator()
if torch.cuda.is_available():
    gen = gen.cuda()
    disc = disc.cuda()

In [96]:
# Loss function and optimizers
criterion = nn.BCELoss()
optim_gen = optim.Adam(gen.parameters(), lr=lr, betas=(beta1, 0.999))
optim_dis = optim.Adam(disc.parameters(), lr=lr, betas=(beta1, 0.999))

# Training loop
for epoch in range(num_epochs):
    print('Running epoch')
    for i, data in enumerate(train_dl, 0):
        print('Running time: ', i)
        # Update Discriminator
        disc.zero_grad()
        real_images = data[0].cuda()
        batch_size = real_images.size(0)
        labels = torch.full((batch_size,), 1, dtype=torch.float, device='cuda')
        output = netD(real_images).view(-1)
        errD_real = criterion(output, labels)
        errD_real.backward()
        noise = torch.randn(batch_size, nz, 1, 1, 1, device='cuda')
        fake_images = netG(noise)
        labels.fill_(0)
        output = netD(fake_images.detach()).view(-1)
        errD_fake = criterion(output, labels)
        errD_fake.backward()
        optimizerD.step()

        # Update Generator
        netG.zero_grad()
        labels.fill_(1)
        output = netD(fake_images).view(-1)
        errG = criterion(output, labels)
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD_real.item() + errD_fake.item()} Loss_G: {errG.item()}')

print("Training finished!")


Running epoch
Running time:  0


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

[IPKernelApp] Exception in execute request:
[0;31m---------------------------------------------------------------------------[0m
[0;31mRuntimeError[0m                              Traceback (most recent call last)
Cell [0;32mIn[96], line 13[0m
[1;32m     11[0m [38;5;66;03m# Update Discriminator[39;00m
[1;32m     12[0m disc[38;5;241m.[39mzero_grad()
[0;32m---> 13[0m real_images [38;5;241m=[39m [43mdata[49m[43m[[49m[38;5;241;43m0[39;49m[43m][49m[38;5;241;43m.[39;49m[43mcuda[49m[43m([49m[43m)[49m
[1;32m     14[0m batch_size [38;5;241m=[39m real_images[38;5;241m.[39msize([38;5;241m0[39m)
[1;32m     15[0m labels [38;5;241m=[39m torch[38;5;241m.[39mfull((batch_size,), [38;5;241m1[39m, dtype[38;5;241m=[39mtorch[38;5;241m.[39mfloat, device[38;5;241m=[39m[38;5;124m'[39m[38;5;124mcuda[39m[38;5;124m'[39m)

File [0;32m~/.conda/envs/monai-conda/lib/python3.12/site-packages/monai/data/meta_tensor.py:282[0m, in [0;36mMetaTensor.__torch_