# Medical Data

Armaiz Adenwala


## Overview

Lorem ipsum dolor sit amet

## Creating Artificial X-Ray Scans Using a GAN

### Setting Up The Environment

In [None]:
import os

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.utils import save_image
from torchvision.utils import make_grid

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

### Preparing the Dataset

#### Downloading the Dataset

The dataset is hosted on kaggle, we will use the kaggle api to download this.

In [None]:
 ! pip install -q kaggle

Please upload your `kaggle.json` file here:

In [None]:
from google.colab import files
files.upload()

We will then need to setup the kaggle api config manually in the home directory:

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

This will downloaad the chest-xray-pneumonia dataset from kaggle:

In [None]:
! kaggle datasets download -d paultimothymooney/chest-xray-pneumonia

#### Organize the data

We will now need to unzip the data into a dataset folder. We will then create two dataset folders. One for xray scans of patients that are healthy, and one for xray scans of patients with pneumonia. This is simply to keep everything organized without having it deeply nested in the dataset folder. Additionally, the dataloader pulls from the subfolders, so we will need to split NORMAL and PNEUMONIA.

In [None]:
! mkdir dataset
! unzip chest-xray-pneumonia.zip -d dataset

We will then recursively copy over the image files into the new seperate folders to allow `ImageFolder` to load specifically normal or speicifcally pneumonia xrays.

In [None]:
! mkdir ./dataset/chest_xray/chest_xray_normal
! mkdir ./dataset/chest_xray/chest_xray_pneumonia
! cp -r ./dataset/chest_xray/chest_xray/train/NORMAL ./dataset/chest_xray/chest_xray_normal/train
! cp -r ./dataset/chest_xray/chest_xray/train/PNEUMONIA ./dataset/chest_xray/chest_xray_pneumonia/train

As we prepare the data, we need to specify a few key variables.

`batch_size` refers to ____. 


Image size refers to the size we want our images to be. Due to limited GPU resources, 64x64 is common, however, we will use 128x128 to retain as much detail as we can.

Stats is used to normalize the images.

In [None]:
batch_size = 128
image_size = 128
stats = (0.5), (0.5)
images_count = 64
images_row_count = 8

We use `ImageFolder` to load the datasets and apply the following transformations:
* resize the images to height/width of 128px
* crop the images to become exactly 128x128
* reduce images to 1 channel / grayscale
* convert to tensor
* normalization

We then load the images in dataloader. The batch size from above is passed here.

In [None]:
train_ds_norm = ImageFolder('./dataset/chest_xray/chest_xray_normal/', transform=T.Compose([ T.Resize(image_size),                       
                                                        T.CenterCrop(image_size),
                                                        T.transforms.Grayscale(num_output_channels=1),   
                                                        T.ToTensor(),
                                                        T.Normalize(*stats)]))
train_dl_norm = DataLoader(train_ds_norm, batch_size, shuffle=True, num_workers=2, pin_memory=True)

train_ds_pneum = ImageFolder('./dataset/chest_xray/chest_xray_pneumonia/', transform=T.Compose([ T.Resize(image_size),                       
                                                        T.CenterCrop(image_size),
                                                        T.transforms.Grayscale(num_output_channels=1),   
                                                        T.ToTensor(),
                                                        T.Normalize(*stats)]))
train_dl_pneum = DataLoader(train_ds_pneum, batch_size, shuffle=True, num_workers=2, pin_memory=True)

We now create some functions to help visualize our data. This will let us see our progress as images are being generated, as well as being able to viewing a large batch of images at once.

In [None]:
def denorm(tensors):
    return tensors * stats[1] + stats[0]

def show_images(images, nmax=images_count):
    grid = make_grid(
        denorm(
            images.cpu().detach()[:nmax]
        ), nrow=images_row_count)

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(grid.permute(1, 2, 0))

def show_batch(dl, nmax=images_count):
    for images, _ in dl:
        show_images(images, nmax)
        break

We can visualize our current dataset using the methods above:

In [None]:
show_batch(train_dl_norm)
show_batch(train_dl_pneum)

### Preparing the Generator


In [None]:
print(torch.__version__)

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

We can now verify that we have a gpu available:

In [None]:
device = get_default_device()
device

In [None]:
train_dl_norm = DeviceDataLoader(train_dl_norm, device)
train_dl_pneum = DeviceDataLoader(train_dl_pneum, device)

