In [24]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import Concatenate, Lambda # batch discrimination
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam

import keras.backend as K

import matplotlib.pyplot as plt

import numpy as np

import skimage.measure

import matplotlib.pyplot as plt

# Build One-Class GAN

In [9]:
def add_common_layers(y, set_alpha=0.3):
    #y = BatchNormalization(momentum=0.8)(y)
    y = BatchNormalization()(y)
    #y = Dropout(0.25)(y)
    y = LeakyReLU(alpha=set_alpha)(y)

    return y

In [16]:
data = np.load("/Users/alecx/Downloads/AWS-LESIONDATA-2019_v2.npz")

X_train = data["imageList"]
y_train = data["targetList"]
imageValList = data["imageValList"]
targetValList = data["targetValList"]
testList = data["testList"]
targetTestList = data["targetTestList"]

In [19]:
from skimage.transform import resize

In [36]:
class OCC():
    def __init__(self):
        # Input shape
        self.img_rows = 32
        self.img_cols = 32
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 3
        self.latent_dim = 128

        optimizer = Adam(0.0002, 0.5)
        losses = ['binary_crossentropy']

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=losses,
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input
        # and generates the image
        noise = Input(shape=(self.latent_dim,))
        img = self.generator([noise])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated image as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model([noise], [valid])
        self.combined.compile(loss=losses,
            optimizer=optimizer)

    def build_generator(self):
        noise = Input(shape=(self.latent_dim,))

        i = Reshape((1, 1, 128), input_shape=(128,))(noise)
        
        i = Conv2DTranspose(128, kernel_size=4, padding="valid")(i)
        i = add_common_layers(i)
        
        i = UpSampling2D(size=(2, 2))(i)
        i = Conv2D(64, kernel_size=3, strides=1, padding="same")(i)
        i = add_common_layers(i, set_alpha=0)
        
        i = UpSampling2D(size=(2, 2))(i)
        i = Conv2D(32, kernel_size=3, strides=1, padding="same")(i)
        i = add_common_layers(i, set_alpha=0)
        
        i = UpSampling2D(size=(2, 2))(i)
        i = Conv2D(16, kernel_size=3, strides=1, padding="same")(i)
        i = add_common_layers(i, set_alpha=0)
        
        i = Conv2D(3, kernel_size=3, strides=1, padding="same")(i)
        #i = LeakyReLU(alpha=0)(i)

        img = Activation("tanh")(i)
        
        
        model = Model([noise], img)
        model.summary()

        return model

    def build_discriminator(self):
        # batch_diversity = batch discrimination
        img = Input(shape=self.img_shape)

        i = Conv2D(16, kernel_size=3, strides=1, input_shape=(32,32,3), padding="same")(img)
        i = add_common_layers(i)
        batch_div = Lambda(lambda x:K.mean(K.abs(x[:] - K.mean(x,axis=0)),axis=-1,keepdims=True))(i)
        i = Concatenate()([i, batch_div])
        
        i = Conv2D(32, kernel_size=3, strides=2, padding="same")(i)
        i = add_common_layers(i)
        batch_div = Lambda(lambda x:K.mean(K.abs(x[:] - K.mean(x,axis=0)),axis=-1,keepdims=True))(i)
        i = Concatenate()([i, batch_div])
        
        i = Conv2D(64, kernel_size=3, strides=2, padding="same")(i)
        i = add_common_layers(i)
        batch_div = Lambda(lambda x:K.mean(K.abs(x[:] - K.mean(x,axis=0)),axis=-1,keepdims=True))(i)
        i = Concatenate()([i, batch_div])
        
        i = Conv2D(128, kernel_size=3, strides=2, padding="same")(i)
        #i = LeakyReLU(alpha=0.2)(i)
        #i = Dropout(0.25)(i)

        i = Flatten()(i)
        
        # Determine validity of the image
        validity = Dense(1, activation="sigmoid")(i)
        
        model = Model(img, [validity])
        model.summary()  

    
        return model


    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        data = np.load("/Users/alecx/Downloads/AWS-LESIONDATA-2019_v2.npz")

        X_train = data["imageList"]
        y_train = data["targetList"]
        imageValList = data["imageValList"]
        targetValList = data["targetValList"]
        testList = data["testList"]
        targetTestList = data["targetTestList"]

        # Configure inputs
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        
        # Resize X_train from 224x224 to 32x32
        # with average pooling from skimage blockreduce

        X_train_small = np.empty((7133, 32, 32, 3))
        for i in range(X_train_small.shape[0]):
            X_train_small[i] = resize(X_train[i], output_shape=(32, 32, 3))
    
        X_train = X_train_small
        

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise])

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, [valid])
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, [fake])
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator
            g_loss = self.combined.train_on_batch([noise], [valid])

            # Plot the progress
            print (epoch, d_loss, g_loss)

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                #self.save_model(epoch)
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 10, 3
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict([noise])
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c, figsize=(10,30))
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,:])
                axs[i,j].axis('off')
                cnt += 1
        fig.tight_layout()
        fig.savefig("images/%d.png" % epoch)
        plt.close()

    def save_model(self, epoch):

        def save(model, model_name, epoch):
            model_path = "saved_model/%s_%d.json" % (model_name, epoch)
            weights_path = "saved_model/%s_weights_%d.hdf5" % (model_name, epoch)
            options = {"file_arch": model_path,
                        "file_weight": weights_path}
            json_string = model.to_json()
            open(options['file_arch'], 'w').write(json_string)
            model.save_weights(options['file_weight'])

        save(self.generator, "generator", epoch)
        save(self.discriminator, "discriminator", epoch)


