In [1]:
import tensorflow as tf

import os
import pathlib
import time
import datetime

from matplotlib import pyplot as plt
from IPython import display
import numpy as np
from skimage.util import img_as_float

In [12]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [13]:
class ConvBlock(tf.keras.layers.Layer):
    """
    Defines a block used in the UNET and PatchGAN architectures. 
    It consists of a Conv2D layer with kernel size 4 and stride 2, then followed by batchNorm, maybe dropout, then RELU. 
    
    """
    def __init__(self, numFilters, BN, Dropout):
        super().__init__()
        self.kernel_size = (4,4)
        self.BN = BN
        self.Dropout = Dropout
        self.stride= (2,2)
        self.numFilters = numFilters

        #WHICH PADDING TO USE?????
        self.padding = "same"
        kernelInitializer = tf.keras.initializers.RandomNormal(mean=0, stddev = .02)
        self.conv = tf.keras.layers.Conv2D(numFilters, self.kernel_size, self.stride, padding = self.padding, kernel_initializer = kernelInitializer)
        self.batchNorm = tf.keras.layers.BatchNormalization()
        self.dropout = tf.keras.layers.Dropout(.5)
        self.relu = tf.keras.layers.LeakyReLU(.2)

    def call(self, input):
        
        batchSize, height, width, numChannels = input.shape
        convOutput = self.conv(input)
        newHeight, newWidth = self.calcShape(height, width)
       
        assert(convOutput.shape == (batchSize, newHeight, newWidth, self.numFilters))
        if(self.BN):
            convOutput = self.batchNorm(convOutput)
        if(self.Dropout):
            convOutput = self.dropout(convOutput)
        activated = self.relu(convOutput)
        return activated

    def calcShape(self, height, width):
        """
        Calculates the shape of the output of this layer given the input shape. 
        """
        fh, fw = self.kernel_size            ## filter height & width
        sh, sw = self.stride       ## filter stride
        # Cleaning padding input.
        ry = height%sh
        rx = width %sw
        if(self.padding == "same"):
            valueHeight = fh- ry - sh*int(not ry)
            heightPad = max(valueHeight, 0)
            #same here. 
            valueWidth = fw-rx -sw*int(not rx)
            widthPad = max(valueWidth, 0)
            #heightPad and width pad are total amount you should pad, so get left and right pad here. 
        else:
            heightPad, widthPad = 0,0
        outputHeight = (height + heightPad - fh)//sh + 1
        outputWidth = (width + widthPad - fw)//sw + 1
        return outputHeight, outputWidth
class ConvTBlock(ConvBlock):
    def __init__(self, numFilters, BN, Dropout):
        super().__init__(numFilters, BN, Dropout)
        #only one thing renamed. 
        #no output padding gets the right results. 
        kernelInitializer = tf.keras.initializers.RandomNormal(mean = 0, stddev = .02)
        self.conv = tf.keras.layers.Conv2DTranspose(numFilters, self.kernel_size, self.stride, padding = self.padding,  kernel_initializer = kernelInitializer)
    """
    def call(self, input):
        
        batchSize, height, width, numChannels = input.shape
        convOutput = self.conv(input)
        newHeight, newWidth = self.calcShape(height, width)
       
        assert(convOutput.shape == (batchSize, newHeight, newWidth, self.numFilters))
        if(self.BN):
            convOutput = self.batchNorm(convOutput)
        if(self.Dropout):
            convOutput = self.dropout(convOutput)
        activated = self.relu(convOutput)
        return activated
    """
    def calcShape(self, height, width):
        """
        Calculates shape of layer output in this case, it's different than the ConvBlock class. 
        """
        #in case of same padding. In reality more complicated than this, but can't figure it out rn. 
        shape = (self.stride[0]*height, self.stride[1]*width)
        return shape

