In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

# !pip install --upgrade tensorflow
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.

In [None]:
# 1 model encoder
# 1.1 models 2 encoders
# 2 separate decoders, one for each genre
# discriminator model, latent space
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D,ReLU,BatchNormalization,Conv2DTranspose,MaxPool2D,Dense,Flatten,Layer
print(tf.__version__)

disclosswt = 0.0001
# tf.reset_default_graph()
# tf.clear_all_variables()

class MaxPool2D(Layer):

    def __init__(
            self,
            ksize=(2, 2),
            strides=(2, 2),
            padding='same',
            **kwargs):
        super(MaxPool2D, self).__init__(autocast=False)
        self.padding = padding
        self.pool_size = ksize
        self.strides = strides

    def call(self, inputs, **kwargs):
        padding = self.padding
        pool_size = self.pool_size
        strides = self.strides
        ksize = [1, pool_size[0], pool_size[1], 1]
        padding = padding.upper()
        strides = [1, strides[0], strides[1], 1]
        output, argmax = tf.nn.max_pool_with_argmax(
                inputs,
                ksize=ksize,
                strides=strides,
                padding=padding)
        argmax = tf.cast(argmax, tf.float64)
        return [output, argmax]

    def compute_output_shape(self, input_shape):
        ratio = (1, 2, 2, 1)
        output_shape = [
                dim//ratio[idx]
                if dim is not None else None
                for idx, dim in enumerate(input_shape)]
        output_shape = tuple(output_shape)
        return [output_shape, output_shape]

    def compute_mask(self, inputs, mask=None):
        return 2 * [None]


