<a href="https://colab.research.google.com/github/LarsAmker/ExplainGAN/blob/master/Pneumonia_workinprogress.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import and access the data

In [0]:
# Colab library to upload files to notebook
from google.colab import files

# Install Kaggle library
!pip install -q kaggle

In [0]:
#!pwd
!mkdir ~/.kaggle
#!cp /content/.kaggle/kaggle.json ~/.kaggle/kaggle.json

The kaggle.json file is needed to get the data directly from kaggle. Copy it over from my google drive

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
# copy the json file from google drive into the kaggle directory
!cp "/content/gdrive/My Drive/kaggle.json" ~/.kaggle

In [0]:
# Download pneumonia data (first line makes kaggle API key unreadable)
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia
# the zip file is listed in the left hand pane after the download. We need to unzip and create paths

In [0]:
!unzip chest-xray-pneumonia.zip

In [7]:
import os
import glob
import h5py
import shutil
import imgaug as aug
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.image as mimg
import imgaug.augmenters as iaa
from os import listdir, makedirs, getcwd, remove
from os.path import isfile, join, abspath, exists, isdir, expanduser
from PIL import Image
from pathlib import Path
from skimage.io import imread
from skimage.transform import resize
from keras.models import Sequential, Model
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.preprocessing.image import ImageDataGenerator,load_img, img_to_array
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dense, Dropout, Input, Flatten, SeparableConv2D
from keras.layers import GlobalMaxPooling2D
from keras.layers.normalization import BatchNormalization
from keras.layers.merge import Concatenate
from keras.models import Model
from keras.optimizers import Adam, SGD, RMSprop
from keras.callbacks import ModelCheckpoint, Callback, EarlyStopping
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from mlxtend.plotting import plot_confusion_matrix
from sklearn.metrics import confusion_matrix
import cv2
from keras import backend as K
color = sns.color_palette()
%matplotlib inline

# added tensorflow and stuff from the MNIST classifier:
import tensorflow as tf
from __future__ import absolute_import, division, print_function, unicode_literals
import PIL
import time
from IPython import display

  import pandas.util.testing as tm
Using TensorFlow backend.


In [0]:
# My directory is called "content" instead of "input" as in the original code
#os.listdir("../content/chest_xray/chest_xray") 

## Make the data accessible

NAIN had a code box fixing some random seeds here. Not compatible with tf2 however, therefore deleted and saved in testPneumonia/reproducability

The dataset is divided into three sets: 1) train set 2) validation set and 3) test set. Let's grab the dataset

In [0]:
# Define path to the data directory: "content" instead of "input/chest-xray-pneumonia" as in the original code
data_dir = Path('../content/chest_xray/chest_xray')

# Path to train directory (Fancy pathlib...no more os.path!!)
train_dir = data_dir / 'train'

# Path to validation directory
val_dir = data_dir / 'val'

# Path to test directory
test_dir = data_dir / 'test'

Each of the above directories contains two sub-directories:

NORMAL: These are the samples that describe the normal (no pneumonia) case.

PNEUMONIA: This directory contains those samples that are the pneumonia cases.

### Training data

In [0]:
# Get the path to the normal and pneumonia sub-directories
normal_cases_dir = train_dir / 'NORMAL'
pneumonia_cases_dir = train_dir / 'PNEUMONIA'

# Get the list of all the images
normal_cases = normal_cases_dir.glob('*.jpeg')
pneumonia_cases = pneumonia_cases_dir.glob('*.jpeg')

# An empty list. We will insert the data into this list in (img_path, label) format
train_data = []

# Go through all the normal cases. The label for these cases will be 0
for img in normal_cases:
    train_data.append((img,0))

# Go through all the pneumonia cases. The label for these cases will be 1
for img in pneumonia_cases:
    train_data.append((img, 1))

# Get a pandas dataframe from the data we have in our list 
train_data = pd.DataFrame(train_data, columns=['image', 'label'],index=None)

# Shuffle the data 
train_data = train_data.sample(frac=1.).reset_index(drop=True)

# How the dataframe looks like?
#train_data.head()

### Validation data
We will be defining a generator for the training dataset later in the notebook but as the validation data is small, so I can read the images and can load the data without the need of a generator. This is exactly what the code block given below is doing.

