# LSGAN

Least Squares Generative Adversial Network pro generování obrázků disků.

<img src="lsgan_paper.jpg" alt="lsgan example" style="width: 700px;"/>

References:
- [Least Squares Generative Adversarial Networks](https://arxiv.org/pdf/1611.04076.pdf). X. Mao et al., 20156.

## Reálné obrázky

Reálné obrázky disků jsou uloženy v .tfrecords databázi vytvořené pomocí tfrec2gan skriptu.
Pro názornost uvadíme příklad uložených fotografií:
<table><tr>
    <td><img src="ExtractedImages\img_3.jpg" alt="img_3" style="width: 64px;"/></td>
    <td><img src="ExtractedImages\img_4.jpg" alt="img_4" style="width: 64px;"/></td>
    <td><img src="ExtractedImages\img_5.jpg" alt="img_5" style="width: 64px;"/></td>
    <td><img src="ExtractedImages\img_6.jpg" alt="img_6" style="width: 64px;"/></td>
    <td><img src="ExtractedImages\img_7.jpg" alt="img_7" style="width: 64px;"/></td>
</tr></table>

In [1]:
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental.preprocessing import Resizing
import time

from PIL import Image

from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from tensorflow.keras.datasets.mnist import load_data
from skimage.transform import resize
from scipy.linalg import sqrtm

In [2]:
FIDmodel = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

BUFFER_SIZE = 60000
BATCH_SIZE = 32
FID_BATCH = 1000

In [3]:
def scale_images(images, new_shape):
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape, 0).astype('float16')
        # store        
        images_list.append(new_image)
    return np.asarray(images_list)

def calculate_fid(model, act1, images2):
    # calculate activations
    #act1 = model.predict(images1)
    act2 = model.predict(images2)
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [4]:
def _extract_fn(tfrecord):
  ximg = 64
  yimg = 48
  # Extract features
  features = {
    'fpath': tf.io.FixedLenFeature([1], tf.string),
    'image': tf.io.FixedLenFeature([ximg * yimg], tf.int64),
    'label': tf.io.FixedLenFeature([6], tf.float32)
  }

  # Extract the data record
  sample = tf.io.parse_single_example(tfrecord, features)
  fpath = sample['fpath']
  image = sample['image']
  label = sample['label']

  fpath = tf.cast(fpath, tf.string)

  image = tf.reshape(image, [yimg, ximg, 1])  
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1  

  coords = tf.cast(label, 'float32')
  attrs = coords

  return image, attrs

In [5]:
xres = 64; yres=48; dstype = 'train'
tfrecord_file = f"E:\\NCK\\gan_{xres}{yres}{dstype}.tfrecord"
dataset = tf.data.TFRecordDataset(tfrecord_file)
dataset = dataset.map(_extract_fn)
dataset = dataset.repeat()
#dataset = dataset.shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(4)
train_dataset = dataset

fiddataset = tf.data.TFRecordDataset(tfrecord_file)
fiddataset = fiddataset.map(_extract_fn)
fiddataset = fiddataset.repeat()
fiddataset = fiddataset.batch(FID_BATCH)

for item in fiddataset.take(1):
  images_real, _ = item
images_real = (images_real + 1.) * 127.5
images_real = tf.concat((images_real,)*3, axis=3)
images_real = scale_images(images_real.numpy(), (299,299,3))
#images_real = preprocess_input(images_real).astype('float16')
images_real = preprocess_input(images_real)
act_real = FIDmodel.predict(images_real)
del images_real
print("Done reals.")

Done reals.


In [6]:
# The 64x48 Generator
def make_generator_model48():
  gf_dim = 64
  model = tf.keras.Sequential()
  model.add(layers.Dense(gf_dim * 8 * 3 * 4, input_shape=(100,)))
  model.add(layers.Reshape((3, 4, gf_dim * 8)))  # (4, 3, 512)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Reshape((3, 4, gf_dim * 8)))  # (4, 4, 512)
  assert model.output_shape == (None, 3, 4, 512)  # Note: None is the batch size
  
  model.add(layers.Conv2DTranspose(gf_dim * 4, (5, 5), strides=(2, 2), padding='SAME'))
  assert model.output_shape == (None, 6, 8, 256)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(gf_dim * 2, (5, 5), strides=(2, 2), padding='SAME'))
  assert model.output_shape == (None, 12, 16, 128)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(gf_dim * 1, (5, 5), strides=(2, 2), padding='SAME'))
  assert model.output_shape == (None, 24, 32, 64)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='SAME', activation='tanh'))
  assert model.output_shape == (None, 48, 64, 1)

  return model