In [14]:
class UNet(tf.keras.layers.Layer):
    """
    UNET is a potential choice for the Generator Architecture as posed by the paper. 
    It's an encoder-decoder architecture with skip connections between layer i in the encoder and layer n-i in the decoder. 
    They concatenate activations from layer i to layer n-i. 
    Ck = Convolution BatchNorm LeakyRelu layer with k filters. With the D, it means dropout of 50%. 
    All convolutions 4x4 spatial filters with stride 2. 
    Convolutions in encoder - downsample by factor of 2. In encoder they updsample by a factor of 2. 
    The encoder structure is: 
    C64-C128-C256-C512-C512-C512-C512-C512
    The decoder structure is: 
    CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
    """
    def __init__(self):
        super().__init__()
        self.encblock1 = ConvBlock(64, False, False)
        self.encblock2 = ConvBlock(128, True, False)
        self.encblock3 = ConvBlock(256, True, False)
        self.encblock4 = ConvBlock(512, True, False)
        self.encblock5 = ConvBlock(512, True, False)
        self.encblock6 = ConvBlock(512, True, False)
        self.encblock7 = ConvBlock(512, True, False)
        self.encblock8 = ConvBlock(512, True, False)

        self.decblock1 = ConvTBlock(512, True, True)
        self.decblock2 = ConvTBlock(1024, True, True)
        self.decblock3 = ConvTBlock(1024, True, True)
        self.decblock4 = ConvTBlock(1024, True, False)
        self.decblock5 = ConvTBlock(1024, True, False)
        self.decblock6 = ConvTBlock(512, True, False)
        self.decblock7 = ConvTBlock(256, True, False)
        self.decblock8 = ConvTBlock(128, True, False)
        #I think this is right. 
        
    def call(self, input):
        """
        Include residual connections with the encoder blocks in the decoder. 
        Want connections between layer i and layer n-i. So, layer 7 and 9, 6 and 10 etc. Concatenate along the channels axis. 
        """
        #print("UNET input shape: ", input.shape)
        block1 = self.encblock1(input)
        #print("UNET block 1 shape: ", block1.shape)
        block2 = self.encblock2(block1)
        #print("UNET block 2 shape: ", block2.shape)
        block3 = self.encblock3(block2)
        #print("UNET block 3 shape: ", block3.shape)
        block4 = self.encblock4(block3)
        #print("UNET block 4 shape: ", block4.shape)
        block5 = self.encblock5(block4)
        #print("UNET block 5 shape: ", block5.shape)
        block6 = self.encblock6(block5)
        #print("UNET block 6 shape: ", block6.shape)
        block7 = self.encblock7(block6)
        #print("UNET block 7 shape: ", block7.shape)
        block8 = self.encblock8(block7)
        #print("UNET block 8 shape: ", block8.shape)

        #finished encoder. 
        block9 = self.decblock1(block8)
        #print("UNET block 9 shape: ", block9.shape)
        #I think we want to do 7 and 16-7 = 9
        combinedBlock9 = tf.concat([block7, block9], axis=-1)
        block10 = self.decblock2(combinedBlock9)
        #print("UNET block 10 shape: ", block10.shape)
        combinedBlock10 = tf.concat([block6, block10], axis=-1)
        block11 = self.decblock3(combinedBlock10)
        #print("UNET block 11 shape: ", block11.shape)
        combinedBlock11 = tf.concat([block5, block11], axis=-1)
        block12 = self.decblock4(combinedBlock11)
        #print("UNET block 12 shape: ", block12.shape)
        combinedBlock12 = tf.concat([block4, block12], axis=-1)
        block13 = self.decblock5(combinedBlock12)
        #print("UNET block 13 shape: ", block13.shape)
        combinedBlock13 = tf.concat([block3, block13], axis=-1)
        block14 = self.decblock6(combinedBlock13)
        #print("UNET block 14 shape: ", block14.shape)
        combinedBlock14 = tf.concat([block2, block14], axis=-1)
        block15 = self.decblock7(combinedBlock14)
        #print("UNET block 15 shape: ", block15.shape)
        combinedBlock15 = tf.concat([block1, block15], axis=-1)
        block16 = self.decblock8(combinedBlock15)
        #print("UNET block 16 shape: ", block16.shape)

        #block16 is the output of the UNET, but we add a thing at the end of it as well. 
        
        return block16