In [0]:
# Why do we make artificial RGB pictures? Because the weights from imageNet that we import are for RGB pictures
# Get the path to the sub-directories
normal_cases_dir = val_dir / 'NORMAL'
pneumonia_cases_dir = val_dir / 'PNEUMONIA'

# Get the list of all the images
normal_cases = normal_cases_dir.glob('*.jpeg')
pneumonia_cases = pneumonia_cases_dir.glob('*.jpeg')

# List that are going to contain validation images data and the corresponding labels
valid_data = []
valid_labels = []

# Some images are in grayscale while majority of them contains 3 channels. 
# So, if the image is grayscale, we will convert into a image with 3 channels.
# We will normalize the pixel values and resizing all the images to 224x224 

# Normal cases
for img in normal_cases:
    img = cv2.imread(str(img))
    img = cv2.resize(img, (224,224))
    if img.shape[2] ==1:
        img = np.dstack([img, img, img]) # add a third dimension for RGB channels and fill it with 3 copies of the original
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)/255.
    label = to_categorical(0, num_classes=2)
    valid_data.append(img)
    valid_labels.append(label)
                      
# Pneumonia cases        
for img in pneumonia_cases:
    img = cv2.imread(str(img))
    img = cv2.resize(img, (224,224))
    if img.shape[2] ==1:
        img = np.dstack([img, img, img])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)/255.
    label = to_categorical(1, num_classes=2)
    valid_data.append(img)
    valid_labels.append(label)
    
# Convert the list into numpy arrays
valid_data = np.array(valid_data)
valid_labels = np.array(valid_labels)

#print("Total number of validation examples: ", valid_data.shape)
#print("Total number of labels:", valid_labels.shape)

### Test data

In [0]:
# Preparing test data
normal_cases_dir = test_dir / 'NORMAL'
pneumonia_cases_dir = test_dir / 'PNEUMONIA'

normal_cases = normal_cases_dir.glob('*.jpeg')
pneumonia_cases = pneumonia_cases_dir.glob('*.jpeg')

test_data = []
test_labels = []

for img in normal_cases:
    img = cv2.imread(str(img))
    img = cv2.resize(img, (224,224))
    if img.shape[2] ==1:
        img = np.dstack([img, img, img])
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)/255.
    label = to_categorical(0, num_classes=2)
    test_data.append(img)
    test_labels.append(label)
                      
for img in pneumonia_cases:
    img = cv2.imread(str(img))
    img = cv2.resize(img, (224,224))
    if img.shape[2] ==1:
        img = np.dstack([img, img, img])
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)/255.
    label = to_categorical(1, num_classes=2)
    test_data.append(img)
    test_labels.append(label)

test_data = np.array(test_data)
test_labels = np.array(test_labels)

#print("Total number of test examples: ", test_data.shape)
#print("Total number of labels:", test_labels.shape)

### Training data generator (stop ExplainGAN preprocessing before this?)
Data augmentation is a powerful technique which helps in almost every case for improving the robustness of a model. But augmentation can be much more helpful where the dataset is imbalanced. You can generate different samples of undersampled class in order to try to balance the overall distribution.
I like imgaug a lot. It comes with a very clean api and you can do hell of augmentations with it. It's worth exploring!! In the next code block, I will define a augmentation sequence. You will notice Oneof and it does exactly that. At each iteration, it will take one augmentation technique out of the three and will apply that on the samples

In [0]:
# Augmentation sequence 
seq = iaa.OneOf([
    iaa.Fliplr(), # horizontal flips
    iaa.Affine(rotate=20), # roatation
    iaa.Multiply((1.2, 1.5))]) #random brightness

