Implementation using Theano with Lasagne. However, it is planned to be replaced with tensorflow/keras implementation. 

In [None]:
import theano
import theano.tensor as T
import lasagne

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import time
import random 
from PIL import Image
import sys, os
import joblib
%matplotlib inline

# Define Main Functions

In [None]:
relu        = lasagne.nonlinearities.rectify
lrelu       = lasagne.nonlinearities.LeakyRectify(0.2)
tanh        = lasagne.nonlinearities.tanh
sigmoid     = lasagne.nonlinearities.sigmoid
conv        = lambda incoming, num_filters, filter_size, W, b, nonlinearity: \
                lasagne.layers.Conv2DLayer(incoming, num_filters, filter_size, stride=(2,2), pad='same', W=W, b=b, flip_filters=True, nonlinearity=nonlinearity)
tconv       = lambda incoming, num_filters, filter_size, W, nonlinearity: lasagne.layers.TransposedConv2DLayer(incoming, num_filters, filter_size, stride=(2,2), crop='same', W=W, nonlinearity=nonlinearity)
batchnorm   = lasagne.layers.batch_norm

# bias and weight initializations
w_init      = lasagne.init.Normal(std=0.02)
b_init      = lasagne.init.Constant(val=0.0)
g_init      = lasagne.init.Normal(mean=1.,std=0.02)

def sharedX(X, dtype=theano.config.floatX, name=None):
    return theano.shared(np.asarray(X, dtype=dtype), name=name)

from theano.tensor.shared_randomstreams import RandomStreams
srng = RandomStreams(seed=234)

In [None]:
def upsample(zx, depth):
    return (zx - 1)*2**depth + 1

def sample_z(config, batch_size, zx):
    Z = np.zeros( (batch_size, config.nz, zx, zx) )
    Z[:, config.nz_global:config.nz_global+config.nz_local] = np.random.uniform(-1.,1., (batch_size, config.nz_local, zx, zx) )
    Z[:,:config.nz_global] = np.random.uniform(-1.,1., (batch_size, config.nz_global, 1, 1) )
    return Z

def get_data(dir_data: str, batch_size: int, crop_size: int):
    """ random cropping samples from image in the directory
    inputs:
           - dir_data (str) : path to users image directory
           - batch_size (int) : size of dataaset to prepare
           - crop_size(int) : dimension of cropped image          
    """
    img = np.array( Image.open( dir_data) )
    w, h = img.shape[:2]
    while True: 
        batch = np.zeros( (batch_size, 3, crop_size, crop_size))
        for i in range(batch_size):
            rdm_x = random.randrange(0, w - crop_size, 1)
            rdm_y = random.randrange(0, h - crop_size, 1)
            sample = np.copy( img[rdm_x:rdm_x+crop_size, rdm_y:rdm_y+crop_size] )
            
            batch[i, 0, ...] = np.copy(sample[..., 0])
            batch[i, 1, ...] = np.copy(sample[..., 1])
            batch[i, 2, ...] = np.copy(sample[..., 2])

        yield batch

