## INTEGRALES DISCRETAS

Esta notebook se utilizará como guía para generar una calculadora de transformadas discretas. Las transformadas que se generarán son las siguientes:
- Transformada de Ondícula Discreta (Discrete Wavelet Transform)
- Transformada Discreta de Hartley
- Transformada Discreta de Legendre
- Transformada Discreta de Fourier
- Transformada Discreta del coseno

### CARGA DE LIBRERÍAS

In [5]:
import numpy as np

### TRANSFORMADA DE ONDÍCULA DISCRETA

In [13]:
A=[[2,5],[4,7]]

fft=np.fft.fft2(A)
fft
##%timeit fft

array([[18.+0.j, -6.+0.j],
       [-4.+0.j,  0.+0.j]])

### DEFINICIÓN DE LA CAPA DE FFT

In [1]:
def fft_forward(image,w,b,hparameters1):
    
    '''
    FUNCTION DESCRIPTION:
    This function is used to transform and convolute in spectral domain using the FFT approach. The purpose of this
    function is to excecute a fast convolution that is less time consuming than the spatial convolution given that in the
    spectral domain a convolution is a simple function product, in this case between the filter and the image.
    
    PARAMETER DESCRIPTION:
    - image: a matrix representation of an image to process
    - w: the weight to apply to the convolution
    - b: the weight to add to the convoluted image
    - hparameters1: **********************
    '''
    
    # Defining image shape
    length = image.shape[1]
    padding = (length - 3)//2
    filter = np.pad(w.reshape(3,3), (padding,padding))
    image = image.reshape(length,length)
    
    # For even values of length
    if length % 2 == 0:
        length = length-1
        image = image[:length, :length]
        limit = length//2+1
        
        # Transforming image and filter using FFT
        fft_image = np.fft.fft2(image)
        fft_filter = np.fft.fft2(filter)
        
        # Convolution in Spectral Domain
        conv = fft_image * fft_filter
        
        # Returning to Spatial Domain
        conv_image = np.fft.ifft2(conv)
        
        # Transposing image
        y = np.zeros((length,length))
        y[1:limit, 1:limit] = conv_img.real[limit:, limit:]
        y[1:limit, limit:] = conv_img.real[limit:, 1:limit]
        y[limit:, 1:limit] = conv_img.real[1:limit, limit:]
        y[limit:, limit:] = conv_img.real[1:limit, 1:limit]

        # Returning function
        y = y[1:length, 1:length] + b
    
    # For odd values of length
    else:
        limit = length//2+1
        
        # Transforming image and filter using FFT
        fft_image = np.fft.fft2(image)
        fft_filter = np.fft.fft2(filter)
        
        # Convolution in Spectral Domain
        conv = fft_image * fft_filter
        
        # Returning to Spatial Domain
        conv_image = np.fft.ifft2(conv)
        
        # Transposing image
        y = np.zeros((length,length))
        y[1:limit , 1:limit] = conv_image.real[limit:, limit:]
        y[1:limit , limit:] = conv_image.real[limit:, 1:limit]
        y[limit: , 1:limit] = conv_image.real[1:limit, limit:]
        y[limit: , limit:] = conv_image.real[1:limit, 1:limit]
        
        # Returning function
        y = y[1:length-1, 1:length-1] + b
        
    return zeros

### DEFINICIÓN DE LA CAPA DE SPECTRAL POOLING

In [20]:
import torch
import torch.nn as nn
from torch.autograd import Function
import math
from torch.nn.modules.utils import _pair

def _spectral_crop(input, oheight, owidth):

    cutoff_freq_h = math.ceil(oheight / 2)
    cutoff_freq_w = math.ceil(owidth / 2)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -cutoff_freq_w:]

    top_combined = torch.cat((top_left, top_right), dim=-1)
    bottom_combined = torch.cat((bottom_left, bottom_right), dim=-1)
    all_together = torch.cat((top_combined, bottom_combined), dim=-2)

    return all_together

def _spectral_pad(input, output, oheight, owidth):
    cutoff_freq_h = math.ceil(oheight / 2)
    cutoff_freq_w = math.ceil(owidth / 2)

    pad = torch.zeros_like(input)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):] = output[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:] = output[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):] = output[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -cutoff_freq_w:] = output[:, :, -cutoff_freq_h:, -cutoff_freq_w:]	

    return pad

def DiscreteHartleyTransform(input):
    fft = np.fft.fft2(input, 2, normalized=True, onesided=False)
    dht = fft[:, :, :, :, -2] - fft[:, :, :, :, -1]
    return dht

