In [1]:
from tensorflow.keras import Sequential, Model, backend, losses, optimizers
from tensorflow import ones_like, zeros_like, random, GradientTape
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense,BatchNormalization,LeakyReLU,Conv2DTranspose,Conv2D,Dropout,Flatten, Reshape

In [2]:
print(tf.config.list_physical_devices('GPU'))

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
def generator_model(noise_dim):
    model = Sequential()
    model.add(Dense(8*8*512, use_bias=False, input_shape=(noise_dim,)))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((8, 8, 512)))

    model.add(Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2DTranspose(1, (4, 4), strides=(2, 2), padding='same', use_bias=False))

    return model

In [4]:
generator = generator_model(100)
generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 32768)             3276800   
                                                                 
 batch_normalization (BatchN  (None, 32768)            131072    
 ormalization)                                                   
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 32768)             0         
                                                                 
 reshape (Reshape)           (None, 8, 8, 512)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 16, 16, 256)      2097152   
 nspose)                                                         
                                                                 
 batch_normalization_1 (Batc  (None, 16, 16, 256)      1

In [5]:
def discriminator_model():

    input_layer = Input(shape=[128, 128, 3])

    conv1 = Conv2D(16, (3, 3), strides=(3,3), padding='same')(input_layer)
    act1 = LeakyReLU()(conv1)
    drop1 = Dropout(0.3)(act1)

    conv2 = Conv2D(32, (3, 3), strides=(3,3), padding='same')(drop1)
    act2 = LeakyReLU()(conv2)
    drop2 = Dropout(0.3)(act2)


    conv3 = Conv2D(64, (3, 3), strides=(3,3), padding='same')(drop2)
    act3 = LeakyReLU()(conv3)
    drop3 = Dropout(0.3)(act3)


    conv4 = Conv2D(128, (3, 3), strides=(3,3), padding='same')(drop3)
    act4 = LeakyReLU()(conv4)
    drop4 = Dropout(0.3)(act4)


    conv5 = Conv2D(128, (3, 3), strides=(3,3), padding='same')(drop4)
    act5 = LeakyReLU()(conv5)
    drop5 = Dropout(0.3)(act5)


    flat = Flatten()(drop5)
    output = Dense(1)(flat)

    model = Model(inputs=input_layer, outputs=output)

    return model

In [6]:
discriminator = discriminator_model()
discriminator.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 43, 43, 16)        448       
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 43, 43, 16)        0         
                                                                 
 dropout (Dropout)           (None, 43, 43, 16)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 15, 15, 32)        4640      
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 15, 15, 32)        0         
                                                                 
 dropout_1 (Dropout)         (None, 15, 15, 32)        0     

In [7]:
loss_func = losses.BinaryCrossentropy(from_logits=True)
def gen_loss_calc(fake):
    return loss_func(ones_like(fake), fake)

def dis_loss_calc(real, fake):
    real_loss = loss_func(ones_like(real), real)
    fake_loss = loss_func(zeros_like(fake), fake)
    total_loss = real_loss + fake_loss
    return total_loss

In [8]:
gen_opt = optimizers.Adam(1e-4)
dis_opt = optimizers.Adam(1e-4)

In [9]:
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

In [10]:
%matplotlib qt
plt.ion()
def plot_image(image):
    # Assuming the image is in the range [0,1] or [-1,1]
    image = (image + 1) / 2  # Rescale image to [0,1] if needed
    plt.imshow(image)
    plt.axis('off')  # Turn off axis for a cleaner display


In [25]:
def train_epoch(images, batch_size, noise_dim,epoch,batch_number):
    if len(images) < batch_size:
        print(f"Warning: Not enough images for a full batch. Using {len(images)} images.")
        batch_size = len(images)
    start = time.time()
    noise = random.normal([batch_size, noise_dim])

    global fig, ax

    with GradientTape() as gen_tape, GradientTape() as disc_tape:
        generated = generator(noise, training=True)
        plot_image(images[0])
        #print(generated.shape,images.shape)
        real = discriminator(images, training=True)
        fake = discriminator(generated, training=True)

        gen_loss = gen_loss_calc(fake)
        dis_loss = dis_loss_calc(real, fake)

        real_acc = np.mean(real > 0.5)
        fake_acc = np.mean(fake < 0.5)
        disc_acc = 0.5 * (real_acc + fake_acc)
        gen_acc = np.mean(fake > 0.5)

        gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
        dis_gradients = disc_tape.gradient(dis_loss, discriminator.trainable_variables)

        gen_opt.apply_gradients(zip(gen_gradients, generator.trainable_variables))
        dis_opt.apply_gradients(zip(dis_gradients, discriminator.trainable_variables))
        clear_output(wait=True)
        display(f"Epoch {epoch + 1}/50 batch {batch_number}")
        display(f"gen_loss: {gen_loss:.4f}, gen_acc: {gen_acc:.4f}")
        display(f"disc_loss: {dis_loss:.4f}, disc_acc: {disc_acc:.4f}")
        display(f"time: {time.time() - start:.2f}s")    
        if batch_number % 20 == 0:
            plt.figure(figsize=(5, 5))
            plot_image(generated[0])
            plt.savefig(f'{epoch}_{batch_number}.png')
            plt.show()
            time.sleep(4)
        if batch_number == 141:
            generator.save("generator.keras")
            discriminator.save("discriminator.keras")


def train(dataset, epochs):
    for epoch in range(epochs):
        k = 1
        for batch in dataset:
            print()
            k+=1
            train_epoch(batch,32,100,epoch,k)


In [17]:
from PIL import Image
import io
i = 0
def preprocess_image(img):
    global i
    img = Image.open(io.BytesIO(img["bytes"]))
    img = img.resize((128, 128))
    img = img.convert("L")
    img = np.asarray(img, dtype=np.float16)
    img_array = img/255
    i+=1
    print(f"\r{i}/{len(df)}                      ",end = '')
    return img_array

In [13]:
import pickle

In [22]:
#df = pd.read_parquet("hf://datasets/Wejh/celeb-a-hq___0-to-4999___FLUX.1-dev_training_faces/base_transforms/train-00000-of-00001.parquet")[["image"]]

# with open("data.pkl", "wb") as file:
#     # Use pickle.dump() to serialize the object
#     pickle.dump(df, file)

with open("data.pkl", "rb") as file:
    # Use pickle.load() to deserialize the object
    df = pickle.load(file)

In [24]:
dataset = tf.data.Dataset.from_tensor_slices(np.array(df['image'].apply(preprocess_image).tolist(), dtype=np.float16).reshape(-1, 128, 128,1))
del df
dataset = dataset.shuffle(buffer_size=1024).batch(32)

8924/4462                      

In [28]:
train(dataset, 150)




ValueError: Exception encountered when calling layer "model" "                 f"(type Functional).

Input 0 of layer "conv2d" is incompatible with the layer: expected axis -1 of input shape to have value 3, but received input with shape (32, 128, 128, 1)

Call arguments received by layer "model" "                 f"(type Functional):
  • inputs=tf.Tensor(shape=(32, 128, 128, 1), dtype=float16)
  • training=True
  • mask=None