In [None]:
class Config(object):
    lr          = 0.0002                
    b1          = 0.5                   
    l2_fac      = 1e-8                  
    epoch_count = 100                   
    k           = 1                     # number of D updates vs G updates
    batch_size  = 25
    epoch_iters = batch_size * 500
                 
    def __init__(self):    
        # sampling
        self.nz_local = 30    
        self.nz_global = 15                 
        self.nz_periodic = 0                  
        self.nz_periodic_MLPnodes = 50
        self.nz          = self.nz_local+self.nz_global+self.nz_periodic*2
        self.periodic_affine = False            
        self.zx          = 6                    # number of spatial dimensions in Z
        self.zx_sample   = 32                   # size of the spatial dimension in Z for producing the samples    

        # network
        self.nc          = 3                     # number of channels in input X (i.e. r,g,b)
        self.gen_ks      = ([(5,5)] * 5)[::-1]   # kernel sizes for each layer for G
        self.dis_ks      = [(5,5)] * 5           # kernel sizes on each layer for D
        self.gen_ls      = len(self.gen_ks)           # number of layers for G
        self.dis_ls      = len(self.dis_ks)           # number of layers for D
        self.gen_fn      = [self.nc]+[2**(n+6) for n in range(self.gen_ls-1)]  # number of filters for G
        self.gen_fn      = self.gen_fn[::-1]
        self.dis_fn      = [self.nc] + [2**(n+6) for n in range(self.dis_ls-1)]  # number of filters for D
        self.npx         = upsample(self.zx, self.gen_ls) # shape of output
        
        ## directory
        self.save_name   = "trained network"
        self.load_name   = None # if None, train from scratch 
        self.dir_data    = "/path/to/directory/img.jpg"
           
    def data_iter(self):
        return get_data(self.dir_data, self.batch_size, self.npx)

    def __str__(self): # print_info
        return_str = f"Learning and generating samples from zx {self.zx}, which yields images of size npx {upsample(self.zx, self.gen_ls)}\n"
        return_str += f"Generator: {self.gen_fn}\n"
        return_str += f"Discriminator: {self.dis_fn}\n"
        return_str += f"Saving samples and model data to file {self.save_name}"
        return return_str

In [None]:
class PeriodicLayer(lasagne.layers.Layer):
    def __init__(self,incoming,config,wave_params):
        self.config = config       
        self.wave_params = wave_params
        self.input_layer= incoming
        self.input_shape = incoming.output_shape
        self.get_output_kwargs = []
        self.params = {}
        for p in wave_params:
            self.params[p] = set('trainable')
            
    def _wave_calculation(self,Z):
        if self.config.nz_periodic ==0:
            return Z
        nPeriodic = self.config.nz_periodic

        if self.config.nz_global > 0:  # #MLP or directly a weight vector in case of no Global dims
            h = T.tensordot(Z[:, :self.config.nz_global], self.wave_params[0], [1, 0]).dimshuffle(0, 3, 1, 2) + self.wave_params[1].dimshuffle('x', 0, 'x', 'x')
            band0 = (T.tensordot(relu(h),self.wave_params[2], [1, 0]).dimshuffle(0, 3, 1, 2)) + self.wave_params[3].dimshuffle('x', 0, 'x', 'x')  # #moved relu inside
        else:
            band0 = self.wave_params[0].dimshuffle('x', 0, 'x', 'x')
        
        if self.config.periodic_affine:
            band1 = Z[:, -nPeriodic * 2::2] * band0[:, :nPeriodic] + Z[:, -nPeriodic * 2 + 1::2] * band0[:, nPeriodic:2 * nPeriodic]
            band2 = Z[:, -nPeriodic * 2::2] * band0[:, 2 * nPeriodic:3 * nPeriodic] + Z[:, -nPeriodic * 2 + 1::2] * band0[:, 3 * nPeriodic:]
        else:
            band1 = Z[:, -nPeriodic * 2::2] * band0[:, :nPeriodic] 
            band2 = Z[:, -nPeriodic * 2 + 1::2] * band0[:, 3 * nPeriodic:]
        band = T.concatenate([band1 , band2], axis=1)       
        
        band += srng.uniform((Z.shape[0],nPeriodic * 2)).dimshuffle(0,1, 'x', 'x') *np.pi*2
        return T.concatenate([Z[:, :-2 * nPeriodic], T.sin(band)], axis=1)

    def get_output_for(self, input, **kwargs):
        return self._wave_calculation(input)

    def get_output_shape_for(self, input_shape):
        return (input_shape[0],input_shape[1]+self.config.nz_periodic*2,input_shape[2],input_shape[3])     

periodic = lambda incoming,config,wave_params: PeriodicLayer(incoming,config,wave_params)