In [0]:
def data_gen(data, batch_size):
    # Get total number of samples in the data
    n = len(data)
    steps = n//batch_size
    
    # Define two numpy arrays for containing batch data and labels
    batch_data = np.zeros((batch_size, 224, 224, 3), dtype=np.float32)
    batch_labels = np.zeros((batch_size,2), dtype=np.float32)

    # Get a numpy array of all the indices of the input data
    indices = np.arange(n)
    
    # Initialize a counter
    i =0
    while True:
        np.random.shuffle(indices)
        # Get the next batch 
        count = 0
        next_batch = indices[(i*batch_size):(i+1)*batch_size]
        for j, idx in enumerate(next_batch):
            img_name = data.iloc[idx]['image']
            label = data.iloc[idx]['label']
            
            # one hot encoding
            encoded_label = to_categorical(label, num_classes=2)
            # read the image and resize
            img = cv2.imread(str(img_name))
            img = cv2.resize(img, (224,224))
            
            # check if it's grayscale
            if img.shape[2]==1:
                img = np.dstack([img, img, img])
            
            # cv2 reads in BGR mode by default
            orig_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # normalize the image pixels
            orig_img = img.astype(np.float32)/255.
            
            batch_data[count] = orig_img
            batch_labels[count] = encoded_label
            
            # generating more samples of the undersampled class
            if label==0 and count < batch_size-2:
                aug_img1 = seq.augment_image(img)
                aug_img2 = seq.augment_image(img)
                aug_img1 = cv2.cvtColor(aug_img1, cv2.COLOR_BGR2RGB)
                aug_img2 = cv2.cvtColor(aug_img2, cv2.COLOR_BGR2RGB)
                aug_img1 = aug_img1.astype(np.float32)/255.
                aug_img2 = aug_img2.astype(np.float32)/255.

                batch_data[count+1] = aug_img1
                batch_labels[count+1] = encoded_label
                batch_data[count+2] = aug_img2
                batch_labels[count+2] = encoded_label
                count +=2
            
            else:
                count+=1
            
            if count==batch_size-1:
                break
            
        i+=1
        yield batch_data, batch_labels
            
        if i>=steps:
            i=0

# Classifier (NAIN's depthwise kernel)

Build and compile an empty model and then load the weights for it from my google drive. These weights were created and saved in the notebook "PneumoniaClassifier.ipynb"

In [0]:
# changed the format of this method due to compatibility with tensorflow 2
# The original version from NAIN worked with "x=...(x)" steps, my theory is that this x is the placeholder that lead to errors
from tensorflow.keras import layers

def build_model():
  model = tf.keras.Sequential()
  model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='Conv1_1', input_shape=[224, 224, 3]))
  model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='Conv1_2', input_shape=[224, 224, 3]))
  model.add(layers.MaxPooling2D((2,2), name='pool'))

  model.add(layers.SeparableConv2D(128, (3,3), activation='relu', padding='same', name='Conv2_1'))
  model.add(layers.SeparableConv2D(128, (3,3), activation='relu', padding='same', name='Conv2_2'))
  model.add(layers.MaxPooling2D((2,2), name='pool2'))

  model.add(layers.SeparableConv2D(256, (3,3), activation='relu', padding='same', name='Conv3_1'))
  model.add(layers.BatchNormalization(name='bn1'))
  model.add(layers.SeparableConv2D(256, (3,3), activation='relu', padding='same', name='Conv3_2'))
  model.add(layers.BatchNormalization(name='bn2'))
  model.add(layers.SeparableConv2D(256, (3,3), activation='relu', padding='same', name='Conv3_3'))
  model.add(layers.MaxPooling2D((2,2), name='pool3'))

  model.add(layers.SeparableConv2D(512, (3,3), activation='relu', padding='same', name='Conv4_1'))
  model.add(layers.BatchNormalization(name='bn3'))
  model.add(layers.SeparableConv2D(512, (3,3), activation='relu', padding='same', name='Conv4_2'))
  model.add(layers.BatchNormalization(name='bn4'))
  model.add(layers.SeparableConv2D(512, (3,3), activation='relu', padding='same', name='Conv4_3'))
  model.add(layers.MaxPooling2D((2,2), name='pool4'))

  model.add(layers.Flatten(name='flatten'))
  model.add(layers.Dense(1024, activation='relu', name='fc1'))
  model.add(layers.Dropout(0.7, name='dropout1'))
  model.add(layers.Dense(512, activation='relu', name='fc2'))
  model.add(layers.Dropout(0.5, name='dropout2'))
  model.add(layers.Dense(2, activation='softmax', name='fc3'))
  
  return model

In [0]:
classifier =  build_model()
#model.summary()

In [0]:
# opt = RMSprop(lr=0.0001, decay=1e-6)
opt = tf.keras.optimizers.Adam(lr=0.0001, decay=1e-5)
es = EarlyStopping(patience=5)
chkpt = ModelCheckpoint(filepath='best_model_todate', save_best_only=True, save_weights_only=True)
classifier.compile(loss='binary_crossentropy', metrics=['accuracy'],optimizer=opt)

