In [None]:
import sys
sys.path.append('/tf/data')

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.datasets as dset

import numpy as np
import imageio
import matplotlib.pyplot as plt
import tensorflow_docs.vis.embed as embed

from tqdm import tqdm
from pathlib import Path

from MedImageGanFiles.utils import to_gpu, loss_plot, image_grid
from MedImageGanFiles.metrics import FID_plot
from MedImageGanFiles.dcgan import weights_init, Generator, Discriminator_SN, training_loop

In [None]:
params = {
    'path': '/tf/data/MedImageGanModels/netG_lr_100_2.zip',
    'images_root_path': '/tf/data/augmented_1slice_64_3ch/',
    'out': '/tf/data/MedImageGanImages/',
    'run': 'Single_slice_lr_100_b_128_smooth',
    'model_save_path': '/tf/data/MedImageGanModels_MyModels/',
    'model_save_freq': 500,

    'image_size': 64,

    'seed': 42,
    'n_gpu': 1,

    'num_epochs': 15,
    'learning_rate': 0.002,
    'beta_adam': 0.5,
    'batch_size': 128,
    'label_smooth': True,

    'latent_vector': 256,

    'loader_workers': 2,
    'number_channels': 3,
    'gen_feature_maps': 64,
    'dis_feature_maps': 64
    }
nz = params['latent_vector']

if torch.cuda.is_available():
    device = to_gpu(ngpu=params['n_gpu'])
    print('Cuda available')
else:
    device = torch.device('cpu')
    print('Cuda NOT available')

In [None]:
def show_examples(netG):
    fixed_noise = torch.randn(params['image_size'],
                            params['latent_vector'], 1, 1, device=device)
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    plt.figure(figsize=(16,2))
    plt.imshow(np.transpose(vutils.make_grid(fake[:8], padding=2, normalize=True), (1, 2, 0)))
    plt.axis("off")
    plt.show()

In [None]:
def dataLoader(path, batch_size, workers):
    def npy_loader(path):
        sample = torch.from_numpy(np.load(path))
        sample = sample.permute(2,0,1)
        return sample

    dataset = dset.DatasetFolder(root=path, loader=npy_loader, extensions=['.npy'])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    return dataloader

dataloader = dataLoader(
            path=params['images_root_path'], batch_size=params['batch_size'],
            workers=params['loader_workers'])

In [None]:
#Create models
netG = Generator(ngpu=params['n_gpu'], nz=params['latent_vector'],
                    ngf=params['gen_feature_maps'], nc=params['number_channels']).to(device)

netD = Discriminator_SN(params['n_gpu'], nc=params['number_channels'],
                        ndf=params['dis_feature_maps']).to(device)

netG.apply(weights_init)
netD.apply(weights_init)
print('Untrained Generator')
show_examples(netG)
netG.load_state_dict(torch.load(params['path']))
print('Loaded Generator')
show_examples(netG)
print(netG)

In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(params['image_size'],
                            params['latent_vector'], 1, 1, device=device)

optimizerD = optim.Adam(netD.parameters(), lr=params['learning_rate'], betas=(
    params['beta_adam'], 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=params['learning_rate'], betas=(
    params['beta_adam'], 0.999))

G_losses, D_losses, img_list, img_list_only = training_loop(num_epochs=params['num_epochs'], dataloader=dataloader,
                                                            netG=netG, netD=netD, device=device, criterion=criterion, nz=params['latent_vector'], smooth = params['label_smooth'],
                                                            model_save_path = params['model_save_path']+params['run']+'/', model_save_freq = params['model_save_freq'],
                                                            optimizerG=optimizerG, optimizerD=optimizerD, fixed_noise=fixed_noise, out=params['out'] + params['run'])

In [None]:
loss_plot(G_losses=G_losses, D_losses=D_losses, out=params['out'] + params['run'] + '/')
FID_plot(real_dataloader = dataloader, generated_img_list = img_list_only, out = params['out'] + params['run'] + '/')
image_grid(dataloader=dataloader, img_list=img_list,
            device=device, out=params['out'] + params['run'] + '/')

In [None]:
#Create GIF of generated image progression
anim_file = params['out'] + params['run'] + '/Gan.gif'
image_path = Path(params['out'] + params['run'] + '/')
filenames = list(image_path.glob('*img.png'))
filenames = sorted(filenames)

# Function to add progress bar to the image
def add_progress_bar(image, progress_percentage):
    # Calculate width of the progress bar
    progress_width = int(image.shape[1] * progress_percentage / 100)
    
    # Create a progress bar with the same height and width as the image
    progress_bar = np.zeros_like(image)
    
    # Set color of the progress bar to maximum intensity for the completed portion
    progress_bar[:10, :progress_width] = 255
    
    # Concatenate the progress bar with the image horizontally
    image_with_progress = np.concatenate([image, progress_bar], axis=0)
    
    return image_with_progress

# Initialize image writer
with imageio.get_writer(anim_file, mode='I') as writer:
    # Initialize tqdm to create a progress bar
    progress_bar = tqdm(total=len(filenames), desc='Creating GIF', unit='image')
    
    # Iterate through filenames
    for idx, filename in enumerate(filenames):
        image = imageio.imread(filename)
        
        # Calculate progress percentage
        progress_percentage = int((idx + 1) / len(filenames) * 100)
        
        # Add progress information to the image
        image_with_progress = add_progress_bar(image, progress_percentage)
        
        # Append image with progress bar to the GIF
        for i in range(10):
            writer.append_data(image_with_progress)
        
        # Update progress bar description
        progress_bar.set_postfix({'Progress': progress_percentage})
        progress_bar.update(1)

# Finalize progress bar
progress_bar.close()
print('Gif saved')

In [None]:
#Show gif
embed.embed_file(anim_file)