In [15]:
class EncDec(tf.keras.layers.Layer):

    def __init__(self):
        super().__init__()
        self.encblock1 = ConvBlock(64, False, False)
        self.encblock2 = ConvBlock(128, True, False)
        self.encblock3 = ConvBlock(256, True, False)
        self.encblock4 = ConvBlock(512, True, False)
        self.encblock5 = ConvBlock(512, True, False)
        self.encblock6 = ConvBlock(512, True, False)
        self.encblock7 = ConvBlock(512, True, False)
        self.encblock8 = ConvBlock(512, True, False)

        self.decblock1 = ConvTBlock(512, True, True)
        self.decblock2 = ConvTBlock(512, True, True)
        self.decblock3 = ConvTBlock(512, True, True)
        self.decblock4 = ConvTBlock(512, True, False)
        self.decblock5 = ConvTBlock(512, True, False)
        self.decblock6 = ConvTBlock(256, True, False)
        self.decblock7 = ConvTBlock(64, True, False)
       
    def call(self, input):
        block1 = self.encblock1(input)
        block2 = self.encblock2(block1)
        block3 = self.encblock3(block2)
        block4 = self.encblock4(block3)
        block5 = self.encblock5(block4)
        block6 = self.encblock6(block5)
        block7 = self.encblock7(block6)
        block8 = self.encblock8(block7)

        #finished encoder. 
        block9 = self.decblock1(block8)
        #I think we want to do 7 and 16-7 = 9
      
        block10 = self.decblock2(block9)
      
        block11 = self.decblock3(block10)

        block12 = self.decblock4(block11)
    
        block13 = self.decblock5(block12)
       
        block14 = self.decblock6(block13)
      
        block15 = self.decblock7(block14)

        return block15

In [16]:
class Generator(tf.keras.Model):
    """
    Generator for the GAN. 
    We want to find a way to increase the stochasticity of outputs. The paper used dropout to simulate randomness, but 
    it isn't very random. However, they found that the network learns to IGNORE the input noise. we need a way to make the input noise more 
    prominent. 
    """
    def __init__(self, reg_coeff):
        super().__init__()
        self.u = True
        self.l1 = False
        if(self.u):
          self.UNet = UNet()
        else:
          self.UNet = EncDec()
        #Regularization coefficient for L1 loss. 
        self.reg_coeff = reg_coeff
        if self.u:
          self.lastConvolution = tf.keras.layers.Conv2D(3, (4,4), (1,1), padding = "same", activation = "tanh")
        else:
          self.lastConvolution = tf.keras.layers.Conv2DTranspose(3, (4,4), (2,2), padding = "same", activation = "tanh")
        #range -1 to 1. 
        self.tanh = tf.keras.activations.tanh
    def call(self, data, training):
        
        batchSize, height, width, numChannels = data.shape
        uNetOutput = self.UNet(data)
        if self.u:
          assert(uNetOutput.shape == (batchSize, height, width, 128))
        else:
          assert(uNetOutput.shape == (batchSize, int(height/2), int(width/2), 64))
        generated = self.lastConvolution(uNetOutput)
        print(generated.shape)
        print(data.shape)
        assert(generated.shape == (batchSize, height, width, 3))
      
        
        assert(generated.shape[0:3] == data.shape[0:3])
        return generated

    def compute_loss(self, combined,  genPred, genReal, sample_weight=None):
        """
        Want binary crossentropy with L1 regularization. 
        """
        generated = combined[0]
        x = combined[1]
        difference = generated-x
        realY = tf.cast(tf.logical_not(tf.cast(0*genPred, bool)), tf.int32)
        #calls the loss function passed into the compiler. 
        lossDefault = self.compiled_loss(realY, genPred, sample_weight)
        #penalty gets a scalar value instead of a batchSize tensor. 
        l1 = tf.reduce_sum(tf.abs(difference), axis = [1,2,3])
        if self.l1:
          return lossDefault + self.reg_coeff*l1
        else:
          return lossDefault