class SpectralPoolingFunction(Function):
    @staticmethod
    def forward(ctx, input, oheight, owidth):
        ctx.oh = oheight
        ctx.ow = owidth
        ctx.save_for_backward(input)

        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(input)

        # frequency cropping
        all_together = _spectral_crop(dht, oheight, owidth)
        
        # inverse Hartley transform
        dht = DiscreteHartleyTransform(all_together)
        return dht

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables

        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(grad_output)
        # frequency padding
        grad_input = _spectral_pad(input, dht, ctx.oh, ctx.ow)
        # inverse Hartley transform
        grad_input = DiscreteHartleyTransform(grad_input)
        return grad_input, None, None

class SpectralPool2d(nn.Module):
    def __init__(self, scale_factor):
        super(SpectralPool2d, self).__init__()
        self.scale_factor = _pair(scale_factor)
    def forward(self, input):
        H, W = input.size(-2), input.size(-1)
        h, w = math.ceil(H*self.scale_factor[0]), math.ceil(W*self.scale_factor[1])
        return SpectralPoolingFunction.apply(input, h, w)

ModuleNotFoundError: No module named 'torch'

In [None]:
def fft_model(image, weights, hparameters1, hparameters2):
    fft = relu(fft_forward(img, weights[0], weights[1], hparameters1))
    pool = pool_forward(fft.reshape(1,223,223,1), hparameters2)[0]
    fft1 = relu(fft_forward(pool, weights[2], weights[3], hparameters1))
    pool1 = pool_forward(fft1.reshape(1,fft1.shape[1],fft1.shape[1],1), hparameters2)[0]
    fft2 = relu(fft_forward(pool1, weights[4], weights[5], hparameters1))
    pool2 = pool_forward(fft2.reshape(1,fft2.shape[1],fft2.shape[1],1), hparameters2)[0]
    flatten = pool2.reshape(1,pool2.shape[1]*pool2.shape[2]*pool2.shape[3])
    neural_net = relu(np.dot(flatten, weights[6]) + weights[7])
    ans = sigmoid(np.dot(neural_net, weights[8]) + weights[9])
    
    tup = (ans, fft, pool, fft1, pool1, fft2, pool2)
    return tup

In [None]:
hparameters1 = {
    'stride':1,
    'pad':0
}
hparameters2 = {
    'stride':2,
    'f':2
}
weights = model.get_weights()

In [None]:
fft_out = []

start = time.time()
for i in range(len(x_test)):
    img = cv2.resize(x_test[i],(225,225)).reshape(1,225,225,1)
    output1 = fft_model(img, weights, hparameters1, hparameters2)
    if output1[0] >= 0.5:
        fft_out.append(1)
    else:
        fft_out.append(0)

end = time.time()
print('Time taken by FFT to predict class for 253 images is % seconds', %end - start)

In [7]:
class Activation(x,function_name='leaky_relu'):
    
    def lrelu(x, leak=0.2, name=function_name, alt_relu_impl=False):

        with tf.variable_scope(name):
            if alt_relu_impl:
                f1 = 0.5 * (1 + leak)
                f2 = 0.5 * (1 - leak)
                # lrelu = 1/2 * (1 + leak) * x + 1/2 * (1 - leak) * |x|
                return f1 * x + f2 * abs(x)
            else:
                return tf.maximum(x, leak * x)

    def relu(x, function_name):
        return np.maximum(0, x)

    def sigmoid(x, function_name):
        return 1 / (1 + np.exp(-x))


    
def InstanceNormalization(x):

    with tf.variable_scope("instance_norm"):
        epsilon = 1e-5
        
        mean, var = tf.nn.moments(x,
                                  [1, 2],
                                  keep_dims=True)
        
        scale = tf.get_variable('scale',
                                [x.get_shape()[-1]], 
                                initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02))
        
        offset = tf.get_variable('offset',
                                 [x.get_shape()[-1]],
                                 initializer=tf.constant_initializer(0.0))
        
        out = scale*tf.div(x-mean, tf.sqrt(var+epsilon)) + offset

        return out

### COLOR TRANSFORMATION

In [4]:
def lab2rgb(self, L, AB):
        """Convert an Lab tensor image to a RGB numpy output
        Parameters:
            L  (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
            AB (2-channel tensor array):  ab channel images (range: [-1, 1], torch tensor array)
        Returns:
            rgb (RGB numpy image): rgb output images  (range: [0, 255], numpy array)
        """
        AB2 = AB * 110.0
        L2 = (L + 1.0) * 50.0
        Lab = torch.cat([L2, AB2], dim=1)
        Lab = Lab[0].data.cpu().float().numpy()
        Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
        rgb = color.lab2rgb(Lab) * 255
        return rgb

    