In [7]:
generator = make_generator_model48()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
#plt.imshow(generated_image[0, :, :, 0], cmap='gray')
print(generator.summary())

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 6144)              620544    
_________________________________________________________________
reshape (Reshape)            (None, 3, 4, 512)         0         
_________________________________________________________________
batch_normalization_94 (Batc (None, 3, 4, 512)         2048      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 3, 4, 512)         0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 3, 4, 512)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 6, 8, 256)         3277056   
_________________________________________________________________
batch_normalization_95 (Batc (None, 6, 8, 256)         1

In [8]:
# The 64x48 Discriminator
def make_discriminator_model48():
  model = tf.keras.Sequential()
  model.add(layers.InputLayer(input_shape=(48, 64, 1)))
  model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='SAME'))
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='SAME'))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='SAME'))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2D(512, (5, 5), strides=(2, 2), padding='SAME'))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Flatten())
  model.add(layers.Dense(1))

  return model

In [9]:
discriminator = make_discriminator_model48()
decision = discriminator(generated_image)
print(decision)
print(discriminator.summary())

tf.Tensor([[0.00029361]], shape=(1, 1), dtype=float32)
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_94 (Conv2D)           (None, 24, 32, 64)        1664      
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 24, 32, 64)        0         
_________________________________________________________________
conv2d_95 (Conv2D)           (None, 12, 16, 128)       204928    
_________________________________________________________________
batch_normalization_98 (Batc (None, 12, 16, 128)       512       
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 12, 16, 128)       0         
_________________________________________________________________
conv2d_96 (Conv2D)           (None, 6, 8, 256)         819456    
_________________________________________________________________

In [10]:
# Loss - see train_step

In [11]:
# Optimizers
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [12]:
# Training loop
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

seed = np.random.normal(size=[num_examples_to_generate, noise_dim]).astype(np.float32)

In [13]:
# Loss computed directly in train_step
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    #noise = np.random.uniform(-1, 1, size=[BATCH_SIZE, noise_dim]).astype(np.float32)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      d_real_logits = discriminator(images, training=True)
      d_fake_logits = discriminator(generated_images, training=True)

      gen_loss = tf.reduce_mean(tf.nn.l2_loss(d_fake_logits - tf.ones_like(d_fake_logits)))
      d_loss_real = tf.reduce_mean(tf.nn.l2_loss(d_real_logits - tf.ones_like(d_real_logits)))
      d_loss_fake = tf.reduce_mean(tf.nn.l2_loss(d_fake_logits - tf.zeros_like(d_real_logits)))
      disc_loss = d_loss_real + d_loss_fake

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

