## GAN Development

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


device = "cuda"
test = False

In [2]:
# create noise vector -- drawn from normal dist
def get_noise(n_samples, z_dim, device, im_chan=1):
    return torch.randn(n_samples, z_dim*im_chan, device=device)

Generator steps:
- Create noise vector of shape: (# samples, # elements)
- Reshape to (# samples, # elements, 1, 1); i.e. each element in the noise vector is converted into a 1x1 matrix
- Transpose convolve the 1x1 matrix to desired shape

In [3]:
def show_scan_slices(sample_element, scan_type='', x=100, cmap='gray'):
    print("{} Scan".format(scan_type))
    plt.figure()
    f, axarr = plt.subplots(1,3) 
    plt.title("Brain Scan Slices")

    axarr[0].imshow(sample_element[x], cmap=cmap)
    axarr[0].title.set_text("Slice with fixed X")
    axarr[1].imshow(sample_element[:,x], cmap=cmap)
    axarr[1].title.set_text("Slice with fixed Y")
    axarr[2].imshow(sample_element[:,:,x], cmap=cmap)
    axarr[2].title.set_text("Slice with fixed Z")

    plt.show()

The dimmension of a transpose convolution along any axis can be found with the formula:

$X_{out}[i] = X_{in}[i]*S[i]-2*P[i]+D[i]*(K[i]-1)+1$

where S, P, D, K are the vectors representing the Stride, Padding, Dilation, and Kernel along all axis respectively
For simplicity in our 3D tranpose convolution, let S=(2,2,2) P=(0,0,0) D=(1,1,1) in all layers and adjust the Kernel to reach the desired output shape.

With this simplifcation the output shape over n composed transpose convolutions is greatly reduced in complexity to :

$Gf(X_{in}[i]) = X_{out}[i]$

$f^n(X_{in}[i]) = \sum_{j=0}^{n} 2^{n-j}(K_j[i]-1)$

With this one of the possible arrangements of the kernels $K_j$ to reach the target shape is found in the following cell.

In [4]:
def compute_new_size(input_size, kernel, stride=2, padding=0, dilation=1):
    return (input_size-1)*stride - 2*padding + dilation*(kernel-1) + 1



def get_dim_s2(k):
    sum = 0
    n = len(k)-1
    for i, x in enumerate(k[:-1]):
        sum += 2**(n-i)*(x-1)
    return sum + k[-1]

print(get_dim_s2([6,5,5,7,5,2]))



# verify solution
x0, y0 = 1, 1
x1, y1 = compute_new_size(x0, 6), compute_new_size(y0, 5)
x2, y2 = compute_new_size(x1, 5), compute_new_size(y1, 3)
x3, y3 = compute_new_size(x2, 5), compute_new_size(y2, 3)
x4, y4 = compute_new_size(x3, 7), compute_new_size(y3, 5)
x5, y5 = compute_new_size(x4, 5), compute_new_size(y4, 5)
x6, y6 = compute_new_size(x5, 2), compute_new_size(y5, 3)
print("Output dim: ({},{},{})".format(x6,x6,y6))

290
Output dim: (290,290,203)


In [None]:
# define generator
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=32):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 2, kernel_size=(6,6,5)),
            self.make_gen_block(hidden_dim * 2, hidden_dim * 2, kernel_size=(5,5,3)),
            self.make_gen_block(hidden_dim * 2, hidden_dim*4, kernel_size=(5,5,3)),
            self.make_gen_block(hidden_dim * 4, hidden_dim*2, kernel_size=(7,7,5)),
            self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=5),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=(2,2,3), final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):

        # Build the neural block
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose3d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm3d(output_channels),
                nn.ReLU()
            )
        else: # Final Layer
            return nn.Sequential(
                nn.ConvTranspose3d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh()
            )
    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1, 1)

    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)



In [None]:
# test generator
if test:
    with torch.no_grad():
        gen = Generator(z_dim=100, im_chan=1).cuda()
        z = gen.unsqueeze_noise(get_noise(2, 100, device))
        print(z.shape)
        y = gen(z)
        print(y.shape)
        show_scan_slices(y[0][0].cpu())