def rgb2lab (inputColor):

    num = 0
    RGB = [0, 0, 0]

    for value in inputColor :
        value = float(value) / 255

        if value > 0.04045 :
            value = ( ( value + 0.055 ) / 1.055 ) ** 2.4
        else :
            value = value / 12.92

        RGB[num] = value * 100
        num = num + 1

    XYZ = [0, 0, 0,]

    X = RGB [0] * 0.4124 + RGB [1] * 0.3576 + RGB [2] * 0.1805
    Y = RGB [0] * 0.2126 + RGB [1] * 0.7152 + RGB [2] * 0.0722
    Z = RGB [0] * 0.0193 + RGB [1] * 0.1192 + RGB [2] * 0.9505
    XYZ[ 0 ] = round( X, 4 )
    XYZ[ 1 ] = round( Y, 4 )
    XYZ[ 2 ] = round( Z, 4 )

    XYZ[ 0 ] = float( XYZ[ 0 ] ) / 95.047         # ref_X =  95.047   Observer= 2°, Illuminant= D65
    XYZ[ 1 ] = float( XYZ[ 1 ] ) / 100.0          # ref_Y = 100.000
    XYZ[ 2 ] = float( XYZ[ 2 ] ) / 108.883        # ref_Z = 108.883

    num = 0
    for value in XYZ :

        if value > 0.008856 :
            value = value ** ( 0.3333333333333333 )
        else :
            value = ( 7.787 * value ) + ( 16 / 116 )

        XYZ[num] = value
        num = num + 1

    Lab = [0, 0, 0]

    L = ( 116 * XYZ[ 1 ] ) - 16
    a = 500 * ( XYZ[ 0 ] - XYZ[ 1 ] )
    b = 200 * ( XYZ[ 1 ] - XYZ[ 2 ] )

    Lab [ 0 ] = round( L, 4 )
    Lab [ 1 ] = round( a, 4 )
    Lab [ 2 ] = round( b, 4 )

    return Lab

In [5]:
from skimage import io, color
rgb = io.imread(img)
lab = color.rgb2lab(rgb)
rgb = color.lab2rgb(lab)

NameError: name 'filename' is not defined

## MODELO

In [28]:
# Image preprocessing
img_height = 225
img_width = 225
img_size = img_height * img_width
image = cv2.resize(x_test[15],(img_height,img_width)).reshape(1,img_height,img_width,1) 
for im in image:
    rgb = io.imread(im)
    lab = color.rgb2lab(rgb)
    print(image)

# Defining model hyperparameters
batch_size = 50
pool_size = 100
epoch=1000
gen_features = 32
disc_features = 64
learning_rate = lr_schedule(epoch)

# Setting Learning Rate for different number of Epochs
def lr_schedule(epoch):
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    elif epoch > 80:
        lr *= 1e-1
    print('The learning rate is set to %f given the number of epochs.' %(lr))
    return lr

The learning rate is set to 0.000000 given the number of epochs.


