In [None]:
import sys
sys.path.append('/tf/data')

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.utils as vutils

import numpy as np
import matplotlib.pyplot as plt

from MedImageGanFiles.utils import to_gpu, loss_plot, image_grid, create_gif, show_gif
from MedImageGanFiles.metrics import FID_plot
from MedImageGanFiles.dcgan import weights_init, Generator, training_loop, Discriminator_SN_wide3

In [None]:
params = {
    'path': '/tf/data/MedImageGanModels/Single slice simple gan.zip',
    'images_root_path': '/tf/data/augmented_64_3ch/',
    'out': '/tf/data/MedImageGanImages/',
    'run': 'MIG3_Wide_img_lr_0_0001_b_16_smooth_5_B_0_9',
    'model_save_path': '/tf/data/MedImageGanModels_MyModels/',
    'model_save_freq': 500,
    'save_discr': True,

    'image_size': 64,
    'wide_images': True,

    'seed': 42,
    'n_gpu': 1,

    'num_epochs': 30,
    'learning_rate': 0.0001,
    'beta_adam': 0.9,
    'batch_size': 16,
    'label_smooth': True,

    'latent_vector': 256,

    'loader_workers': 2,
    'number_channels': 3,
    'gen_feature_maps': 64,
    'dis_feature_maps': 64
    }

nz = params['latent_vector']
img_path = f"{params['out']}{params['run']}/"

if torch.cuda.is_available():
    device = to_gpu(ngpu=params['n_gpu'])
    print('Cuda available')
else:
    device = torch.device('cpu')
    print('Cuda NOT available')

In [None]:
def show_examples(netG, wide = False):
    fixed_noise = torch.randn(params['image_size'],
                            params['latent_vector'], 1, 1, device=device)
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    plt.figure(figsize=(16,2))
    if wide == False:
        plt.imshow(np.transpose(vutils.make_grid(fake[:8], padding=2, normalize=True), (1, 2, 0)))
    else:
        plt.imshow(np.transpose(vutils.make_grid(fake[:3], padding=2, normalize=True, nrow=1), (1, 2, 0)))
    plt.axis("off")
    plt.show()


In [None]:
def dataLoader(path, batch_size, workers):
    def npy_loader(path):
        sample = torch.from_numpy(np.load(path))
        sample = sample.permute(2,0,1)
        return sample

    dataset = dset.DatasetFolder(root=path, loader=npy_loader, extensions=['.npy'])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    return dataloader

dataloader = dataLoader(
            path=params['images_root_path'], batch_size=params['batch_size'],
            workers=params['loader_workers'])

In [None]:
#Show examples of real data
num_images = 5
for batch in dataloader:
    print(len(batch))
    input_data = batch[0]
    print("Shape of input data:", input_data.shape)
    input_data = input_data[:num_images]
    grid_image = vutils.make_grid(input_data, nrow=1, padding=2, normalize=True)
    plt.figure(figsize=(20, num_images))
    plt.imshow(np.transpose(grid_image, (1, 2, 0)))
    plt.axis('off')
    plt.title('Navel     <<          Rectum          >>       Legs')
    plt.show()
    break

In [None]:
#Create generator
netG = Generator(ngpu=params['n_gpu'], nz=params['latent_vector'],
                    ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)
netG.apply(weights_init)
netG.load_state_dict(torch.load(params['path']))
print('Loaded Generator single channel')
input_tensor = torch.randn(1, 256, 1, 1).to(device)
print('Output tensor shape:', netG(input_tensor).shape)
show_examples(netG)

#Modify generator to wide images
netG.main[12] = nn.ConvTranspose2d(64, 64, kernel_size=(3, 4), stride=(1, 4), padding=(1, 2), bias=False).apply(weights_init).to(device)
netG.main[13] = nn.BatchNorm2d(64).apply(weights_init).to(device)
netG.main.add_module('14', nn.ReLU(inplace=True))

netG.main.add_module('15', nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=(5, 4), stride=(1, 4), padding=(2,2), bias=False).apply(weights_init).to(device))
netG.main.add_module('16', nn.BatchNorm2d(64).to(device))
netG.main.add_module('17', nn.ReLU(inplace=True))
                     
netG.main.add_module('18', nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=(4, 3), stride=(2, 3), padding=(1,2), bias=False).apply(weights_init).to(device))
netG.main.add_module('19', nn.Tanh())

print('Loaded Generator 23 channels')
input_tensor = torch.randn(1, 256, 1, 1).to(device)
print('Output tensor shape:', netG(input_tensor).shape)
show_examples(netG, wide=True)
print(netG)

In [None]:
#Create discriminator
netD = Discriminator_SN_wide3(params['n_gpu'], nc=params['number_channels'],
                        ndf=params['dis_feature_maps']).to(device)
netD.apply(weights_init)
print(netD)

input_tensor = torch.randn(1, 3, 64, 1472).to(device)
print(netD(input_tensor).shape)

In [None]:

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=params['learning_rate'], betas=(
    params['beta_adam'], 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=params['learning_rate'], betas=(
    params['beta_adam'], 0.999))

fixed_noise = torch.randn(params['image_size'],
                            params['latent_vector'], 1, 1, device=device)

G_losses, D_losses, img_list, img_list_only = training_loop(num_epochs=params['num_epochs'], dataloader=dataloader, wide_images = params['wide_images'],
                                                            netG=netG, netD=netD, save_discr=params['save_discr'], device=device, criterion=criterion, nz=params['latent_vector'], smooth = params['label_smooth'],
                                                            model_save_path = params['model_save_path']+params['run']+'/', model_save_freq = params['model_save_freq'],
                                                            optimizerG=optimizerG, optimizerD=optimizerD, fixed_noise=fixed_noise, out=img_path)

In [None]:
loss_plot(G_losses=G_losses, D_losses=D_losses, out=img_path)
FID_plot(real_dataloader = dataloader, generated_img_list = img_list_only, out = img_path)
image_grid(dataloader=dataloader, img_list=img_list, wide_images = params['wide_images'],
            device=device, out=img_path)
gif_path = create_gif(image_path=img_path)

In [None]:
show_gif(gif_path)