In [37]:
occ = OCC()
occ.train(epochs=100000, batch_size=64, sample_interval=1000)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_19 (InputLayer)           (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv2d_53 (Conv2D)              (None, 32, 32, 16)   448         input_19[0][0]                   
__________________________________________________________________________________________________
batch_normalization_46 (BatchNo (None, 32, 32, 16)   64          conv2d_53[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_46 (LeakyReLU)      (None, 32, 32, 16)   0           batch_normalization_46[0][0]     
__________________________________________________________________________________________________
lambda_22 

  warn("The default mode, 'constant', will be changed to 'reflect' in "
  'Discrepancy between trainable weights and collected trainable'


0 [0.9212393 0.359375 ] 0.6719533
1 [0.68224585 0.65625   ] 0.8296698
2 [0.61113024 0.6640625 ] 0.96828926
3 [0.5218483 0.78125  ] 1.179438
4 [0.45566744 0.859375  ] 1.3091862
5 [0.43140233 0.8515625 ] 1.31672
6 [0.36857897 0.921875  ] 1.5564983
7 [0.3866384 0.875    ] 1.3892317
8 [0.3432604 0.8984375] 1.4150665
9 [0.3878529 0.8828125] 1.420162
10 [0.34411582 0.890625  ] 1.5858313
11 [0.3701056 0.890625 ] 1.4807813
12 [0.41221732 0.859375  ] 1.6644444
13 [0.3466434 0.8671875] 1.5556135
14 [0.38239533 0.8828125 ] 1.6251826
15 [0.3296238 0.921875 ] 1.5792263
16 [0.332711  0.9140625] 1.5656238
17 [0.31870067 0.921875  ] 1.548876
18 [0.23804212 0.96875   ] 1.5868287
19 [0.2875631 0.953125 ] 1.4047683
20 [0.35725984 0.8984375 ] 1.1994662
21 [0.27530676 0.921875  ] 1.2269994
22 [0.307367 0.921875] 1.0449368
23 [0.20231041 0.9609375 ] 1.0767581
24 [0.2525626 0.921875 ] 0.93676054
25 [0.31510296 0.8984375 ] 0.8761375
26 [0.25830448 0.9296875 ] 0.8306165
27 [0.23616098 0.9375    ] 0.72782564
28

KeyboardInterrupt: 