#### Discrimator Steps:
* Input batch of images Shape: (# samples, 1, 290, 290, 203) -- output from generator and real image batches
* 3D convolve samples
* targert size for convolutions: (# Samples, 1)

This process can be done with the sequence of Kernel vectors from the generator in reverse.


In [None]:
# define discriminator
class Discriminator(nn.Module):

    def __init__(self, im_chan=1, hidden_dim=8):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim, kernel_size=(2,2,3)),
            self.make_disc_block(hidden_dim, hidden_dim * 2, kernel_size=5),
            self.make_disc_block(hidden_dim*2, hidden_dim * 4, kernel_size=(7,7,5)),
            self.make_disc_block(hidden_dim*4, hidden_dim * 2, kernel_size=(5,5,3)),
            self.make_disc_block(hidden_dim*2, hidden_dim * 2, kernel_size=(5,5,3)),
            self.make_disc_block(hidden_dim * 2, 1, kernel_size=(6,6,5), final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):

        if not final_layer:
            return nn.Sequential(
                nn.Conv3d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm3d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else: # Final Layer
            return nn.Sequential(
                nn.Conv3d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [None]:
if test:
    disc = Discriminator(im_chan=1).cuda()
    pred = disc(y)
    print(pred.shape)

In [None]:
# define hyperparameters
z_dim = 256
display_step = 5
batch_size = 2
lr = 0.002
epochs = 10
c_lambda = 10 #coefficient of gradient penalty
disc_repeats = 5
gen_repeats = 5

# optimizer momentum parameters
beta_1 = 0.5 
beta_2 = 0.999


# initialize gen, disc, and optimizer 
gen = Generator(z_dim, im_chan=1, hidden_dim=16).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator(im_chan=1, hidden_dim=8).to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# initialize network weights -- W ~ N(0,.02^2)
def weights_init(m):
    if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm3d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

The loss of this model will be defined as:

$\underset{g}{min}$ $\underset{d}{max}$ $E[d(x)]-E[d(g(z))]+\lambda(||\nabla d(\hat{x})||_2 -1)^2$

$\text{where g is the generator and d is the discriminator}$

This loss is known as wasserstein loss:
* The generator is motivated to maximize the degree to which the disciminator believes generated samples are real
    * $\underset{g}{min}$ $-E[d(g(z))]$
* The disciminator is motivated to maximize the degree to which real images are believed to be real and is punished for beliving generated images to be real
    * $\underset{d}{max}$ $E[d(x)]-E[d(g(z))]$
* The disciminator is also punished for having an L1-Gradient Norm greater than 1 (Causes issues with W-Loss)
    * Gradient is computed at intermediate point between real $x$ and fake $g(z)$ known as $\hat{x}$ (approxinated by tensor interpolation)
    * $\underset{d}{max}$ $\lambda(||\nabla d(\hat{x})||_2 -1)^2$


In [None]:

# poll gradient with interpolated sample
def get_gradient(disc, real, fake, epsilon):

    # interpolate real and fake samples
    mixed_images = torch.add(real * epsilon, fake * (1 - epsilon))
    #print(mixed_images.shape)
    mixed_scores = disc(mixed_images)
    
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

# compute l1 gradient penalty
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - torch.ones_like(gradient_norm))**2)
    return penalty

# define gen w-loss
def get_gen_loss(disc_fake_pred):
    gen_loss = -torch.mean(disc_fake_pred)
    return gen_loss

# define disc w-loss (with l1 norm)
def get_disc_loss(disc_fake_pred, disc_real_pred, gp, c_lambda):
    disc_loss = torch.mean(disc_fake_pred) + torch.mean(gp)*c_lambda - torch.mean(disc_real_pred)
    return disc_loss

In [None]:
from torch.utils.data import Dataset
import nibabel as nib
import numpy as np
import os    

# define brain dataset
class BrainDataset(Dataset):
    def __init__(self, scan_type="T2w"):
        self.scan_type=scan_type
        self.path = os.path.join(os.path.dirname(os.getcwd()), "research_dataset", "ds", self.scan_type)
        self.file_names = os.listdir(self.path)
        self.length = len(self.file_names)

    def reshape_samples(self, x):
        x = x[:,:,67:-20] # slice z down to 203
        p1, p2 = np.zeros((36, 290, 203)), np.zeros((37, 290, 203))
        x = np.concatenate((p1, x, p2), axis=0) # pad x up to 290 by evenly concatinating zero 290*203 tensors
        return x

    def path_to_array(self, path):
        image_array = nib.load(path)
        image_array = image_array.get_fdata()
        return np.array(image_array)

    # get scan as numpy array -- reshape irregular images according to preprocess.ipynb
    def __getitem__(self, index):
        item_path = os.path.join(self.path, self.file_names[index])
        item = self.path_to_array(item_path)
        if item.shape != (290, 290, 203):
            item = self.reshape_samples(item)
        if self.scan_type=="T2w":
            item = item/88
        else:
            item = item/3    
        return item

    def __len__(self):
        return self.length

# create dataloader
dataloader = DataLoader(BrainDataset(), batch_size=batch_size, shuffle=True)

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0

# training loop
for epoch in range(epochs):
    for batch in tqdm(dataloader):
        cur_batch_size = len(batch)
        real = batch.view((cur_batch_size, 1, 290, 290, 203)).to(device, dtype=torch.float)
        #print(real.shape)


        for _ in range(disc_repeats):
            ## Update discriminator
            disc_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            disc_fake_pred = disc(fake.detach())
            disc_real_pred = disc(real)

            # interpolate real and fake with random factor for GP
            epsilon = torch.rand(len(real), 1, 1, 1, 1, device=device, requires_grad=True)
            grad = get_gradient(disc, real, fake.detach(), epsilon)
            gp = gradient_penalty(grad)
            disc_loss = get_disc_loss(disc_fake_pred, disc_real_pred, gp, c_lambda)

            disc_loss.backward(retain_graph=True)
            disc_opt.step()

        ## Update generator
        for _ in range(gen_repeats):
            gen_opt.zero_grad()
            fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
            fake_2 = gen(fake_noise_2)
            disc_fake_pred = disc(fake_2)

            gen_loss =get_gen_loss(disc_fake_pred)

            gen_loss.backward()
            gen_opt.step()
        

        # Keep track of the average disc and gen loss
        disc_loss_v, gen_loss_v = disc_loss.item(), gen_loss.item()
        mean_discriminator_loss += disc_loss_v / display_step
        mean_generator_loss += gen_loss_v / display_step

        # dynamically adjust train cycles to gen or disc domination
        if mean_discriminator_loss > mean_generator_loss:
            disc_repeats = 5
            gen_repeats = 1
        else:
            disc_repeats = 1
            gen_repeats = 5

        if cur_step % display_step == 0:
            with torch.no_grad():
                print(f"Step {cur_step}: Generator loss: {disc_loss_v}, discriminator loss: {gen_loss_v}")
                fake_noise = get_noise(1, z_dim, device=device)
                fake = gen(fake_noise)[0][0].cpu()
                show_scan_slices(fake, scan_type="T2w")
                show_scan_slices(real[0][0].cpu(), scan_type="T2w")
        cur_step += 1