class MaxUnpool2D(Layer):
    def __init__(self, ksize=(2, 2), **kwargs):
        super(MaxUnpool2D, self).__init__(autocast=False,**kwargs)
        self.size = ksize

    def call(self, inputs, output_shape=None):
        updates, mask = inputs[0], inputs[1]
        mask = tf.cast(mask, 'int32')
        input_shape = tf.shape(updates, out_type='int32')
        #  calculation new shape
        if output_shape is None:
            output_shape = (
                    input_shape[0],
                    input_shape[1]*self.size[0],
                    input_shape[2]*self.size[1],
                    input_shape[3])
        self.output_shape1 = output_shape

        # calculation indices for batch, height, width and feature maps
        one_like_mask = tf.ones_like(mask, dtype='int32')
        batch_shape = tf.concat(
                [[input_shape[0]], [1], [1], [1]],
                axis=0)
        batch_range = tf.reshape(
                tf.range(output_shape[0], dtype='int32'),
                shape=batch_shape)
        # print("SHAPE______",output_shape[3])
        b = one_like_mask * batch_range
        y = mask // (output_shape[2] * output_shape[3])
        x = (mask // output_shape[3]) % output_shape[2]
        feature_range = tf.range(output_shape[3], dtype='int32')
        f = one_like_mask * feature_range

        # transpose indices & reshape update values to one dimension
        updates_size = tf.size(updates)
        indices = tf.transpose(tf.reshape(
            tf.stack([b, y, x, f]),
            [4, updates_size]))
        values = tf.reshape(updates, [updates_size])
        ret = tf.scatter_nd(indices, values, output_shape)
        return ret

    def compute_output_shape(self, input_shape):
        mask_shape = input_shape[1]
        return (
                mask_shape[0],
                mask_shape[1]*self.size[0],
                mask_shape[2]*self.size[1],
                mask_shape[3]
                )
class segnet(tf.keras.Sequential):
    def __init__(self):
        super(segnet,self).__init__()
    def conv_layer(self, channel):
        conv_block = tf.keras.Sequential(
            [Conv2D(filters=channel, kernel_size=3, padding="same",kernel_initializer='glorot_normal'),
            BatchNormalization(axis=-1),
            ReLU()]
        )
        return conv_block

class encoder(segnet):
    def __init__(self,channels=3):
        super(encoder,self).__init__()
        filter = [64, 128, 256, 512, 512]
        self.conv_block_enc = []
        self.conv_block_enc.append(Sequential([self.conv_layer(filter[0]),self.conv_layer(filter[0])]))
        for i in range(4):  #TODO Refactor for better model making
            if i == 0:
                self.conv_block_enc.append(Sequential([self.conv_layer(filter[i + 1]),
                                                    self.conv_layer(filter[i + 1])]))
            else:
                self.conv_block_enc.append(Sequential([self.conv_layer(filter[i + 1]),
                                                    self.conv_layer(filter[i + 1]),
                                                    self.conv_layer(filter[i + 1])]))
        self.down_sampling = MaxPool2D(ksize=(2,2),padding='same')
    def call(self,x):
        x1 = x
        indices = []
        for i in range(5):
            x1 = self.conv_block_enc[i](x1)
            x1,index = self.down_sampling.call(x1)
            indices.append(index)
            #print("Encode",x1.shape,indices[-1].shape)
        self.encout = x1
        self.indices = indices

class decoder(segnet):
    def __init__(self,channels=3):
        super(decoder,self).__init__()
        filter = [64, 128, 256, 512, 512]
        self.conv_block_dec = []
#       self.conv_block_dec = Sequential()
        for i in range(1,4):
            self.conv_block_dec.append(Sequential([self.conv_layer(filter[-i]),
                                                  self.conv_layer(filter[-i]),
                                                  self.conv_layer(filter[-(i+1)])]))

        self.conv_block_dec.append(Sequential([self.conv_layer(filter[1]),
                                                  self.conv_layer(filter[0])]))
        self.conv_block_dec.append(Sequential([self.conv_layer(filter[0]),
                                                  tf.keras.Sequential(
            [Conv2D(filters=1, kernel_size=3, padding="same",kernel_initializer='glorot_normal'),
             BatchNormalization(axis=-1),
             ReLU()]
        )]))    #Getting best results when sigmoid, batch_norm, relu
        
        self.up_sampling = MaxUnpool2D(ksize=(2,2))
    def forward(self,X,indices):
        indices = indices[::-1]
        for idx,layer in enumerate(self.conv_block_dec):
            #print(X.shape,indices[idx].shape)
            # print(idx,X.shape,self.max_indices[idx].shape)
            X = self.up_sampling.call([X,indices[idx]])
            #print(idx,X.shape,indices[idx].shape)
            X = layer(X) 
        return X
# class unet(tf.keras.Sequential):
#     def __init__(self):
#         super(unet,self).__init__()
    
#     def contracting_block(self,in_channels,out_channels,kernel_size=(3,3)):
#         layers = [Conv2D(filters=out_channels,kernel_size=kernel_size,padding="same",activation="relu"),
#                     BatchNormalization(axis=-1),
#                     Conv2D(filters=out_channels,kernel_size=kernel_size,padding="same",activation="relu"),
#                     BatchNormalization(axis=-1)]
#         return tf.keras.Sequential(layers)
    
#     def expansive_block(self,in_channels,mid_channels,out_channels,kernel_size=(3,3)):
#         layers = [Conv2D(filters=mid_channels,kernel_size=kernel_size,padding="same",activation="relu"),
#                     BatchNormalization(axis=-1),
#                     Conv2D(filters=mid_channels,kernel_size=kernel_size,padding="same",activation="relu"),
#                     BatchNormalization(axis=-1),
#                     Conv2DTranspose(filters=out_channels,kernel_size=kernel_size,strides=2,output_padding=1,padding="same")
#                  ]
#         return tf.keras.Sequential(layers)
    
#     def final_block(self,in_channels,mid_channels,out_channels,kernel_size=3):
#         layers = [Conv2D(filters=mid_channels,kernel_size=kernel_size,padding="same",activation="relu"),
#                     BatchNormalization(axis=-1),
#                     Conv2D(filters=mid_channels,kernel_size=kernel_size,padding="same",activation="relu"),
#                     BatchNormalization(axis=-1),
#                     Conv2D(filters=out_channels,kernel_size=kernel_size,padding="same",activation="relu"),
#                     #BatchNormalization(axis=-1)
#                  ]
#         return tf.keras.Sequential(layers)

    
# class encoder(unet):
#     def __init__(self,imsize):
#         super(encoder,self).__init__()
#         self.conv_encode1 = self.contracting_block(in_channels=1, out_channels=64)
#         self.conv_maxpool1 = MaxPool2D(pool_size=2)
#         self.conv_encode2 = self.contracting_block(64, 128)
#         self.conv_maxpool2 = MaxPool2D(pool_size=2)
#         self.conv_encode3 = self.contracting_block(128, 256)
#         self.conv_maxpool3 = MaxPool2D(pool_size=3)
#         # Bottleneck
#         self.bottleneck1 = tf.keras.Sequential(
#                             [
#                             Conv2D(kernel_size=(3,3), filters=512,activation="relu",padding="same"),
#                             BatchNormalization(axis=-1),
#                             Conv2D(kernel_size=(3,3), filters=512,activation="relu",padding="same"),
#                             BatchNormalization(axis=-1)])
#         self.bottleneck2 = tf.keras.Sequential([Conv2DTranspose(filters=256,kernel_size=(2,2),strides=3,output_padding=1,padding="valid",activation="relu")])
        
# #     @tf.function
#     def call(self,x):
#         self.encode_block1 = self.conv_encode1(x)
#         encode_pool1 = self.conv_maxpool1(self.encode_block1)
#         self.encode_block2 = self.conv_encode2(encode_pool1)
#         encode_pool2 = self.conv_maxpool2(self.encode_block2)
#         self.encode_block3 = self.conv_encode3(encode_pool2)
#         encode_pool3 = self.conv_maxpool3(self.encode_block3)
#         self.discin = self.bottleneck1(encode_pool3)
#         self.encout = self.bottleneck2(self.discin)
        
# class decoder(unet):
#     def crop_and_concat(self, upsampled, bypass, crop=False):
#         if crop:
#             c = (tf.shape(bypass)[2] - upsampled.shape[2]) // 2
#             bypass = tf.keras.layers.Cropping2D(cropping=(c,c))(bypass)
#         return tf.concat((upsampled, bypass),axis=-1)
    
#     def __init__(self,imsize,out_channel=1):
#         super(decoder,self).__init__()
# #         self.con
# #         self.convtr = Conv2DTranspose(filters=512,kernel_size=(3,3),strides=2,output_padding=1,padding="same")
#         self.conv_decode3 = self.expansive_block(512, 256, 128)
#         self.conv_decode2 = self.expansive_block(256, 128, 64)
#         self.final_layer = self.final_block(128, 64, out_channel)
    
# #     @tf.function
#     def forward(self,bottleneck1,encode_block1,encode_block2,encode_block3):
#         decode_block3 = bottleneck1
#         cat_layer2 = self.conv_decode3(decode_block3)
#         decode_block2 = cat_layer2
#         cat_layer1 = self.conv_decode2(decode_block2)
#         decode_block1 = cat_layer1
#         final_layer = self.final_layer(decode_block1)
#         return  final_layer


In [None]:
class Model(tf.keras.Sequential):
    def __init__(self,encoders,decoders):
#       encoder is a model of type encoder defined above,
#       decoders is a list of decoders 
        super(Model,self).__init__()
        self.encoder = encoders
        self.decoders = decoders
        self.discriminator = tf.keras.Sequential([Conv2D(kernel_size=(8,8),strides=1,filters=1024,padding='valid'),Flatten(),Dense(1024,activation="relu"),Dense(1024,activation="relu"),Dense(1,activation='linear')])
        self.TrainableVarsSet = False
    
    def setTrainableVars(self):
        self.genTrainableVariables = self.encoder.trainable_variables+self.decoders[0].trainable_variables+self.decoders[1].trainable_variables
        self.discTrainableVariables = self.discriminator.trainable_variables
        self.TrainableVarsSet = True
#     @tf.function
    def forwardX2Y(self,X,training=False):
        # target_index index of target genre
        self.encoder.call(X)
        #print(self.encoder.encout.shape)
        decoded_out = self.decoders[1].forward(self.encoder.encout,self.encoder.indices)
        if training == True:
            return decoded_out,self.discriminator(self.encoder.encout)
        else:
            return decoded_out
#     @tf.function
    def forwardY2X(self,X,training=False):
        # target_index index of target genre
        self.encoder.call(X)
        #print(self.encoder.encout.shape)
        decoded_out = self.decoders[0].forward(self.encoder.encout,self.encoder.indices)
        if training == True:
            return decoded_out,self.discriminator(self.encoder.encout)
        else:
            return decoded_out
#     @tf.function
    def forwardX2X(self,X,training=False):
        # target_index index of target genre
        self.encoder.call(X)
        decoded_out = self.decoders[0].forward(self.encoder.encout,self.encoder.indices)
        if training == True:
            return decoded_out,self.discriminator(self.encoder.encout)
        else:
            return decoded_out
#     @tf.function
    def forwardY2Y(self,X,training=False):
        # target_index index of target genre
        self.encoder.call(X)
        decoded_out = self.decoders[1].forward(self.encoder.encout,self.encoder.indices)
        if training == True:
            return decoded_out,self.discriminator(self.encoder.encout)
        else:
            return decoded_out        
    def build_disc(self,X):
        #Function to build disciminator initially, please don't call ever again.
        self.encoder.call(X)
        #print(self.encoder.encout)
        self.discriminator(self.encoder.encout)
        
    def reconstruction_loss(self,X,Y):
        # print(X.shape,Y.shape)
        #Pixel-wise l2 loss
        # return  tf.math.reduce_sum(tf.math.reduce_sum(tf.math.reduce_sum((X-Y)**2,
            # axis=-1),axis=-1),axis=-1,keepdims=True)    #see if keepdims is required
        return tf.math.reduce_mean(tf.math.abs(X-Y))

    def loss_classification(self,X,labels):
        return (-1*tf.reduce_mean(labels*(tf.math.log(X+1e-5)) + (1-labels)*(tf.math.log(1-X+1e-5))))
    
    def loss_disc(self,X,labels):
        #Wasserstein loss
        #print((X*(2*labels-1)).shape)
        return -1*tf.reduce_mean(X*(2*labels-1))*disclosswt
    
    @tf.function
    def train_on_batch(self,X,labels,optimizerGen,optimizerDisc):
        if self.TrainableVarsSet == False:
            self.setTrainableVars()
        n,h,w,c = X.shape
        X1 = tf.slice(X,[0,0,0,0],[n//2,h,w,c])
        X2 = tf.slice(X,[n//2,0,0,0],[n//2,h,w,c])
        with tf.GradientTape(persistent=True) as tape:
            transformedX2Y,disc1 = self.forwardX2Y(X1,training=True)
            transformedY2X,disc2 = self.forwardY2X(X2,training=True)
            
            loss_disc = self.loss_disc(tf.concat([disc1,disc2],axis=0),labels)
            discfeats = tf.concat([disc1,disc2],axis=0)
            #print(discfeats)
            #loss_disc = disclosswt*(discfeats)*(2*labels-1)
            #print(loss_disc)
            transformedX2X = self.forwardX2X(X1)
            transformedY2Y = self.forwardY2Y(X2)
            transformedCycleX2X = self.forwardY2X(transformedX2Y)
            transformedCycleY2Y = self.forwardX2Y(transformedY2X)
            
            loss_reconstruction = self.reconstruction_loss(X1,transformedX2X) + self.reconstruction_loss(X2,transformedY2Y)
            #loss_cycle = self.reconstruction_loss(X1,transformedCycleX2X) + self.reconstruction_loss(X2,transformedCycleY2Y)
            #Comment out the line below
            loss_cycle = self.reconstruction_loss(X1,transformedX2Y) + self.reconstruction_loss(X2,transformedY2X)
            #loss_classification = self.classification_loss()
            
            loss_gen = loss_reconstruction-loss_disc + 0*loss_cycle
            
        grads_gen = tape.gradient(loss_gen,self.genTrainableVariables)
        grads_disc = tape.gradient(loss_disc,self.discTrainableVariables)
        grads_and_vars_disc = zip(grads_disc, self.discTrainableVariables)
        
        #tf.clip_by_value(self.discTrainableVariables,-0.01,0.01)
        optimizerDisc.apply_gradients(grads_and_vars_disc)
        grads_and_vars_gen = zip(grads_gen, self.genTrainableVariables)
        optimizerGen.apply_gradients(grads_and_vars_gen)
        del tape
        #print(grads_and_vars_gen)
        return loss_reconstruction,loss_cycle,loss_disc,discfeats




In [None]:
def read_and_decode(filename, epochs,size):
#     filename_queue = tf.train.string_input_producer([filename],num_epochs=epochs)

#     reader = tf.TFRecordReader()
#     _, serialized_example = reader.read(filename_queue)
#     features = tf.parse_single_example(serialized_example,
#                                        features={
#                                            'img_raw': tf.FixedLenFeature([], tf.string),
#                                            'domain': tf.FixedLenFeature([1], tf.float32),
#                                        })
#     img = tf.decode_raw(features['img_raw'], tf.float32)
    
#     img = tf.reshape(img, [size, size, 3])
#     img = tf.cast(img, tf.float32)
#     domain = features['domain']
#     return img,domain
    features={'img_raw': tf.io.FixedLenFeature([], tf.string)}                                           
    def _parse_image_function(example_proto):
        # Parse the input tf.Example proto using the dictionary above.
        feat = tf.io.parse_single_example(example_proto, features)
        img = tf.io.decode_raw(feat['img_raw'],tf.float32)       
        img = tf.reshape(img, [size, size, 1])
        img = tf.cast(img,tf.float32)
        return tf.math.log(1+img)
        
    
    raw_dataset = tf.data.TFRecordDataset([filename])
    raw_dataset = raw_dataset.shard(4, 0)
 
    raw_dataset = raw_dataset.repeat(count=epochs)
    raw_dataset = raw_dataset.shuffle(16)
    
    parsed_image_dataset = raw_dataset.map(_parse_image_function)
    parsed_image_dataset = parsed_image_dataset.prefetch(buffer_size=10)
    parsed_image_dataset = parsed_image_dataset.batch(batch_size=5)
    return iter(parsed_image_dataset)
    print(parsed_image_dataset)
    
    
    



# print(next(imgclassical))




In [None]:
import pickle
size=256
def train(model,epochs,optimizerGen,optimizerDisc,imgclassical,imgjazz):
    numbatches = 10
#     for epoch in range(epochs):
#     print(next(imgjazz))
    i=0
    log = open('log.txt','w')
    for batch_jazz,batch_class in zip(imgjazz,imgclassical):
        X = tf.concat((batch_jazz,batch_class),axis=0)
        Y = tf.expand_dims(tf.concat((tf.ones(batch_jazz.shape[0]),tf.zeros(batch_class.shape[0])),axis=0),-1)
        Y = 1-Y
        l1,l2,l3,l4 = model.train_on_batch(X,Y,optimizerGen,optimizerDisc)
        a = l4.numpy()
#         b = Y.numpy()
        print(a[0],a[5])
#         a = np.where(a>0.5,1.0,0.0)
#         print(np.sum(a==b))
        #Make discriminator a Lipschitz func by wt clipping
        for l in model.discriminator.layers:
            weights = l.get_weights()
            weights = [tf.clip_by_value(w, -0.1, 0.1) for w in weights]
            l.set_weights(weights)        
        
        print(i,"reconstruction-{},cycle-{},disc {} ".format(l1,l2,l3))
        log.write("iteration-{},reconstruction-{},cycle-{},disc {} \n".format(i,l1,l2,l3))
        pickle.dump(model.genTrainableVariables,open("GenVars.pkl",'wb'))
        pickle.dump(model.discTrainableVariables,open("DiscVars.pkl",'wb'))
        if i%20 == 0:
            X1 = batch_jazz
            Y1 = batch_class
            Y2 = model.forwardX2Y(X1)
            X2 = model.forwardY2X(Y1)
            pickle.dump(Y2,open("ClassicalTrans-iter-{}".format(i),"wb"))
            pickle.dump(X2,open("JazzTrans-iter-{}".format(i),"wb"))
            pickle.dump(Y1,open("Classical-iter-{}".format(i),"wb"))
            pickle.dump(X1,open("Jazz-iter-{}".format(i),"wb"))
            for j in range(len(model.decoders[0].trainable_variables)):
                print(np.mean(model.decoders[0].trainable_variables[j] - model.decoders[1].trainable_variables[j]))
            

        i+=1

encs = encoder([size,size,1])
decs = [decoder([size,size,1]),decoder([size,size,1])]
model = Model(encs,decs)
optimizerGen = tf.keras.optimizers.Adam(learning_rate=1e-5)
optimizerDisc = tf.keras.optimizers.Adam(learning_rate=1e-5)
imgclassical = read_and_decode("../input/styletransfer-revised-256/classical.tfrecords", epochs=20, size=size)
imgjazz = read_and_decode("../input/styletransfer-revised-256/jazz.tfrecords", epochs=20, size=size)

model.forwardX2Y(next(imgjazz))
model.forwardY2X(next(imgclassical))
model.build_disc(next(imgclassical))
print("Model Built")
# for x in tf.trainable_variables():
#     if(x.shape==())

train(model,10,imgclassical=imgclassical,imgjazz=imgjazz,optimizerGen=optimizerGen,optimizerDisc=optimizerDisc)