In [None]:
class GAN(object):
    def __init__(self, load=None):
        """
        :param load: (str) directory to stored model
        """
        if load is not None:
            print( f"loading trained model from {load}")
            vals =joblib.load(load)
            self.config = vals["config"]                       
            self.dis_W = [sharedX(p) for p in vals["dis_W"]]
            self.dis_g = [sharedX(p) for p in vals["dis_g"]]
            self.dis_b = [sharedX(p) for p in vals["dis_b"]]
            self.gen_W = [sharedX(p) for p in vals["gen_W"]]
            self.gen_g = [sharedX(p) for p in vals["gen_g"]]
            self.gen_b = [sharedX(p) for p in vals["gen_b"]]
            self.wave_params = [sharedX(p) for p in vals["wave_params"]]
            
            self.config.gen_ks = []
            self.config.gen_fn = []
            
            l = len(vals["gen_W"])
            for i in range(l):
                if i==0:
                    self.config.nz = vals["gen_W"][i].shape[0]
                else:
                    self.config.gen_fn +=[vals["gen_W"][i].shape[0]]
                self.config.gen_ks += [(vals["gen_W"][i].shape[2],vals["gen_W"][i].shape[3])]
            self.config.nc = vals["gen_W"][i].shape[1]
            self.config.gen_fn +=[self.config.nc]
            self.config.dis_ks = []
            self.config.dis_fn = []
            
            l = len(vals["dis_W"])
            for i in range(l):
                self.config.dis_fn +=[vals["dis_W"][i].shape[1]]   
                self.config.dis_ks += [(vals["gen_W"][i].shape[2],vals["gen_W"][i].shape[3])]             

            self._setup_gen_params(self.config.gen_ks, self.config.gen_fn)
            self._setup_dis_params(self.config.dis_ks, self.config.dis_fn)
        else:
            self.config = Config()
            self._setup_gen_params(self.config.gen_ks, self.config.gen_fn)
            self._setup_dis_params(self.config.dis_ks, self.config.dis_fn)
            self._sample_initials()
            self._setup_wave_params()
            
        self._build_sgan()
    
    def save(self, dir_save):
        print( f"saving trained model at {dir_save}" )
        
        vals = {}
        vals["config"] = self.config
        vals["dis_W"] = [p.get_value() for p in self.dis_W]
        vals["dis_g"] = [p.get_value() for p in self.dis_g]
        vals["dis_b"] = [p.get_value() for p in self.dis_b]
        vals["gen_W"] = [p.get_value() for p in self.gen_W]
        vals["gen_g"] = [p.get_value() for p in self.gen_g]
        vals["gen_b"] = [p.get_value() for p in self.gen_b]
        vals["wave_params"] = [p.get_value() for p in self.wave_params]
        vals["m"] = [p.get_value() for p in self.bm]
        vals["istd"] = [p.get_value() for p in self.bi]
        joblib.dump(vals, self.config.save_name, True)

    
    def _setup_wave_params(self):
        if self.config.nz_periodic:
            nPeriodic = self.config.nz_periodic
            nperiodK = self.config.nz_periodic_MLPnodes
            if self.config.nz_global > 0 and nperiodK > 0: 
                lin1 =  sharedX( g_init.sample( (self.config.nz_global,nperiodK)))
                bias1 = sharedX( g_init.sample( (nperiodK)))
                lin2 =  sharedX( g_init.sample( (nperiodK,nPeriodic * 2*2)))
                bias2 = sharedX( g_init.sample( (nPeriodic * 2*2)))
                self.wave_params = [lin1,bias1,lin2,bias2]
            else: ##in case no global dimensions learn global wave numbers
                bias2 = sharedX( g_init.sample( (nPeriodic * 2*2)))
                self.wave_params = [bias2]
            a = np.zeros(nPeriodic * 2*2)              
            a[:nPeriodic]=1#x
            a[nPeriodic:2*nPeriodic]=0#y
            a[2*nPeriodic:3*nPeriodic]=0#x
            a[3*nPeriodic:]=1#y
            self.wave_params[-1].set_value(np.float32(a)) 
        else:
            self.wave_params = []

    def _setup_gen_params(self, gen_ks, gen_fn):
        self.gen_ks = [(5,5)] * 5 if gen_ks==None else gen_ks
        self.gen_depth = len(self.gen_ks)
        
        if gen_fn!=None:
            assert len(gen_fn)==len(self.gen_ks), 'Layer number of filter numbers and sizes does not match.'
            self.gen_fn = gen_fn
        else:
            self.gen_fn = [64] * self.gen_depth
    

    def _setup_dis_params(self, dis_ks, dis_fn):
        self.dis_ks = [(5,5)] *5 if dis_ks==None else dis_ks
        self.dis_depth = len(dis_ks)

        if dis_fn!=None:
            assert len(dis_fn)==len(self.dis_ks), 'Layer number of filter numbers and sizes does not match.'
            self.dis_fn = dis_fn
        else:
            self.dis_fn = [64] * self.dis_depth

    def _sample_initials(self):
        self.dis_W, self.dis_b, self.dis_g = [], [], []
        self.dis_W.append( sharedX( w_init.sample( (self.dis_fn[0], self.config.nc, self.dis_ks[0][0], self.dis_ks[0][1]) )) )
        
        for l in range(self.dis_depth-1):
            self.dis_W.append( sharedX( w_init.sample( (self.dis_fn[l+1], self.dis_fn[l], self.dis_ks[l+1][0], self.dis_ks[l+1][1]) ) ) )
            self.dis_b.append( sharedX( b_init.sample( (self.dis_fn[l+1]) ) ) )
            self.dis_g.append( sharedX( g_init.sample( (self.dis_fn[l+1]) ) ) )
    
        self.gen_b, self.gen_g = [], []
        for l in range(self.gen_depth-1):
            self.gen_b += [sharedX( b_init.sample( (self.gen_fn[l]) ) ) ]
            self.gen_g += [sharedX( g_init.sample( (self.gen_fn[l]) ) ) ]

        self.gen_W = []
        last = self.config.nz
        for l in range(self.gen_depth-1):
            self.gen_W +=[sharedX( w_init.sample((last,self.gen_fn[l], self.gen_ks[l][0],self.gen_ks[l][1])))]
            last=self.gen_fn[l]
        self.gen_W +=[sharedX( w_init.sample((last,self.gen_fn[-1], self.gen_ks[-1][0],self.gen_ks[-1][1])))]   

    def _spatial_generator(self, inlayer):
        layers  = [inlayer]
        layers.append(periodic(inlayer,self.config,self.wave_params))
        
        m, i =[], []
        for l in range(self.gen_depth-1):
            layers.append( batchnorm(tconv(layers[-1], self.gen_fn[l], self.gen_ks[l],self.gen_W[l],\
                                           nonlinearity=relu),gamma=self.gen_g[l],beta=self.gen_b[l],alpha=1.0) )
            m +=[layers[-1].input_layer.mean]
            i +=[layers[-1].input_layer.inv_std]
        output  = tconv(layers[-1], self.gen_fn[-1], self.gen_ks[-1],self.gen_W[-1] , nonlinearity=tanh)
        return output,m,i
    
    def _spatial_generator_det(self, inlayer):
        layers  = [inlayer]
        layers.append(periodic(inlayer,self.config,self.wave_params))
        for l in range(self.gen_depth-1):
            layers.append( batchnorm(tconv(layers[-1], self.gen_fn[l], self.gen_ks[l],self.gen_W[l], nonlinearity=relu),gamma=self.gen_g[l],\
                                  beta=self.gen_b[l],mean=self.im[l],inv_std=self.im[l+self.gen_depth-1]) )
        output  = tconv(layers[-1], self.gen_fn[-1], self.gen_ks[-1],self.gen_W[-1] , nonlinearity=tanh)
        return output
    
    def _spatial_discriminator(self, inlayer):
        layers  = [inlayer]
        layers.append( conv(layers[-1], self.dis_fn[0], self.dis_ks[0], self.dis_W[0], None, nonlinearity=lrelu) )
        for l in range(1,self.dis_depth-1):
            layers.append( batchnorm(conv(layers[-1], self.dis_fn[l], self.dis_ks[l], self.dis_W[l],None,nonlinearity=lrelu),gamma=self.dis_g[l-1],beta=self.dis_b[l-1]) )
        output = conv(layers[-1], self.dis_fn[-1], self.dis_ks[-1], self.dis_W[-1], None, nonlinearity=sigmoid)

        return output   
    
    def _build_sgan(self):    
        Z               = lasagne.layers.InputLayer((None,self.config.nz,None,None))   
        X               = lasagne.layers.InputLayer((self.config.batch_size,self.config.nc,self.config.npx,self.config.npx))
        self.forDebug = Z
        gen_X,i,m           = self._spatial_generator(Z)
        self.im = i+m 

        gen_X_det           = self._spatial_generator_det(Z)
        d_real          = self._spatial_discriminator(X)
        d_fake          = self._spatial_discriminator(gen_X)

        prediction_gen  = lasagne.layers.get_output(gen_X)
        prediction_gen_det  = lasagne.layers.get_output(gen_X_det,deterministic=True)
        prediction_real = lasagne.layers.get_output(d_real)
        prediction_fake = lasagne.layers.get_output(d_fake)

        params_g        = lasagne.layers.get_all_params(gen_X, trainable=True)
        params_d        = lasagne.layers.get_all_params(d_real, trainable=True)

        l2_gen          = lasagne.regularization.regularize_network_params(gen_X, lasagne.regularization.l2)
        l2_dis          = lasagne.regularization.regularize_network_params(d_real, lasagne.regularization.l2)

        
        obj_d= -T.mean(T.log(1-prediction_fake)) - T.mean( T.log(prediction_real)) + self.config.l2_fac * l2_dis
        obj_g= -T.mean(T.log(prediction_fake)) + self.config.l2_fac * l2_gen

        updates_d       = lasagne.updates.adam(obj_d, params_d, self.config.lr, self.config.b1)
        updates_g       = lasagne.updates.adam(obj_g, params_g, self.config.lr, self.config.b1)

        st = time.time() 
        self.train_d    = theano.function([X.input_var, Z.input_var], obj_d, updates=updates_d, allow_input_downcast=True)
        print( f"Compiling Discriminator {round( time.time() - st, 4 )}s" )
        
        st = time.time()
        self.train_g    = theano.function([Z.input_var], obj_g, updates=updates_g, allow_input_downcast=True)
        print( f"Compiling Generator {round( time.time() - st, 4 )}s" )
        
        st = time.time()
        self.generate   = theano.function([Z.input_var], prediction_gen, allow_input_downcast=True)
        self.generate_det   = theano.function([Z.input_var], prediction_gen_det, allow_input_downcast=True)
        print( f"Compiling rest {round( time.time() - st, 4 )}s" )