In [17]:
class PatchGAN(tf.keras.layers.Layer):
    """
    Layer called from the discriminator model. Pretty much does all the work of the discriminator though. Could make this a model. 
    Basically performs classification on each patch of the input image as real or fake. Then, you calculate the loss over that grid
    instead of just one classification value. 
    """
    def __init__(self, size):
        super().__init__()
        #assuming we have a (size,size) discriminator

        #FORMULA FOR RECEPTIVE FIELD SIZE: 
        #r = sum l= 1 to L (kl-1)*Prod i = 1 to l-1 (si) + 1
        #for our example, since kl = 4, si = 2 for all layers except the last, and
        
        #specifically for a model structure like this, with n -1 conv blocks of (4,4) and 2 stride, 
        #and one final layer of stride 1 kernel 4, the formula is: 
        # r = -2 + 9*2^(L-2)
        
        self.conv1 = ConvBlock(64, False, False)
        self.conv2 = ConvBlock(128, True, False)

        """
        #add this if want 70x70 patch gan. 
        #self.conv3 = ConvBlock(256, True, False)
        #self.conv4 = ConvBlock(512, True, False)
        """
        kernelInitializer = tf.keras.initializers.RandomNormal(mean= 0, stddev = .02)
        #forgot a padding = same on first run
        self.lastConvLayer = tf.keras.layers.Conv2D(1, kernel_size = (4,4), strides = (2,2), padding = "same", activation = "sigmoid", kernel_initializer=kernelInitializer)
    def call(self, input):
        #print("PG input shape: ", input.shape)
        output = self.conv1(input)
        #print("PG layer1 shape: ", layer1.shape)
        output = self.conv2(output)
        """
        #add this if want 70x70 patch gan. 
        #print("PG layer2 shape: ", layer2.shape)
        #layer3 = self.conv3(layer2)
        #print("PG layer3 shape: ", layer3.shape)
        #layer4 = self.conv4(layer3)
        ##print("PG layer4 shape: ", layer4.shape)
        """

        #this has the shape of batchSize x numPatches x numPatches x 1 
        output = self.lastConvLayer(output)
        #print("Output shape:  ", output.shape)
        
        return output

In [18]:
class NotAPatchGAN(tf.keras.layers.Layer):
    """
    Layer called from the discriminator model. Pretty much does all the work of the discriminator though. Could make this a model. 
    Basically performs classification on each patch of the input image as real or fake. Then, you calculate the loss over that grid
    instead of just one classification value. 
    """
    def __init__(self, size):
        super().__init__()
        #assuming we have a (size,size) discriminator

        #FORMULA FOR RECEPTIVE FIELD SIZE: 
        #r = sum l= 1 to L (kl-1)*Prod i = 1 to l-1 (si) + 1
        #for our example, since kl = 4, si = 2 for all layers except the last, and
        
        #specifically for a model structure like this, with n -1 conv blocks of (4,4) and 2 stride, 
        #and one final layer of stride 1 kernel 4, the formula is: 
        # r = -2 + 9*2^(L-2)
        kernelInitializer = tf.keras.initializers.RandomNormal(mean= 0, stddev = .02)
        self.conv1 = ConvBlock(64, False, False)
        self.conv2 = ConvBlock(128, True, False)
        self.lastConvLayer = tf.keras.layers.Conv2D(1, kernel_size = (4,4), strides = (2,2), padding = "same",  kernel_initializer=kernelInitializer)
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(1, activation = "sigmoid")

        """
        #add this if want 70x70 patch gan. 
        #self.conv3 = ConvBlock(256, True, False)
        #self.conv4 = ConvBlock(512, True, False)
        """
        

    def call(self, input):
        #print("PG input shape: ", input.shape)
        output = self.conv1(input)
        #print("PG layer1 shape: ", layer1.shape)
        output = self.conv2(output)
        """
        #add this if want 70x70 patch gan. 
        #print("PG layer2 shape: ", layer2.shape)
        #layer3 = self.conv3(layer2)
        #print("PG layer3 shape: ", layer3.shape)
        #layer4 = self.conv4(layer3)
        ##print("PG layer4 shape: ", layer4.shape)
        """
        output = self.lastConvLayer(output)
        #this has the shape of batchSize x numPatches x numPatches x 1 
        output = self.flatten(output)
        output = self.dense(output)
        #print("Output shape:  ", output.shape)
        
        return output

