In [2]:
import numpy as np
import pandas as pd
from keras.utils import image_dataset_from_directory

In [3]:
df="/kaggle/input/persian-handwritten-digits/Train"

In [4]:
df

'/kaggle/input/persian-handwritten-digits/Train'

In [5]:
train_data=image_dataset_from_directory(
df,
image_size=(256,256),
batch_size=32,
label_mode=None,
)

Found 100000 files.


# # GAN

In [11]:
from keras.layers import Dense,Flatten,BatchNormalization,Dropout,ReLU,Reshape,Conv2DTranspose,Conv2D,MaxPooling2D,Input
from keras.models import Model
from keras.initializers import HeNormal
from keras.regularizers import l2
from keras.callbacks import EarlyStopping

In [46]:
def build_gen():
    inputs=Input(shape=(100,))
    C=Dense(16*16*256)(inputs)
    C=Reshape((16,16,256))(C)
    C=Conv2DTranspose(256,(3,3),strides=(2,2),padding="same")(C)
    C=ReLU()(C)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)

    C=Conv2DTranspose(128,(3,3),strides=(2,2),padding="same",kernel_initializer="HeNormal",kernel_regularizer=l2(0.01))(C)
    C=ReLU()(C)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)

    C=Conv2DTranspose(64,(3,3),strides=(2,2),activation="relu",padding='same')(C)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)

    C=Conv2DTranspose(32,(3,3),strides=(2,2),padding="same",kernel_initializer="HeNormal",kernel_regularizer=l2(0.01))(C)
    C=ReLU()(C)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)
    C=Conv2D(3,(3,3),activation="relu",padding="same")(C)
    model=Model(inputs,C)
    return model
gen=build_gen()
gen.summary()

In [47]:
def build_dis():
    inputs=Input(shape=(256,256,3))
    C=Conv2D(32,(3,3),activation="relu",padding='same')(inputs)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)
    C=MaxPooling2D(pool_size=(2,2),strides=(2,2))(C)

    C=Conv2D(64,(3,3),padding="same",kernel_initializer="HeNormal",kernel_regularizer=l2(0.01))(C)
    C=ReLU()(C)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)
    C=MaxPooling2D(pool_size=(2,2),strides=(2,2))(C)

    C=Conv2D(128,(3,3),activation="relu",padding='same')(C)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)
    C=MaxPooling2D(pool_size=(2,2),strides=(2,2))(C)

    C=Conv2D(256,(3,3),padding="same",kernel_initializer="HeNormal",kernel_regularizer=l2(0.01))(C)
    C=ReLU()(C)
    C=Dropout(0.5)(C)
    C=BatchNormalization()(C)
    C=MaxPooling2D(pool_size=(2,2),strides=(2,2))(C)

    C=Flatten()(C)
    D=Dense(32,activation="relu")(C)
    D=Dense(1,activation="sigmoid")(D)
    model=Model(inputs,D)
    return model
dis=build_dis()
dis.summary()

In [48]:
def build_gan(generator,discriminator,input_shape):
    dis.trainable=False
    inputs=Input(shape=(input_shape,))
    fake_img=generator(inputs)
    outputs=discriminator(fake_img)
    model=Model(inputs,outputs)
    return model
gan=build_gan(generator=gen,discriminator=dis,input_shape=100)
gan.summary()

In [53]:
def train_gan(gan,generator,discriminator,epochs,batch_size,latent_dim,dataset):
    discriminator.trainable=True
    discriminator.compile(optimizer="adam",loss="binary_crossentropy",metrics=["accuracy"])
    gan.compile(optimizer="adam",loss="binary_crossentropy",metrics=["accuracy"])
    for epoch in range(epochs):
        for real_img in dataset:
            real_label=np.ones((batch_size,1))
            fake_label=np.ones((batch_size,1))
            noise=np.random.normal(0,1,(batch_size,latent_dim))
            fake_img=generator(noise)
            discriminator.trainable=True
            discriminator.train_on_batch(real_img,real_label)
            discriminator.train_on_batch(fake_img,fake_label)
            discriminator.trainable=False
            g_g=gan.train_on_batch(noise,real_label)
    return g_g


In [None]:
train_gan=train_gan(gan,gen,dis,3,32,100,train_data)