# Train model

In [None]:
gan = GAN()
print( gan.config )
c = gan.config

epoch = 0
total_iters = 0
c.epoch_count = 10
while epoch < c.epoch_count: 
    epoch += 1
    print( f"epoch {epoch}" )
    
    Gcost = []
    Dcost = []
    n_iters = c.epoch_iters / c.batch_size
    for i_iters, samples in enumerate( tqdm(c.data_iter(), total=n_iters) ): 
        if i_iters >= n_iters:
            break
        total_iters += 1
        
        zs = sample_z(c, c.batch_size, c.zx)
        if total_iters % (c.k+1) == 0: 
            cost = gan.train_g(zs)
            Gcost.append(cost)
        else: 
            cost = gan.train_d(samples, zs)
            Dcost.append(cost)
        
    print( f"Epoch {epoch} - G ({np.mean(Gcost)}), D ({np.mean(Dcost)})" )    
    outs = gan.generate(z[:3])
    outs = outs.transpose( (0,2,3,1) )

    plt.figure(figsize=(10,5))
    plt.subplot( 1, 3, 1 )
    plt.imshow( outs[0] )
    plt.axis('off')
    plt.subplot( 1, 3, 2 )
    plt.imshow( outs[1] )
    plt.axis('off')
    plt.subplot( 1, 3, 3 )
    plt.imshow( outs[2] )
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
gan.save() # save trained model