In [30]:
# Basic ResNet Building Block
def FastConvolutionLayer(inputs,
                 num_filters = 16,
                 kernel_size = 3,
                 strides = 1,
                 activation ='relu',
                 batch_normalization = False,
                 conv_first=False):
    
    conv = fft_forward(num_filters,w,b,hparameters)

    x = inputs
    if conv_first:
        x = conv(x)
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
    else:
        if batch_normalization:
            x = BatchNormalization()(x)
        else:
            x = InstanceNormalization(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x

In [10]:
# ResNet V2 architecture
def resnet_v2(input_shape, depth, num_classes = 10):
    if (depth - 2) % 9 != 0:
        raise ValueError('Argument Error: \'depth\' argument should be 6n + 2 (eg 20, 32, 44 in [a])')
    # Start model definition.
    num_filters_in = 16
    num_res_blocks = int((depth - 2) / 9)

    inputs = Input(shape = input_shape)
    # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths
    x = FastConvolutionLayer(inputs = inputs,
                    num_filters = num_filters_in,
                    conv_first = True)

    # Instantiate the stack of residual units
    for stage in range(3):
        for res_block in range(num_res_blocks):
            activation = 'relu'
            batch_normalization = True
            strides = 1
            if stage == 0:
                num_filters_out = num_filters_in * 4
                if res_block == 0: # first layer and first stage
                    activation = None
                    batch_normalization = False
            else:
                num_filters_out = num_filters_in * 2
                if res_block == 0: # first layer but not first stage
                    strides = 2 # downsample

            # bottleneck residual unit
            y = FastConvolutionLayer(inputs = x,
                            num_filters = num_filters_in,
                            kernel_size = 1,
                            strides = strides,
                            activation = activation,
                            batch_normalization = batch_normalization,
                            conv_first = False)
            y = FastConvolutionLayer(inputs = y,
                            num_filters = num_filters_in,
                            conv_first = False)
            y = FastConvolutionLayer(inputs = y,
                            num_filters = num_filters_out,
                            kernel_size = 1,
                            conv_first = False)
            if res_block == 0:
                # Linear projection residual shortcut connection to match
                # Changed dims
                x = FastConvolutionLayer(inputs = x,
                                        num_filters = num_filters_out,
                                        kernel_size = 1,
                                        strides = strides,
                                        activation = None,
                                        batch_normalization = False)
            x = x + y

        num_filters_in = num_filters_out

    # Add classifier on top.
    # v2 has BN-ReLU before Pooling
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size = 8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes,
                    activation ='softmax',
                    kernel_initializer ='he_normal')(y)

    # Instantiate model.
    model = Model(inputs = inputs, outputs = outputs)
    return model


In [11]:
# ResNet V1 architecture
class ResNet():
def resnet_v1(input_shape, depth, num_classes = 10):
    
    if (depth - 2) % 6 != 0:
        raise ValueError('Argument Error: \'depth\' argument should be 6n + 2 (eg 20, 32, 44 in [a])')
    # Start model definition.
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape = input_shape)
    x = resnet_layer(inputs = inputs)
    # Instantiate the stack of residual units
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0: # first layer but not first stack
                strides = 2 # downsample
            y = resnet_layer(inputs = x,
                            num_filters = num_filters,
                            strides = strides)
            y = resnet_layer(inputs = y,
                            num_filters = num_filters,
                            activation = None)
            if stack > 0 and res_block == 0: # first layer but not first stack
                # linear projection residual shortcut connection to match
                # changed dims
                x = resnet_layer(inputs = x,
                                num_filters = num_filters,
                                kernel_size = 1,
                                strides = strides,
                                activation = None,
                                batch_normalization = False)
            x = keras.layers.add([x, y])
            x = Activation('relu')(x)
        num_filters *= 2

    # Add classifier on top.
    # v1 does not use BN after last shortcut connection-ReLU
    x = AveragePooling2D(pool_size = 8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes,
                    activation ='softmax',
                    kernel_initializer ='he_normal')(y)

    # Instantiate model.
    model = Model(inputs = inputs, outputs = outputs)
    return model

In [None]:
class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    


'''     
        Entradas = Input(shape=(filas,columnas,canales))
        x = lrelu(fft_forward(Entradas))
        x = lrelu(fft_forward(x))
        x = SpectralPool2D(x)
        x = instance_norm(x)
        
   '''     
        
        
        
def fft_model(img, weights, hparameters1, hparameters2):
    fft = relu(fft_forward(img, weights[0], weights[1], hparameters1))
    fft1 = relu(fft_forward(fft, weights[0], weights[1], hparameters1))
    
    
    ##pool = pool_forward(fft.reshape(1,223,223,1), hparameters2)[0]
    fft1 = relu(fft_forward(pool, weights[2], weights[3], hparameters1))
    ##pool1 = pool_forward(fft1.reshape(1,fft1.shape[1],fft1.shape[1],1), hparameters2)[0]
    fft2 = relu(fft_forward(pool1, weights[4], weights[5], hparameters1))
    ##pool2 = pool_forward(fft2.reshape(1,fft2.shape[1],fft2.shape[1],1), hparameters2)[0]
    ##flatten = pool2.reshape(1,pool2.shape[1]*pool2.shape[2]*pool2.shape[3])
    ##neural_net = relu(np.dot(flatten, weights[6]) + weights[7])
    ##ans = sigmoid(np.dot(neural_net, weights[8]) + weights[9])
    
    tup = (ans, fft, pool, fft1, pool1, fft2, pool2)
    return tup
        
        
        
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

        

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)



    def train(self, epochs, batch_size=128, sample_interval=50):
        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)
        

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)



    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        
        im = '/Users/aleja/Downloads/gan_imgs/'
        fig.savefig(im + "%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=1000, batch_size=32, sample_interval=100)