In [0]:
classifier.load_weights('/content/gdrive/My Drive/Colab Notebooks/DepthwiseWeights')

### Testing classifier

In [0]:
# Evaluation on test dataset
test_loss, test_score = classifier.evaluate(test_data, test_labels, batch_size=16)
print("Loss on test set: ", test_loss)
print("Accuracy on test set: ", test_score)

In [0]:
# Get predictions
preds = classifier.predict(test_data, batch_size=16)
preds = np.argmax(preds, axis=-1)

# Original labels
orig_test_labels = np.argmax(test_labels, axis=-1)

#print(orig_test_labels.shape)
#print(preds.shape)

In [0]:
# Get the confusion matrix
cm = confusion_matrix(orig_test_labels, preds)
plt.figure()
#plot_confusion_matrix(cm,figsize=(12,8), hide_ticks=True, alpha=0.7,cmap=plt.cm.Blues)
plot_confusion_matrix(cm,figsize=(12,8), hide_ticks=True,cmap=plt.cm.Blues)
plt.xticks(range(2), ['Normal', 'Pneumonia'], fontsize=16)
plt.yticks(range(2), ['Normal', 'Pneumonia'], fontsize=16)
#plt.show()

In [0]:
# Calculate Precision and Recall
tn, fp, fn, tp = cm.ravel()

precision = tp/(tp+fp)
recall = tp/(tp+fn)
#print("Recall of the classifier is {:.2f}".format(recall))
#print("Precision of the classifier is {:.2f}".format(precision))

Running everything up to this point takes 3 and a half minutes.

# ExplainGAN

In [0]:
print(train_data.shape) # this is a data frame of paths to the images and their labels.
# The train data generator gets the images from the paths. This saves memory compared to saving all training images
print(valid_data.shape)
print(test_data.shape)

## Encoder

In [0]:
# Copied from https://blog.keras.io/building-autoencoders-in-keras.html, added flattening at the end
# Also changed the activation function from relu to sigmoid to not get exploding latent variables
# Maybe change this back if I train the encoder separately before training the rest of ExplainGAN. Then, also move the flattening

# Change input shape, but use only one of the three RGB channels
def make_encoder_model(activ_fct):
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(16, (3, 3), activation=activ_fct, padding='same', use_bias=True,
                                     input_shape=[224, 224, 3]))
    model.add(layers.MaxPooling2D((2,2), padding='same'))

    model.add(layers.Conv2D(8, (3, 3), activation=activ_fct, padding='same', use_bias=True))
    model.add(layers.MaxPooling2D((2,2), padding='same'))

    model.add(layers.Conv2D(8, (3, 3), activation=activ_fct, padding='same', use_bias=True))
    model.add(layers.MaxPooling2D((2,2), padding='same'))
    model.add(layers.Conv2D(8, (3, 3), activation=activ_fct, padding='same', use_bias=True))
    model.add(layers.MaxPooling2D((2,2), padding='same'))
    model.add(layers.Flatten())
    
    return model
# At this point the representation is (4, 4, 8) i.e. 128-dimensional
# The convolutions here don't decrease the 2D dimension because they have default (1,1) strides. The MaxPooling does

In [0]:
encoder0 = make_encoder_model(activ_fct='relu')
encoder1 = make_encoder_model(activ_fct='relu')
#encoder0.summary()

In [0]:
encoded = encoder0(test_data[0:1,:,:,:])

## Generators

The generator_start is the biggest part of ExplainGAN, i.e. the one with the most weights. I already ran into one OOM error with this version of the code after building ~5 models with 123 millions of weights. The vast majority of weights appears in the very first, dense, layer. 

Make this smaller somehow. Let's try to create fewer 7*7 images (in MNIST, we created 128 of these from an encoded vector of length 128). Try 128 for now. Then we have only 10 millions of weights in generator_start