In [19]:
class Discriminator(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.patch = True
        if self.patch:
          self.patchGAN = PatchGAN(16)
        else:
          self.patchGAN = NotAPatchGAN(16)

    def call(self, inputs, training):
        """
        Integrated conditionality via concatenation. 
        """
        data = inputs[0]
        condition = inputs[1]
        #Concatenate the data to the condition to take both into account in the network. 
        concatenated = tf.concat([data, condition], axis=-1)
        return self.patchGAN(concatenated)

    def compute_loss(self, combined, predGen, predReal, sample_weights=None):
        """
        Generates the labels and computes the loss. 
        Computes the two losses separately and then sums them together. 
        """
        #these aren't really used here. 
        generated = combined[0]
        x = combined[1]
        #generate the labels here instead of earlier before. 
        if(self.patch):
          assert(predGen.shape == (generated.shape[0], 32, 32, 1))
        else:
          assert(predGen.shape == (generated.shape[0], 1))
        realLabels = tf.cast(tf.logical_not(tf.cast(0*predReal, bool)), tf.int32)
        genLabels = tf.cast(0*predGen, tf.int32)

        realLoss = self.compiled_loss(realLabels, predReal, sample_weights)
        genLoss = self.compiled_loss(genLabels, predGen, sample_weights)
        return realLoss+genLoss

In [20]:
class GAN(tf.keras.Model):
    def __init__(self):
        super().__init__()
        reg_coeff = 100
        self.generator = Generator(reg_coeff = reg_coeff)
        self.discriminator = Discriminator()

    def call(self, data, training):
        X = data[0]
        Y = data[1]
        #generate data
        generated = self.generator(Y, training) 
        #run discriminator on generated examples. 
        genPred = self.discriminator((generated, Y))
        #in case where want gradients of generator, don't need this. 
        realPred = None
        #in case where you're calculating gradients of discriminator, run the real predictions. 
        if(X is not None):
            #run discriminator on real examples. 
            realPred = self.discriminator((X, Y))
        #return generated examples, predictions of generated examples, and predictions of real examples
        return generated, genPred, realPred
    
    def batch_step(self, data, training):
        """
        Called from both train step and test step, makes the methods simpler by keeping all the code in one place. 
        X - real examples of images. 
        Y - outlines of those specific real examples. This is the conditional input. 
        """
        X = data[0]
        Y = data[1]
        #print("batch step eager? :",tf.executing_eagerly())
        #calculate discriminator gradients first. 
        with tf.GradientTape() as disTape: 
            #forward pass. 
            generated, predGen, predReal= self(data, training)
            discriminatorLoss = self.discriminator.compute_loss((generated, X), predGen, predReal)
            self.dLoss.update_state(discriminatorLoss)
        if(training):
            #if training, calculate gradients and update weights. 
            disGrad = disTape.gradient(discriminatorLoss, self.discriminator.trainable_variables)
            self.discriminator.optimizer.apply_gradients(zip(disGrad, self.discriminator.trainable_variables))
        #now once those are updated, calculate generator weights. 
        with tf.GradientTape() as genTape:
            #none indicates to the call function we're working with the generator. 
            generated, genPredict, predReal = self((None, Y), training)
            #is there a way to make this better? 
            assert(predReal is None)
            #calculate the loss of the generator. Don't need real predictions. 
            generatorLoss = self.generator.compute_loss((generated, X), genPredict, predReal)
            self.gLoss.update_state(generatorLoss)
        if(training):
            print("about to update gradients")
            genGrad = genTape.gradient(generatorLoss, self.generator.trainable_variables)
            self.generator.optimizer.apply_gradients(zip(genGrad, self.generator.trainable_variables))
        self.updateStates(not training, generatorLoss, discriminatorLoss)
        print("end of batch step")
        return self.evalMetrics(training)

    def compile(self, optimizerGen, optimizerDis, lossFxnGen, lossFxnDis, metrics = None, steps_per_execution = 1):

        super().compile(steps_per_execution = steps_per_execution, metrics = metrics)
        self.generator.compile(optimizerGen, lossFxnGen)
        self.discriminator.compile(optimizerDis, lossFxnDis)
        #maybe add metrics here. 
        #self.createMetrics()

    @tf.function
    def train_step(self, data):
        return self.batch_step(data, True)
    @tf.function
    def test_step(self, data):
        return self.batch_step(data, False)

    def generateImages(self, Y):
        """
        Generate some images from the conditions. 
        """
        generated = self.generator(Y, training = False)
        return generated
    def createMetrics(self):
        self.dLoss = tf.keras.metrics.Mean(name = "dLoss")
        self.gLoss = tf.keras.metrics.Mean(name = "gLoss")
        self.sumLoss = tf.keras.metrics.Mean(name = "sumLoss")
        self.dValLoss = tf.keras.metrics.Mean(name = "valDLoss")
        self.gValLoss = tf.keras.metrics.Mean(name = "valGLoss")
        self.valSumLoss = tf.keras.metrics.Mean(name = "valSumLoss")
    
        self.listMetricsTrain = [self.dLoss, self.gLoss, self.sumLoss]
        self.listMetricsTest = [self.dValLoss, self.gValLoss, self.valSumLoss]
        self.listMetrics = self.listMetricsTrain + self.listMetricsTest
        return self.listMetrics
    def updateStates(self, val, gLoss, dLoss):
        if val:
            self.dValLoss.update_state(dLoss)
            self.gValLoss.update_state(gLoss)
            self.valSumLoss.update_state(dLoss+gLoss)

        else:
            self.dLoss.update_state(dLoss)
            self.gLoss.update_state(gLoss)
            self.sumLoss.update_state(dLoss + gLoss)
    def resetStates(self):
        for metric in self.listMetrics:
            metric.reset_state()
    def evalMetrics(self, training):
        if training:
            return self.evalMetricsTrain()
        else:
            return self.evalMetricsTest()
    def evalMetricsTest(self):
        return {metric.name:metric.result() for metric in self.listMetricsTest}
    def evalMetricsTrain(self):
        return {metric.name:metric.result() for metric in self.listMetricsTrain}

In [21]:
class displayImages(tf.keras.callbacks.Callback):
    """
    Callback designed to display the real images, the edge images, the generated images from teh edge images, 
    and the result of the patchGAN discriminator. 
    
    """
    def __init__(self, trainData,trainEdges, testData, testEdges):
        self.trainData = trainData
        self.trainEdges = trainEdges
        self.testData = testData
        self.testEdges = testEdges

    def on_train_batch_end(self, batch, logs):
        return
       
    def on_epoch_end(self, batch, logs):
        self.displayImages(batch, logs)
        return 
    def displayImages(self, batch, logs):
        print("in callback")
        numExamples = 5
        trainImages = self.testData[0:numExamples]
        trainEdges = self.testEdges[0:numExamples]
        generatedTrainImages, genPredictions, realPredictions= self.model((trainImages, trainEdges), training = False)
        if(len(realPredictions.shape)<=2):
          doReal = False
        else:
          doReal = True
          realPredictions = tf.tile(realPredictions, multiples = [1, 1, 1, 3])
          genPredictions = tf.tile(genPredictions, multiples = [1, 1, 1, 3])
        
        trainEdgesTiled = tf.tile(trainEdges, multiples = [1, 1, 1, 3])
        print(trainEdgesTiled.shape)
        stackedImages = tf.stack([trainImages, trainEdgesTiled, generatedTrainImages], axis=0)
        assert(stackedImages.shape == (3, numExamples, 256,256, 3))
        
        rows = numExamples
        columns = 5
        fig = plt.figure(figsize=(10, 10))
        #print("start iteration")
        for i in range(0, numExamples):
            #print("on i: ", i)
            for j in range(columns-2):
                #print("on j ", j)
                index = i*columns + j + 1
                fig.add_subplot(rows, columns, index)
                plt.imshow(stackedImages[j, i, :], aspect = 'auto')  
                plt.axis('off')
                predProbability = 0
                if(j == 0):
                    predProbability = tf.reduce_mean(realPredictions[i])
                elif(j == 2):
                    predProbability = tf.reduce_mean(genPredictions[i])

            if(doReal):
              #print("on j " , 3)
              fig.add_subplot(rows,columns, i*columns + 4)
              plt.imshow(realPredictions[i], aspect = 'auto')
              plt.axis('off')
       
              #print("on j ", 4)
              fig.add_subplot(rows,columns, i*columns + 5)
              plt.imshow(genPredictions[i], aspect = 'auto')
              plt.axis('off')
         
        plt.subplots_adjust(hspace=0, wspace = 0)
        plt.show()
        print("done callback")

In [22]:
# runModel.py

def showImages(images):
    for image in images:
        plt.imshow(image)
        plt.show()

def split(images, sketches):
    print("in split")
    percentTrain = .8
    numImages = images.shape[0]
    numTrain = int(percentTrain*numImages)
    print("before numpy random")
    indices = np.random.permutation(numImages)
    print("indices")
    mixedImages = images[indices]
    #mixedImages = img_as_float(images[indices])
    print("mixed images")
    mixedSketches = sketches[indices]
    #mixedSketches = img_as_float(sketches[indices])
    print("after imag sketches")
    trainImages = mixedImages[:numTrain]
    trainSketches = mixedSketches[:numTrain]
   
    testImages = mixedImages[numTrain:]
    testSketches = mixedSketches[numTrain:]
    print("done split")
    return trainImages, trainSketches, testImages, testSketches


def runModel(images, sketches):
    print("got to run model")
    #maybe use a faster library method for this instead. 
    trainImages, trainSketches, testImages, testSketches = split(images, sketches)
    #showImages(trainSketches[0:10])
    learningRate = .0002
    b1 = .5
    b2 = .999
    print("pre optimizers")
    #is giving half the learning rate the same as dividing the objective by 2? 
    optimizerDis = tf.keras.optimizers.Adam(learning_rate = learningRate, beta_1 = b1, beta_2 = b2)
    optimizerGen = tf.keras.optimizers.Adam(learning_rate = learningRate, beta_1 = b1, beta_2 = b2)
    
    batchSize = 4
    epochs = 40
    lossFxn = tf.keras.losses.BinaryCrossentropy()
    print("got to gan")
    model = GAN()
    startCompAndBuild = time.time()
    stepsPerExecution = 1
    model.compile(optimizerGen, optimizerDis, lossFxn, lossFxn, metrics = model.createMetrics(), steps_per_execution = stepsPerExecution)
    #need this for eager execution, without this it is automatically not eager. 
    #model.run_eagerly = True
    model.build(input_shape = [(None, 256, 256, 3), (None, 256, 256, 1)])
    endCompAndBuild = time.time()
    compAndBuild = endCompAndBuild - startCompAndBuild
    print("comp and build time: ", compAndBuild)
    model.summary()
    print("ready to train")
    smallerTrainImages = tf.constant(trainImages[:1000], dtype = tf.float32)
    smallerTrainSketches = tf.constant(trainSketches[:1000], dtype = tf.float32)
    smallerTestImages = tf.constant(testImages[:500], dtype = tf.float32)
    smallerTestSketches = tf.constant(testSketches[:500], dtype = tf.float32)
    saveFreq = 10
    modelCheckpoint = tf.keras.callbacks.ModelCheckpoint("checkpoints/{epoch}weights", monitor = "sumLoss",save_best_only = True,  mode = "min", save_weights_only = True, save_freq = saveFreq)
    callbacks = [displayImages(smallerTrainImages, smallerTrainSketches, smallerTestImages, smallerTestSketches)]

    history = model.fit(smallerTrainImages, smallerTrainSketches, batch_size = batchSize, epochs = epochs, validation_data = (smallerTestImages, smallerTestSketches), callbacks = callbacks)

    generatedImages = model.generateImages(trainSketches[0:10])
    showImages(generatedImages)

In [23]:
def randomJitter(images):
    """
    Randomly jitters images, by increasing their size then cropping them. 
    Not sure if for loop is fastest way to do this. 
    """
    print("images shape: ", images.shape)
    resizedImages = tf.image.resize(images, size = (286,286))
    print("resized shape: ", resizedImages.shape)
    listRandomCrop = []
    for i in range(images.shape[0]):
        randomCropping = tf.image.random_crop(resizedImages[i], (256,256, resizedImages.shape[-1]))
        listRandomCrop.append(randomCropping)
    croppedImages = tf.stack(listRandomCrop, axis=0)
    return croppedImages.numpy().astype(np.uint8)

In [8]:
PATH = '/content/drive/My Drive/Architectural-Illustrator/DataCombined/'

In [9]:
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.io.decode_jpeg(image)

  w = tf.shape(image)[1]
  w = w // 2
  input_image = image[:, w:, :]
  input_image = np.expand_dims(input_image,2)
  real_image = image[:, :w, :]

  input_image = img_as_float(input_image)
  real_image = img_as_float(real_image)

  return input_image, real_image

In [24]:
input_images = []
real_images = []
for i in range(69,600):
  input_image, real_image = load(PATH + 'combine' + str(i) + '.jpg')
  input_images.append(input_image[:,:,:,1])
  real_images.append(real_image)
input_images_ar = np.array(input_images)
real_images_ar = np.array(real_images)

In [19]:
print(input_images_ar.shape)

(531, 256, 256, 1)


In [1]:
model = runModel(real_images_ar, input_images_ar)

NameError: ignored

In [None]:
model = runModel(real_images_ar, input_images_ar)