In [1]:
import os
os.makedirs('./images_pseudo')
!unzip ../input/monet-effi/images.zip -d ./images_pseudo

len(os.listdir('./images_pseudo'))

Archive:  ../input/monet-effi/images.zip
  inflating: ./images_pseudo/6790.jpg  
  inflating: ./images_pseudo/4644.jpg  
  inflating: ./images_pseudo/1835.jpg  
  inflating: ./images_pseudo/1102.jpg  
  inflating: ./images_pseudo/6040.jpg  
  inflating: ./images_pseudo/4423.jpg  
  inflating: ./images_pseudo/1806.jpg  
  inflating: ./images_pseudo/4760.jpg  
  inflating: ./images_pseudo/217.jpg  
  inflating: ./images_pseudo/5790.jpg  
  inflating: ./images_pseudo/195.jpg  
  inflating: ./images_pseudo/3694.jpg  
  inflating: ./images_pseudo/3868.jpg  
  inflating: ./images_pseudo/224.jpg  
  inflating: ./images_pseudo/5352.jpg  
  inflating: ./images_pseudo/3947.jpg  
  inflating: ./images_pseudo/5759.jpg  
  inflating: ./images_pseudo/2071.jpg  
  inflating: ./images_pseudo/5662.jpg  
  inflating: ./images_pseudo/4020.jpg  
  inflating: ./images_pseudo/3274.jpg  
  inflating: ./images_pseudo/1424.jpg  
  inflating: ./images_pseudo/2629.jpg  
  inflating: ./ima

7038

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import pandas as pd
from tqdm import tqdm
import time

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import sys
from PIL import Image
import glob


AUTO = tf.data.experimental.AUTOTUNE
print(tf.__version__)

2.4.1


In [3]:
GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')

GCS_PATH

'gs://kds-1e682722432ed3ff288b28a0dfe4078e73c23fe5ad1524bf80aa0081'

In [4]:
import re
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)

BATCH_SIZE =  4
EPOCHS_NUM = 30

print(f'Monet TFRecord files: {len(MONET_FILENAMES)}')
print(f'Monet image files: {n_monet_samples}')
print(f'Photo TFRecord files: {len(PHOTO_FILENAMES)}')
print(f'Photo image files: {n_photo_samples}')
print(f"Batch_size: {BATCH_SIZE}")
print(f"Epochs number: {EPOCHS_NUM}")

Monet TFRecord Files: 5
Photo TFRecord Files: 20
Monet TFRecord files: 5
Monet image files: 300
Photo TFRecord files: 20
Photo image files: 7038
Batch_size: 4
Epochs number: 30


In [5]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [6]:
def data_augment(image):
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    if p_crop > .5:
        image = tf.image.resize(image, [286, 286])
        image = tf.image.random_crop(image, size=[256, 256, 3])
        if p_crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
    
    if p_rotate > .9:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .7:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=1) # rotate 90º
        
    if p_spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial > .9:
            image = tf.image.transpose(image)
    
    return image

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)
    
    if augment:
        monet_ds = monet_ds.map(augment, num_parallel_calls=AUTO)
        photo_ds = photo_ds.map(augment, num_parallel_calls=AUTO)
        
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTO)
    photo_ds = photo_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    
    return gan_ds

In [7]:
!mkdir ./images_pseudo_tfrec

In [8]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
        
def image_example(imageString, label, name):
    imageShape = tf.image.decode_jpeg(imageString).shape

    feature = {
        'image_name':_bytes_feature(name),                 #이미지 이름
        'image': _bytes_feature(imageString), #이미지는 0~255의 3차원값들
        'target': _bytes_feature(label),     #랜드마크 분류값
    }

    return tf.train.Example(features=tf.train.Features(feature=feature))

def convert_func(paths):
    i=0
    for path in tqdm(paths):
        i += 1
        recordFile = './images_pseudo_tfrec/' + str(i) + '.tfrecords' #file name
        with tf.io.TFRecordWriter(recordFile) as writer:
            imageString = open(path, 'rb').read()
            tfExample = image_example(imageString, 'None'.encode(), 'None'.encode())
            writer.write(tfExample.SerializeToString())