In [None]:
discriminator_norm = nn.Sequential(
    # input shape = 1 x 128 x 128

    nn.Conv2d(1, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    #  output shape = 128 x 64 x 64

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 256 x 32 x 32

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 512 x 16 x 16

    nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(1024),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 1024 x 8 x 8

    nn.Conv2d(1024, 2048, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(2048),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 2048 x 4 x 4

    nn.Conv2d(2048, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # output shape = 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid())

discriminator_pneum = nn.Sequential(
    # input shape = 1 x 128 x 128

    nn.Conv2d(1, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    #  output shape = 128 x 64 x 64

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 256 x 32 x 32

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 512 x 16 x 16

    nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(1024),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 1024 x 8 x 8

    nn.Conv2d(1024, 2048, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(2048),
    nn.LeakyReLU(0.2, inplace=True),
    # output shape = 2048 x 4 x 4

    nn.Conv2d(2048, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # output shape = 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid())

In [None]:
discriminator_norm = to_device(discriminator_norm, device)
discriminator_pneum = to_device(discriminator_pneum, device)

In [None]:
latent_size = 512

In [None]:
generator_norm = nn.Sequential(
    # input shape =  latent_size x 1 x 1

    nn.ConvTranspose2d(latent_size, 2048, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(2048),
    nn.ReLU(True),
    #output shape = 2048 x 4 x 4

    nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(1024),
    nn.ReLU(True),
    # output shape = 1024 x 8 x 8

    nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # output shape = 512 x 16 x 16

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # output shape = 256 x 32 x 32

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # output shape = 128 x 64 x 64

    nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh()
    # output shape = 1 x 128 x 128
)

generator_pneum = nn.Sequential(
    # input shape =  latent_size x 1 x 1

    nn.ConvTranspose2d(latent_size, 2048, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(2048),
    nn.ReLU(True),
    #output shape = 2048 x 4 x 4

    nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(1024),
    nn.ReLU(True),
    # output shape = 1024 x 8 x 8

    nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # output shape = 512 x 16 x 16

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # output shape = 256 x 32 x 32

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # output shape = 128 x 64 x 64

    nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh()
    # output shape = 1 x 128 x 128
)

In [None]:
generator_norm = to_device(generator_norm, device)
generator_pneum = to_device(generator_pneum, device)

We can then generate a batch of latent tensors to generate random image against both the normal and pneumonia generators. Since they are untrained, they will appear as noise for now.

In [None]:
random_latent_tensors = torch.randn(images_count, latent_size, 1, 1, device=device)
imgs_norm = generator_norm(random_latent_tensors)
imgs_pneum = generator_pneum(random_latent_tensors)

show_images(imgs_norm)
show_images(imgs_pneum)

We can now verify the shapes are correct. The generator correctly outputs 128x128 images.

In [None]:
print(imgs_norm.shape)
print(imgs_pneum.shape)

In [None]:
def train_discriminator(real_images, opt_d, discriminator, generator):
    opt_d.zero_grad()

    real_preds = discriminator(real_images)
    real_targets = torch.ones(real_images.size(0), 1, device=device)
    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()
    
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)

    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_preds = discriminator(fake_images)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()

    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()
    return loss.item(), real_score, fake_score

In [None]:
def train_generator(opt_g, discriminator, generator):
    opt_g.zero_grad()
    
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)
    
    preds = discriminator(fake_images)
    targets = torch.ones(batch_size, 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)
    
    loss.backward()
    opt_g.step()
    
    return loss.item()

In [None]:
os.makedirs('generated/normal', exist_ok=True)
os.makedirs('generated/pneumonia', exist_ok=True)
os.makedirs('epochs/normal', exist_ok=True)
os.makedirs('epochs/pneumonia', exist_ok=True)

In [None]:
def save_samples(index, latent_tensors, generator, path, show=True):
    imgs = generator(latent_tensors)
    file_name = 'xray-{0:0=5d}.png'.format(index)
    save_image(denorm(imgs), os.path.join(path, file_name), nrow=8)

    if show:
        fig, ax = plt.subplots(figsize=(16, 16))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(imgs.cpu().detach(), nrow=images_row_count).permute(1, 2, 0))

In [None]:
fixed_latent = torch.randn(images_count, latent_size, 1, 1, device=device)

In [None]:
save_samples(0, fixed_latent, generator_norm, 'epochs/normal')
save_samples(0, fixed_latent, generator_pneum, 'epochs/pneumonia')

In [None]:
lr_g = 0.0005
lr_d = 0.00005
epochs = 500

#### Setup Weights and Biases To Analyze the Model in Real-time

In [None]:
! pip install wandb

In [None]:
!wandb login
import wandb

In [None]:
def fit(epochs, lr_g, lr_d, discriminator, generator, path, train_dl, start_idx=1):
    torch.cuda.empty_cache()
    
    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    # Create optimizers
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))

    config_defaults = {
      'epochs': epochs,
      'batch_size': batch_size,
      'learning_rate': lr_g,
      'optimizer': 'adam',
      'fc_layer_size': 128,
    }

    wandb.init(project='xray-data-{}'.format(path), config=config_defaults)
    try:
      wandb.watch(generator)
    except:
      print("Error watching model, ignoring.")

    config = wandb.config
    config.learning_rate = lr_g

    for epoch in range(epochs):
        for real_images, _ in tqdm(train_dl):
            # Train discriminator
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d, discriminator, generator)
            # Train generator
            loss_g = train_generator(opt_g, discriminator, generator)
            
        # Record losses & scores
        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)

        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
        save_samples(epoch+start_idx, fixed_latent, generator, 'epochs/' + path, show=False)

        imgs = generator(fixed_latent)
        wandb.log({
          "generator loss": loss_g,
          "discriminator loss": loss_d,
          "real score": real_score,
          "fake score": fake_score,
          "examples" : [wandb.Image(i) for i in imgs],
        }, step=epoch)
    
    return losses_g, losses_d, real_scores, fake_scores, opt_d, opt_g

In [None]:
norm_history = fit(epochs, lr_g, lr_d, discriminator_norm, generator_norm, 'normal', train_dl_norm)

In [None]:
pneum_history = fit(epochs, lr_g, lr_d, discriminator_pneum, generator_pneum, 'pneumonia', train_dl_pneum)

In [None]:
os.makedirs('models/', exist_ok=True)
torch.save(discriminator_norm.state_dict(), 'models/discriminator_norm.pt')
torch.save(discriminator_pneum.state_dict(), 'models/discriminator_pnuem.pt')
torch.save(generator_norm.state_dict(), 'models/generator_norm.pt')
torch.save(generator_pneum.state_dict(), 'models/generator_pnuem.pt')

In [None]:
losses_g_norm, losses_d_norm, real_scores_norm, fake_scores_norm, opt_d_norm, opt_g_norm = norm_history
losses_g_pneum, losses_d_pneum, real_scores_pneum, fake_scores_pneum, opt_d_pneum, opt_g_pneum = pneum_history

In [None]:
torch.save({
  'discriminator_norm_state_dict': discriminator_norm.state_dict(),
  'discriminator_pnuem_state_dict': discriminator_pneum.state_dict(),
  'generator_norm_state_dict': generator_norm.state_dict(),
  'generator_pnuem_state_dict': generator_pneum.state_dict(),
  'generator_optim_norm_state_dict': opt_g_norm.state_dict(),
  'generator_optim_pnuem_state_dict': opt_g_pnuem.state_dict(),
  'discriminator_optim_norm_state_dict': opt_d_norm.state_dict(),
  'discriminator_optim_pnuem_state_dict': opt_d_pnuem.state_dict(),
  epoch: epochs,
}, 'models/gan.pt')

In [None]:
! rm -rf 'drive/MyDrive/chest_xrays/artifacts/'
artifacts_dir = 'drive/MyDrive/chest_xrays/artifacts/'

In [None]:
os.makedirs(artifacts_dir, exist_ok=True)

In [None]:
! cp -r epochs/ drive/MyDrive/chest_xrays/artifacts/epochs/
! cp -r models/ drive/MyDrive/chest_xrays/artifacts/models/

In [None]:
def generate_images(num, generator, path, name):
    for i in range(num):
      random_latent_tensors = torch.randn(batch_size, latent_size, 1, 1, device=device)
      imgs = generator(random_latent_tensors)
      file_name = 'generated-xray-{0:0=5d}-{1}.png'.format(i+1, name)
      save_image(denorm(imgs.cpu().detach()[:1]), os.path.join(path, file_name), nrow=1)
      print('saving {}/{}'.format(i+1, num))

In [None]:
generate_images(5000, generator_norm, 'generated/normal/', 'normal')

In [None]:
generate_images(5000, generator_pneum, 'generated/pneumonia/', 'pneumonia')

In [None]:
! cp -r generated/ drive/MyDrive/chest_xrays/artifacts/generated/

In [None]:
random_latent_tensors = torch.randn(128, latent_size, 1, 1, device=device)
imgs_norm = generator_norm(random_latent_tensors)
imgs_pneum = generator_pneum(random_latent_tensors)

show_images(imgs_norm, 1)
show_images(imgs_pneum, 1)