In [0]:
# This generator is based on the DCGAN generator. But now, we need 3 outputs and not just one!
# Therefore tear it apart in the middle. We will use the second part three times to get recon and trafo and mask 
def make_generator_model_start():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*128, use_bias=False, input_shape=(1568,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    # Create images of size 7*7 with 128 channels, all connected to the 128 nodes of the encoded original image
    # In the next 2 lines, the tensor is actually reshaped into that channel form

    model.add(layers.Reshape((7, 7, 128)))
    assert model.output_shape == (None, 7, 7, 128) # Note: None is the batch size

    # now move 5*5*128(channels) filters over the 7*7*128(channels) images in 1,1 strides. Do this for each of the 64 output channels.
    # 64 is the number of different filters we apply. Because of the big number of input channels, each filter is already huge
    # The number of parameters here is 5*5*128 (filter weights) *64 (number of filters)
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    return model
    
 
# Now the second part. Apply this three times to get reconstruction and transformation and mask
# Because of the splitting, I also needed to add the input shape in the first layer of this second part.
def make_generator_model_end_sigmoid(): # second version to get a mask with values in [0,1] instead of [-1,1] with tanh
    # By taking strides of 2, the size of the image gets doubled in length and width.
    # This is the case, because we do a backwards convolution. If we get a 7*7 image by taking (2,2)-strides, we must have started with 14*14
    model = tf.keras.Sequential()
    model.add(layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False, input_shape=(7,7,64)))
    assert model.output_shape == (None, 14, 14, 32)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(16, (5, 5), strides=(2, 2), padding='same', use_bias=False, input_shape=(14,14,32)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(8, (5, 5), strides=(2, 2), padding='same', use_bias=False, input_shape=(28,28,16)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(8, (5, 5), strides=(2, 2), padding='same', use_bias=False, input_shape=(56,56,8)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='sigmoid'))
    assert model.output_shape == (None, 224, 224, 1)
    return model

In [0]:
generator_start = make_generator_model_start()
#generator_start.summary()
reconstructor = make_generator_model_end_sigmoid()
#reconstructor.summary()

In [0]:
gen_started = generator_start(encoded)
#gen_started # shape (1,7,7,64)

In [0]:
reconstructed = reconstructor(gen_started)
#reconstructed # shape (1,224,224,1) like a grayscale x-ray scan

## Loss function

In [0]:
# inputs of loss_recon are two images, the original one and its reconstruction 
def loss_recon(x, reconstruction):
  difference = tf.subtract(x,reconstruction)
  # reshape 224*224 to one vector in order to apply the l2 norm to it
  # The -1 in the first dimension make sure that the number of images stays the same, images shall not be combined
  difference = tf.reshape(difference, [-1,50176,1])
  return tf.math.square(tf.norm(difference, ord=2, axis=1))

In [0]:
# Testing
#loss_recon(valid_data[:,:,:,0], valid_data[:,:,:,0])

## Experimenation

In [0]:
images = valid_data
predicted_classes = classifier.predict(images) # model takes (224,224,3) as input
predicted_classes = np.argmax(predicted_classes, axis=1) # now we have the actual predictions
predicted_classes = tf.reshape(predicted_classes, [-1,1]) # make it compatible with tensors below
predicted_classes = tf.cast(predicted_classes, tf.float32) # change type to float for multiplications
predicted_classes

images_grayscale = images[:,:,:,0:1]
#images_grayscale.shape

In [0]:
enc0_optimizer = tf.keras.optimizers.Adam(1e-4)
enc1_optimizer = tf.keras.optimizers.Adam(1e-4)
gen_start_optimizer = tf.keras.optimizers.Adam(1e-4)
recon_optimizer = tf.keras.optimizers.Adam(1e-4)

In [35]:
with tf.GradientTape() as enc0_tape, tf.GradientTape() as enc1_tape, tf.GradientTape() as gen_start_tape, tf.GradientTape() as recon_tape:
  z_as0 = encoder0(images)
  z_as1 = encoder1(images)
  gen_from_pred0 = generator_start(z_as0)
  gen_from_pred1 = generator_start(z_as1)
  recon_from_pred0 = reconstructor(gen_from_pred0)
  recon_from_pred1 = reconstructor(gen_from_pred1)
  print("recon0 shape:", recon_from_pred0.shape)
  print("recon1 shape:", recon_from_pred1.shape)
  print("images_grayscale shape:", images_grayscale.shape)

  loss_recon0 = loss_recon(images_grayscale, recon_from_pred0)
  loss_recon1 = loss_recon(images_grayscale, recon_from_pred1)
  print("loss_recon0 shape:", loss_recon0.shape)
  loss_r0 = tf.math.multiply(1-predicted_classes, loss_recon0) # set the loss for the wrong recons to 0
  loss_r1 = tf.math.multiply(predicted_classes, loss_recon1) # set the loss for the wrong recons to 0
  loss_r = (loss_r0 + loss_r1)*0.0005
  loss_summed = loss_r
  print("loss_summed shape:", loss_summed.shape)

  gradients_of_enc0 = enc0_tape.gradient(loss_summed, encoder0.trainable_variables)
  gradients_of_enc1 = enc1_tape.gradient(loss_summed, encoder1.trainable_variables)
  gradients_of_gen_start = gen_start_tape.gradient(loss_summed, generator_start.trainable_variables)
  gradients_of_recon = recon_tape.gradient(loss_summed, reconstructor.trainable_variables)

  enc0_optimizer.apply_gradients(zip(gradients_of_enc0, encoder0.trainable_variables))
  enc1_optimizer.apply_gradients(zip(gradients_of_enc1, encoder1.trainable_variables))
  gen_start_optimizer.apply_gradients(zip(gradients_of_gen_start, generator_start.trainable_variables))
  recon_optimizer.apply_gradients(zip(gradients_of_recon, reconstructor.trainable_variables))

  print('loss_r: ', tf.reduce_max(loss_r), tf.reduce_min(loss_r), tf.reduce_mean(loss_r))

recon0 shape: (16, 224, 224, 1)
recon1 shape: (16, 224, 224, 1)
images_grayscale shape: (16, 224, 224, 1)
loss_recon0 shape: (16, 1)
loss_summed shape: (16, 1)
loss_r:  tf.Tensor(2.433419, shape=(), dtype=float32) tf.Tensor(1.2088357, shape=(), dtype=float32) tf.Tensor(1.649849, shape=(), dtype=float32)


In [0]:
# Changes compared to the DCGAN version: test_input is the first real image(s) instead of a random seed
def generate_and_save_images(epoch, test_input, index):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  original = test_input[index:index+1,:,:,:]
  prediction = classifier.predict(original)
  prediction = np.argmax(prediction, axis=1)
  if prediction == 0:
    z = encoder0(original, training=False)
  if prediction == 1:
    z = encoder1(original, training=False)
  gen_from_pred = generator_start(z, training=False)
  recon_from_pred = reconstructor(gen_from_pred, training=False)
  #trafo_from_pred = transformator(gen_from_pred, training=False)
  #mask_from_pred = mask(gen_from_pred, training=False)
  #comp_from_pred = create_composite(original, trafo_from_pred, mask_from_pred)
  
  fig = plt.figure(figsize=(10,10))
  plt.subplot(1, 2, 1)
  plt.imshow(original[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  plt.axis('off')
  plt.subplot(1, 2, 2)
  plt.imshow(recon_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  plt.axis('off')
  #plt.subplot(1, 5, 3)
  #plt.imshow(trafo_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  #plt.axis('off')
  #plt.subplot(1, 5, 4)
  #plt.imshow(mask_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  #plt.axis('off')
  #plt.subplot(1, 5, 5)
  #plt.imshow(comp_from_pred[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
  #plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

In [0]:
set = test_data
#generate_and_save_images(1,set,index=20)
#generate_and_save_images(1,set,index=21)
#generate_and_save_images(1,set,index=22)

##Training

In [0]:
def train_step(images):
  predicted_classes = classifier.predict(images) # model takes (224,224,3) as input
  predicted_classes = np.argmax(predicted_classes, axis=1) # now we have the actual predictions
  predicted_classes = tf.reshape(predicted_classes, [-1,1]) # make it compatible with tensors below
  predicted_classes = tf.cast(predicted_classes, tf.float32) # change type to float for multiplications
  
  images_grayscale = images[:,:,:,0:1] #needed for loss_computations

  with tf.GradientTape() as enc0_tape, tf.GradientTape() as enc1_tape, tf.GradientTape() as gen_start_tape, tf.GradientTape() as recon_tape:
    z_as0 = encoder0(images)
    z_as1 = encoder1(images)
    gen_from_pred0 = generator_start(z_as0)
    gen_from_pred1 = generator_start(z_as1)
    recon_from_pred0 = reconstructor(gen_from_pred0)
    recon_from_pred1 = reconstructor(gen_from_pred1)
    print("recon0 shape:", recon_from_pred0.shape)
    print("recon1 shape:", recon_from_pred1.shape)
    print("images_grayscale shape:", images_grayscale.shape)

    loss_recon0 = loss_recon(images_grayscale, recon_from_pred0)
    loss_recon1 = loss_recon(images_grayscale, recon_from_pred1)
    print("loss_recon0 shape:", loss_recon0.shape)
    loss_r0 = tf.math.multiply(1-predicted_classes, loss_recon0) # set the loss for the wrong recons to 0
    loss_r1 = tf.math.multiply(predicted_classes, loss_recon1) # set the loss for the wrong recons to 0
    loss_r = (loss_r0 + loss_r1)*0.0005
    loss_summed = loss_r
    print("loss_summed shape:", loss_summed.shape)

    gradients_of_enc0 = enc0_tape.gradient(loss_summed, encoder0.trainable_variables)
    gradients_of_enc1 = enc1_tape.gradient(loss_summed, encoder1.trainable_variables)
    gradients_of_gen_start = gen_start_tape.gradient(loss_summed, generator_start.trainable_variables)
    gradients_of_recon = recon_tape.gradient(loss_summed, reconstructor.trainable_variables)

    enc0_optimizer.apply_gradients(zip(gradients_of_enc0, encoder0.trainable_variables))
    enc1_optimizer.apply_gradients(zip(gradients_of_enc1, encoder1.trainable_variables))
    gen_start_optimizer.apply_gradients(zip(gradients_of_gen_start, generator_start.trainable_variables))
    recon_optimizer.apply_gradients(zip(gradients_of_recon, reconstructor.trainable_variables))

    print('loss_r: ', tf.reduce_max(loss_r), tf.reduce_min(loss_r), tf.reduce_mean(loss_r))

In [40]:
# Checking, seems to work just fine. Loss changes from step to step
#train_step(test_data) # produces OOM error
train_step(valid_data) # still works

recon0 shape: (16, 224, 224, 1)
recon1 shape: (16, 224, 224, 1)
images_grayscale shape: (16, 224, 224, 1)
loss_recon0 shape: (16, 1)
loss_summed shape: (16, 1)
loss_r:  tf.Tensor(2.4333916, shape=(), dtype=float32) tf.Tensor(1.2088183, shape=(), dtype=float32) tf.Tensor(1.6498413, shape=(), dtype=float32)


In [0]:
# Copied from the classifier's training. NAIN then just calls model.fit with train_data_gen as input.
# This won't work 1 to 1 here, since ExplainGAN is not one sequential model
# Nevertheless, I want to use the generator to avoid OOM errors and save time when training
batch_size = 16
nb_epochs = 10

# Get a train data generator
train_data_gen = data_gen(data=train_data, batch_size=batch_size)

# Define the number of training steps
nb_train_steps = train_data.shape[0]//batch_size

print("Number of training and validation steps: {} and {}".format(nb_train_steps, len(valid_data)))

In [0]:
train_step(train_data_gen) # I don't know if this works. It was running over 1 hour without finishing

In [0]:
def train(dataset, epochs, weight_g, weight_c, weight_r, weight_cs, weight_ct, weight_sm, weight_en, kappa, pretraining_flag):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      # Here we use the batching of the dataset below
      # call the function train_step defined in the box above this one
      train_step(image_batch, weight_g, weight_c, weight_r, weight_cs, weight_ct, weight_sm, weight_en, kappa, pretraining_flag)

    # Produce images for the GIF as we go (from DCGAN)
    #display.clear_output(wait=True)
    generate_and_save_images(epoch + 1, train_images, index=0) # still input train_images here. Would be nice to use dataset instead!!
    
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  #display.clear_output(wait=True)
  generate_and_save_images(epochs, train_images, index=0)

In [0]:
fig = plt.figure(figsize=(18,6))
plt.subplot(1, 3, 1)
plt.imshow(valid_data[1, :, :, 0], cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(valid_data[2, :, :, 1], cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(valid_data[3, :, :, 0], cmap='gray')
plt.axis('off')