In [14]:
def train(dataset, epochs):  
  epoch = 0
  while epoch < EPOCHS:
    epoch += 1
    epoch_start = time.time()

    for image_batch in dataset.take(1000):
      g_loss, d_loss = train_step(image_batch[0])

    print(g_loss.numpy(), d_loss.numpy())

    generate_and_save_images(generator, epoch, seed)

    # Compute FID for 10000 example
    if epoch % 5 == 0 or epoch == 1:
    #if False:
      print("Calculating FID ... ", end="", flush=True)
      fid_noise = np.random.normal(size=[FID_BATCH, noise_dim]).astype(np.float32)
      fid_fake=np.empty([FID_BATCH, 48, 64, 1])
      dn = 100
      for i in range(FID_BATCH//dn):
        fid_batch = fid_noise[i*dn:(i+1)*dn]
        #print(i, fid_batch.shape)
        g = generator(fid_batch, training=False)
        g = (g + 1.) * 127.5
        fid_fake[i*dn:(i+1)*dn] = g.numpy()
      fid_fake = np.concatenate((fid_fake,)*3, axis=3)
      fid_fake = np.rint(fid_fake).clip(0, 255).astype(np.uint8)
      fid_fake = fid_fake.astype('float32')
      fid_fake = scale_images(fid_fake, (299, 299, 3))
      images_fake = preprocess_input(fid_fake)

      fid_start = time.time()
      fid = calculate_fid(FIDmodel, act_real, images_fake)
      fid_end = time.time()
      msg = f"{fid:6.2f}, time {fid_end - fid_start:.2f} sec"
      print(msg)
      with open("lsgan_images/fid.txt", "a+") as fidfile:
        fidfile.writelines([f"{epoch:03}_FID {msg}\n"])


    print (f"Time for epoch {epoch} is {time.time()-epoch_start:.2f} sec")

  generate_and_save_images(generator, epochs, seed)

In [15]:
def generate_and_save_images(model, epoch, test_input):
  ximg = 64
  yimg = 48
  
  predictions = model(test_input, training=False)

  g = predictions
  g = (g + 1.) * 127.5
  canvas = np.empty((yimg * 4, ximg * 4, 1))
  c = 0
  for i in range(4):
    for j in range(4):
      c = c + 1
      canvas[j * yimg:(j + 1) * yimg, i * ximg:(i + 1) * ximg] = g[c-1]

  image = np.rint(canvas).clip(0, 255).astype(np.uint8)
  image = np.squeeze(image)
  image = Image.fromarray(image)
  image.save(f"lsgan_images/fake_{epoch*1000}.jpg")

In [16]:
def save_real_images(dataset):
  ximg = 64
  yimg = 48

  for images in dataset.take(1):
    g = (images[0][:16] + 1.) * 127.5

  canvas = np.empty((yimg * 4, ximg * 4, 1))
  c = 0
  for i in range(4):
    for j in range(4):
      c = c + 1
      canvas[j * yimg:(j + 1) * yimg, i * ximg:(i + 1) * ximg] = g[c-1]

  image = np.rint(canvas).clip(0, 255).astype(np.uint8)
  image = np.squeeze(image)
  image = Image.fromarray(image)
  image.save("lsgan_images/reals.jpg")

In [17]:
#===
print("Start training ... ")
f = open("lsgan_images/fid.txt", "w"); f.close()
save_real_images(train_dataset)
train(train_dataset, EPOCHS)
#===

Start training ... 
11.802664 1.9922061
Calculating FID ... 388.52, time 10.05 sec
Time for epoch 1 is 39.19 sec
17.617922 1.7316492
Time for epoch 2 is 20.02 sec
13.723892 1.9549859
Time for epoch 3 is 20.18 sec
15.106436 0.734089
Time for epoch 4 is 20.25 sec
16.643465 1.3533058
Calculating FID ... 415.71, time 7.26 sec
Time for epoch 5 is 35.43 sec
9.169331 1.8442205
Time for epoch 6 is 20.14 sec
9.457615 1.4756974
Time for epoch 7 is 20.26 sec
19.395771 1.7670639
Time for epoch 8 is 20.22 sec
13.306315 0.8929231
Time for epoch 9 is 20.29 sec
14.108548 1.9773176
Calculating FID ... 409.71, time 8.24 sec
Time for epoch 10 is 36.40 sec
19.172937 0.81413305
Time for epoch 11 is 20.08 sec
7.90804 2.7169616
Time for epoch 12 is 20.23 sec
13.389352 2.1111858
Time for epoch 13 is 20.22 sec
9.584326 2.4783592
Time for epoch 14 is 20.21 sec
12.175301 0.5488465
Calculating FID ... 414.05, time 8.04 sec
Time for epoch 15 is 35.79 sec
15.426506 0.3543371
Time for epoch 16 is 20.03 sec
18.241991

<table><tr>
    <td><img src="lsgan_images\fake_1000.jpg"  alt="fake_1000"  style="width: 256px;"/></td>
    <td><img src="lsgan_images\fake_5000.jpg"  alt="fake_5000"  style="width: 256px;"/></td>
    <td><img src="lsgan_images\fake_10000.jpg" alt="fake_10000" style="width: 256px;"/></td>
    <td><img src="lsgan_images\fake_15000.jpg" alt="fake_15000" style="width: 256px;"/></td>    
</tr></table>

<table><tr>
    <td><img src="lsgan_images\fake_20000.jpg" alt="fake_20000" style="width: 256px;"/></td>
    <td><img src="lsgan_images\fake_30000.jpg" alt="fake_30000" style="width: 256px;"/></td>
    <td><img src="lsgan_images\fake_40000.jpg" alt="fake_40000" style="width: 256px;"/></td>
    <td><img src="lsgan_images\fake_50000.jpg" alt="fake_50000" style="width: 256px;"/></td>
</tr></table>    