In [9]:
def convert_to_tfrecord(dataset_name, data_directory, segments=23, files='/*.jpg'):
    """Convert the dataset into TFRecords on disk
    
    Args:
        dataset_name:   The name/folder of the dataset
        data_directory: The directory where records will be stored
        segments:       The number of files on disk to separate records into
        directories_as_labels: Whether the directory name should be used as it's label (used for test directory)
        files:          Which files to find in the data directory
    """
    
    filenames = glob.glob(dataset_name + files)
    num_examples = len(filenames)
    dataset = filenames
    samples_per_segment = 300
    print(f"Have {samples_per_segment} per record file")
    
    for segment_index in range(segments):
        start_index = segment_index * samples_per_segment
        end_index = (segment_index + 1) * samples_per_segment
        
        sub_dataset = dataset[start_index:end_index]
        record_filename = os.path.join(data_directory, f"{segment_index}.tfrecords")

        with tf.io.TFRecordWriter(record_filename) as writer:
            print(f"Writing {record_filename}")

            for index, sample in enumerate(sub_dataset):
#                 sys.stdout.write(f"\rProcessing sample {start_index+index+1} of {num_examples}")
#                 sys.stdout.flush()

                file_path = sample
                imageString = open(file_path, 'rb').read()
#                 image_raw = np.array(imageString).tostring()
                
                features = {
                    'image_name': _bytes_feature('None'.encode()),
                    'image': _bytes_feature(imageString),
                    'target': _bytes_feature('None'.encode())
                }
                example = tf.train.Example(features=tf.train.Features(feature=features))
                writer.write(example.SerializeToString())  
                

In [10]:
convert_to_tfrecord(dataset_name = 'images_pseudo', data_directory = 'images_pseudo_tfrec', files='*/*.jpg')

Have 300 per record file
Writing images_pseudo_tfrec/0.tfrecords
Writing images_pseudo_tfrec/1.tfrecords
Writing images_pseudo_tfrec/2.tfrecords
Writing images_pseudo_tfrec/3.tfrecords
Writing images_pseudo_tfrec/4.tfrecords
Writing images_pseudo_tfrec/5.tfrecords
Writing images_pseudo_tfrec/6.tfrecords
Writing images_pseudo_tfrec/7.tfrecords
Writing images_pseudo_tfrec/8.tfrecords
Writing images_pseudo_tfrec/9.tfrecords
Writing images_pseudo_tfrec/10.tfrecords
Writing images_pseudo_tfrec/11.tfrecords
Writing images_pseudo_tfrec/12.tfrecords
Writing images_pseudo_tfrec/13.tfrecords
Writing images_pseudo_tfrec/14.tfrecords
Writing images_pseudo_tfrec/15.tfrecords
Writing images_pseudo_tfrec/16.tfrecords
Writing images_pseudo_tfrec/17.tfrecords
Writing images_pseudo_tfrec/18.tfrecords
Writing images_pseudo_tfrec/19.tfrecords
Writing images_pseudo_tfrec/20.tfrecords
Writing images_pseudo_tfrec/21.tfrecords
Writing images_pseudo_tfrec/22.tfrecords


In [11]:
example = load_dataset('images_pseudo_tfrec/9.tfrecords')

len(list(example))

300

In [12]:
PSEUDO_FILENAMES = tf.io.gfile.glob('images_pseudo_tfrec/*.tfrecords')
print('PSEUDO TFRecord Files:', len(PSEUDO_FILENAMES))

PSEUDO TFRecord Files: 23


In [13]:
MONET_FILENAMES

['gs://kds-1e682722432ed3ff288b28a0dfe4078e73c23fe5ad1524bf80aa0081/monet_tfrec/monet00-60.tfrec',
 'gs://kds-1e682722432ed3ff288b28a0dfe4078e73c23fe5ad1524bf80aa0081/monet_tfrec/monet04-60.tfrec',
 'gs://kds-1e682722432ed3ff288b28a0dfe4078e73c23fe5ad1524bf80aa0081/monet_tfrec/monet08-60.tfrec',
 'gs://kds-1e682722432ed3ff288b28a0dfe4078e73c23fe5ad1524bf80aa0081/monet_tfrec/monet12-60.tfrec',
 'gs://kds-1e682722432ed3ff288b28a0dfe4078e73c23fe5ad1524bf80aa0081/monet_tfrec/monet16-60.tfrec']

In [14]:
!rm -rf images_pseudo