# Fetch Codebase and Models

In [12]:
import os
os.chdir('/content')
CODE_DIR = 'Controllable_MedGAN'
!git clone https://github.com/WhitneyLab/Controllable_MedGAN.git $CODE_DIR
os.chdir(f'./{CODE_DIR}')
!wget https://www.dropbox.com/s/7d7fzxxjzi7t61m/BCN.pkl?dl=1 -O pretrained_models/BCN.pkl --quiet
# !wget https://www.dropbox.com/s/ckp4r2ubqsvqdne/CT.pkl?dl=1 -O pretrained_models/CT.pkl --quiet
# !wget https://www.dropbox.com/s/fdbj4zb4mfkyeps/HAM.pkl?dl=1 -O pretrained_models/HAM.pkl --quiet
# !wget https://www.dropbox.com/s/buxj4ypxnk8w0v6/MRI_knee.pkl?dl=1 -O pretrained_models/MRI_knee.pkl --quiet
# !wget https://www.dropbox.com/s/2q5cn45d5s0mi0a/SSIM.pkl?dl=1 -O pretrained_models/SSIM.pkl --quiet

# Define Utility Functions

In [13]:
import io
import IPython.display
import numpy as np
import cv2
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config

import torch

from training import misc

synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
_Gs_cache = dict()

def load_Gs(url):
    if url not in _Gs_cache:
        _G, _D, Gs = misc.load_pkl(url)
        _Gs_cache[url] = Gs
    return _Gs_cache[url]

def generate_figures(Gs, nums, seed):
    latents = np.random.RandomState(seed).randn(nums, Gs.input_shape[1])
    images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb]
    image_iter = iter(list(images))
    img_list = []

    for i in tqdm(range(nums)):
        image = PIL.Image.fromarray(next(image_iter), 'RGB')
        img_list.append(image)
    
    return img_list


def imshow(images, col, viz_size=256):
  """Shows images in one figure."""
  num, height, width, channels = images.shape
  assert num % col == 0
  row = num // col

  fused_image = np.zeros((viz_size * row, viz_size * col, channels), dtype=np.uint8)

  for idx, image in enumerate(images):
    i, j = divmod(idx, col)
    y = i * viz_size
    x = j * viz_size
    if height != viz_size or width != viz_size:
      image = cv2.resize(image, (viz_size, viz_size))
    fused_image[y:y + viz_size, x:x + viz_size] = image

  fused_image = np.asarray(fused_image, dtype=np.uint8)
  data = io.BytesIO()
  PIL.Image.fromarray(fused_image).save(data, 'jpeg')
  im_data = data.getvalue()
  disp = IPython.display.display(IPython.display.Image(im_data))
  return disp

# generate_figures(load_Gs(file_skin_HAM),20000,seed=6)

ModuleNotFoundError: ignored

# Select a Model

In [None]:
#@title { display-mode: "form", run: "auto" }
model_name = "BCN" #@param ['BCN','CT','HAM','MRI_knee','SSIM']

model_parameters = "./pretrained_models/" + model_name + ".pkl"
Gs = load_Gs(model_parameters)

# Generate image samples

In [None]:
#@title { display-mode: "form", run: "auto" }

num_samples = 4 #@param {type:"slider", min:2, max:8, step:2}
noise_seed = 0 #@param {type:"slider", min:0, max:1000, step:1}

images = generate_figures(Gs,num_samples,noise_seed)
imshow(images,col=num_samples//2)