<a href="https://colab.research.google.com/github/abhishtagatya/paintgan-model-comparison/blob/main/google_colab_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Model Name

Model Description

Paper

### Dataset

```
PaintGAN 80K
Just a sliced dataset of Places365 and WikiArt each containing roughly 80,000 content and style images each.
```







### Prelimenary Setup

Setting up and loading dataset from PaintGAN 80K using Tensorflow DataLoader

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip drive/MyDrive/Dataset/paintgan-80k.zip

In [None]:
import os
import glob
import imageio
import numpy as np
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

#@title ## Model Training Parameters { display-mode: "both" }
#@markdown Change training parameters based on experimentation

# Defining Training Parameters
MODEL_NAME = "" #@param{type:"string"}
IMAGE_SIZE = (256, 256) #@param{type:"raw"}
BATCH_SIZE = 8 #@param{type:"integer"}
EPOCHS = 30 #@param{type:"integer"}
CHECKPOINT_PER_EPOCH = 5 #@param{type:"integer"}

AUTOTUNE = tf.data.AUTOTUNE

WIKIART_BASEPATH = "paintgan-dataset/wikiart"
P365_BASEPATH = "paintgan-dataset/places365"

# Loading Image path
content_images = os.listdir(P365_BASEPATH)
content_images = [os.path.join(P365_BASEPATH, path) for path in content_images]

style_images = os.listdir(WIKIART_BASEPATH)
style_images = [os.path.join(WIKIART_BASEPATH, path) for path in style_images]

In [None]:
# Removing corrupted JPEGs on Unzip process (if any)

from struct import unpack

marker_mapping = {
    0xffd8: "Start of Image",
    0xffe0: "Application Default Header",
    0xffdb: "Quantization Table",
    0xffc0: "Start of Frame",
    0xffc4: "Define Huffman Table",
    0xffda: "Start of Scan",
    0xffd9: "End of Image"
}

class JPEG:

    def __init__(self, image_file):
        with open(image_file, 'rb') as f:
            self.img_data = f.read()
    
    def decode(self):
        data = self.img_data
        while(True):
            marker, = unpack(">H", data[0:2])
            if marker == 0xffd8:
                data = data[2:]
            elif marker == 0xffd9:
                return
            elif marker == 0xffda:
                data = data[-2:]
            else:
                lenchunk, = unpack(">H", data[2:4])
                data = data[2+lenchunk:]            
            if len(data)==0:
                break        


corrupted = []

for img in tqdm(style_images):
  image = JPEG(img) 
  try:
    image.decode()   
  except:
    corrupted.append(img)

for name in corrupted:
  style_images.remove(name)

In [None]:
# Image Preprocessing
def preprocess_image(image_path, size=IMAGE_SIZE):
  """
    Preprocess the image by decoding and resizing and image 
    from the image file path.

    Args:
      image_path: The image file path.
      size: The size of the image to be resized to.

    Returns:
      image: resized image
  """

  image = tf.io.read_file(image_path)
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.convert_image_dtype(image, dtype='float32')
  image = tf.image.resize(image, size)

  return image

# Train Test Splits
total_content_images = len(content_images)
train_content = content_images[:int(0.8 * total_content_images)]
val_content = content_images[int(0.8 * total_content_images):int(0.9 * total_content_images)]
test_content = content_images[int(0.9 * total_content_images):]

total_style_images = len(style_images)
train_style = style_images[:int(0.8 * total_style_images)]
val_style = style_images[int(0.8 * total_style_images):int(0.9 * total_style_images)]
test_style = style_images[int(0.9 * total_style_images):]

# Build the style and content tf.data datasets
train_style_ds = (
    tf.data.Dataset.from_tensor_slices(train_style)
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .repeat()
)

val_style_ds = (
    tf.data.Dataset.from_tensor_slices(val_style)
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .repeat()
)

test_style_ds = (
    tf.data.Dataset.from_tensor_slices(test_style)
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .repeat()
)

train_content_ds = (
    tf.data.Dataset.from_tensor_slices(train_content)
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .repeat()
)

val_content_ds = (
    tf.data.Dataset.from_tensor_slices(val_content)
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .repeat()
)

test_content_ds = (
    tf.data.Dataset.from_tensor_slices(test_content)
    .map(preprocess_image, num_parallel_calls=AUTOTUNE)
    .repeat()
)

# Zipping the datasets
train_ds = (
    tf.data.Dataset.zip((train_style_ds, train_content_ds))
    .shuffle(BATCH_SIZE * 2)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

val_ds = (
    tf.data.Dataset.zip((val_style_ds, val_content_ds))
    .shuffle(BATCH_SIZE * 2)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

test_ds = (
    tf.data.Dataset.zip((test_style_ds, test_content_ds))
    .shuffle(BATCH_SIZE * 2)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

### Monitors and Callbacks

Creating custom monitors and callbacks to be called on_epoch_end

In [None]:
test_style, test_content = next(iter(test_ds))

class DisplayMonitor(tf.keras.callbacks.Callback):

  def on_epoch_end(self, epoch, logs=None):
    # Encode the style and content image
    pass

    # Plot the style, content, image
    fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 5))
    ax[0].imshow(tf.keras.preprocessing.image.array_to_img(test_style[0]))
    ax[0].set_title(f"Style: {epoch+1:03d}")

    ax[1].imshow(tf.keras.preprocessing.image.array_to_img(test_content[0]))
    ax[1].set_title(f"Content: {epoch+1:03d}")

    ax[2].imshow(tf.keras.preprocessing.image.array_to_img(test_recon_image[0]))
    ax[2].set_title(f"{MODEL_NAME}: {epoch+1:03d}")

    plt.show()
    plt.close()

In [None]:
!mkdir model_checkpoints

In [None]:
class CheckpointMonitor(tf.keras.callbacks.Callback):

  def on_epoch_end(self, epoch, logs=None):
    # Saving model checkpoint
    if (epoch+1) % CHECKPOINT_PER_EPOCH == 0:
      self.model.save_weights(f'model_checkpoints/{MODEL_NAME}_{epoch+1}.ckpt')

In [None]:
csv_logger = keras.callbacks.CSVLogger(f'{MODEL_NAME}_p365-{EPOCHS}-{BATCH_SIZE}.csv', append=True, separator=';')

In [None]:
callbacks = [
  DisplayMonitor(),
  CheckpointMonitor(),
  csv_logger
]

### Model Implementation

### Result