In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
# Nurodom nuotrauku ismatavimus
img_shape = (64, 64, 3)
# Nurodo kelia iki nuotrauku
image_dir = '../51_paskaita/autoplius/automobiliai'


def build_generator(img_shape):
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_dim=100, activation='relu'))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(1024, activation='relu'))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Dense(np.prod(img_shape), activation='tanh'))
    model.add(layers.Reshape((img_shape)))
    return model

def build_discriminator(img_shape):
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=(img_shape)))
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

def build_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    # Nurodo, kad discriminator modelis nesimokins. tai padeda nuo overfitinimo.
    discriminator.trainable = False
    gan_input = tf.keras.Input(shape=(100,))
    img = generator(gan_input)
    gan_output = discriminator(img)
    gan = tf.keras.Model(gan_input, gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan


def load_images(image_dir, img_shape):
    images_paths = [os.path.join(image_dir, image) for image in os.listdir(image_dir) if image.lower().endswith('jpg')]
    images = []
    for image_path in images_paths:
        try: 
            img = tf.keras.preprocessing.image.load_img(image_path, target_size=img_shape[:2])
            img = tf.keras.preprocessing.image.img_to_array(img)
            images.append(img)
        except Exception as e:
            print(f'Klaida: {e}')
    images = np.array(images)
    images = (images.astype('float32') -127) / 127.5
    return images

x_train = load_images(image_dir, img_shape)

generator = build_generator(img_shape)
discriminator = build_discriminator(img_shape)
gan = build_gan(generator, discriminator)

def train(epochs, batch_size=128, save_interval=50):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        # atsitiktinai pasirenkam indeksus is mokymu duomenu rinkinio
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        imgs = x_train[idx]

        noise = np.random.normal(0,1,(batch_size,100))
        gen_imgs = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        g_loss = gan.train_on_batch(noise, valid)
        if epoch % save_interval == 0:
            print(f'{epoch} D loss: {d_loss[0]} | G loss {g_loss} | accuracy ? ')
            save_img(epoch)

def save_img(epoch):
    # rows and columns
    r,c = 5,5
    noise = np.random.normal(0,1, (r * c, 100))
    gen_imgs = generator.predict(noise)
    # atstatome intensyvumo intervala i [0,1]
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(r,c)
    counter = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[counter, :, :, 0], cmap="gray")
            axs[i,j].axis("off")
            counter +=1
    fig.savefig(f"gan_images/epocha_{epoch}.png")
    plt.close()


train(epochs=30000, batch_size=64, save_interval=1000)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(**kwargs)


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step  




0 D loss: 0.6540790796279907 | G loss [array(0.6741333, dtype=float32), array(0.6741333, dtype=float32), array(0.4140625, dtype=float32)] | accuracy ? 
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 70ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0

KeyboardInterrupt: 