In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
import tensorflow as tf
from tensorflow import keras
from keras import layers
import pandas as pd
import numpy as np
import random
import math
import matplotlib.pyplot as plt

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
#directories
CHECKPOINT_DIR = './checkpoints/final_version_3'
SAMPLE_DIR = './samples/final_version'
INFERENCE_DIR = './inference/final_version'

# data
DATASET_REPETITIONS = 5
NUM_EPOCHS = 300  
IMAGE_SIZE = 64
MAX_TEXT_LENGTH  = 77
PLOT_DIFFUSION_STEPS = 50

# sampling
MIN_SIGNAL_RATE = 0.02
MAX_SIGNAL_RATE = 0.95

# architecture
EMBEDDING_DIMENSIONS = 32
EMBEDDING_MAX_FREQUENCY = 1000.0
WIDTHS = [64, 96, 128, 160]
BLOCK_DEPTH = 2
HEAD = 2

# optimization
BATCH_SIZE = 64
EMA = 0.999
WEIGHT_DECAY = 1e-4
START_EMA = 2000

In [None]:
if not os.path.exists(SAMPLE_DIR):
    os.makedirs(SAMPLE_DIR)

if not os.path.exists(INFERENCE_DIR):
    os.makedirs(INFERENCE_DIR)

## Preprocess Text

In [None]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))

In [None]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

## Create Dataset by Dataset API
* Tokenizer: https://github.com/openai/CLIP

In [None]:
# in this competition, you have to generate image in size 64x64x3
from src.clip_tokenizer import SimpleTokenizer
tokenizer = SimpleTokenizer()

def training_data_generator(captions, image_path):
    # Load and preprocess image
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.cast(img, tf.float32)
    img.set_shape([None, None, 3])

    # Crop and resize image
    height, width = tf.shape(img)[0], tf.shape(img)[1]
    crop_size = tf.minimum(height, width)
    img = tf.image.crop_to_bounding_box(img, (height - crop_size) // 2, (width - crop_size) // 2, crop_size, crop_size)
    img = tf.image.random_flip_left_right(img)
    img = tf.image.resize(img, size=[IMAGE_SIZE, IMAGE_SIZE], antialias=True)
    img = tf.clip_by_value(img / 255.0, 0.0, 1.0)
    
    # Select random caption
    idx = tf.random.uniform(shape=(1,), minval=0, maxval=10, dtype=tf.int32)
    caption = tf.gather(captions, idx)[0]
    
    return img, caption

def dataset_generator(filenames, batch_size, data_generator):
    # load the training data into two NumPy arrays
    df = pd.read_pickle(filenames)
    captions = df['Captions'].values
    caption = []

    for i in range(len(captions)):
        img_caption = []
        img_raw_caption_id = captions[i]
        for cap_id in img_raw_caption_id:
            img_raw_words = [id2word_dict[str(c_id.astype(np.int32))] for c_id in cap_id if c_id != '5427']
            img_raw_caption = " ".join(img_raw_words)
            img_caption_id = tokenizer.encode(img_raw_caption)
            phrase = img_caption_id + [49407] * (MAX_TEXT_LENGTH - len(img_caption_id))
            img_caption.append(phrase)
        while len(img_caption) < 10:
            img_caption.append(random.choice(img_caption))
        caption.append(img_caption)
        
    caption = np.asarray(caption)
    caption = caption.astype(np.int32)
    image_path = df['ImagePath'].values
    
    # assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == image_path.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))\
                             .map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
                             .cache()\
                             .repeat(DATASET_REPETITIONS)\
                             .shuffle(10 * BATCH_SIZE)\
                             .batch(BATCH_SIZE, drop_remainder=True)\
                             .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

In [None]:
from src.clip_encoder import CLIPTextTransformer

def get_text_encoder():
    input_word_ids = layers.Input(shape=(MAX_TEXT_LENGTH,), dtype="int32")
    input_pos_ids = layers.Input(shape=(MAX_TEXT_LENGTH,), dtype="int32")
    embeds = CLIPTextTransformer()([input_word_ids, input_pos_ids])
    text_encoder = keras.models.Model([input_word_ids, input_pos_ids], embeds)
    text_encoder.trainable = False
    text_encoder_weights_fpath = keras.utils.get_file(
        origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/text_encoder.h5",
        file_hash="d7805118aeb156fc1d39e38a9a082b05501e2af8c8fbdc1753c9cb85212d6619",
    )
    text_encoder.load_weights(text_encoder_weights_fpath)
    return text_encoder

text_encoder = get_text_encoder()
# text_encoder.summary()

In [None]:
pos_ids_one = np.array(list(range(MAX_TEXT_LENGTH)))[None].astype("int32")
pos_ids = np.repeat(pos_ids_one, BATCH_SIZE, axis=0)
test_pos_ids = np.repeat(pos_ids_one, 10, axis=0)
uncondition_cap_one = np.array([[49406]+[49407]*(MAX_TEXT_LENGTH-1)])
uncondition_caps = np.repeat(uncondition_cap_one, BATCH_SIZE, axis=0)
test_uncondition_caps = np.repeat(uncondition_cap_one, 10, axis=0)
uncondition_caps_emb = text_encoder([uncondition_caps, pos_ids])
train_dataset = dataset_generator(data_path + '/text2ImgData.pkl', BATCH_SIZE, training_data_generator)


## MODEL(DDIM)

* Reference: https://keras.io/examples/generative/ddim/

In [None]:
sample_sentences = [
 'this white and purple flower has fragile petals and soft stamens',
 'this flower has four large wide pink petals with white centers and vein like markings',
 'a flower with broad white and pink ribbed petals and yellow stamen',
 'one prominet pistil with alarger stigam and many stamens with anthers',
 'leaves are green in color petals are light pink in color',
 'this flower is bright pink with overlapping petals and a lime green pistil',
 'this flower is white and yellow in color with petals that are multi colored',
 'this flower has 4 leaves three are purple and yellow with lines and one is solid purple',
 'the pretty flower has dark and white petals on it',
 'this flower has petals that are white with yellow stamen'
]

sample_sentences_ids = []
for sample_sentence in sample_sentences:
    sample_sentence_id = tokenizer.encode(sample_sentence)
    phrase = sample_sentence_id + [49407] * (77 - len(sample_sentence_id))
    sample_sentences_ids.append(phrase)
    
sample_sentences_ids = np.array(sample_sentences_ids)
sample_sentences_emb = text_encoder([sample_sentences_ids, test_pos_ids])

un_sample_sentences_emb = text_encoder([test_uncondition_caps, test_pos_ids])

train_sample_sentences = [
 'the flower has bright purple petals and its pistils are dark purple',
 'this flower has petals that are yellow and very stingy',
 'this flower has layers of light yellow sepals holding larger layers of bright pink petals',
 'the flower petals are rounded in shape and are bright yellow in clor',
 'this flower is bright yellow in color and has petals that are very skinny and long',
 'this flower has a lower row of pointed white petals and an upper row of long thin purple petals',
 'this flower has petals that are purple with yellow stame',
 'this flower is pink and white in color with petals that are very small',
 'this flower is white in color and has petals that are ruffled',
 'a flower with a singular conical pink petal with black stripes and orange dotting on the interior of the cone'
]

train_sample_sentences_ids = []
for sample_sentence in train_sample_sentences:
    sample_sentence_id = tokenizer.encode(sample_sentence)
    phrase = sample_sentence_id + [49407] * (77 - len(sample_sentence_id))
    train_sample_sentences_ids.append(phrase)
    
train_sample_sentences_ids = np.array(train_sample_sentences_ids)
train_sample_sentences_emb = text_encoder([train_sample_sentences_ids, test_pos_ids])

print(sample_sentences_emb.shape)
print(train_sample_sentences_emb.shape)
print(un_sample_sentences_emb.shape)

In [None]:
class DiffusionModel(keras.Model):
    def __init__(self, image_size, WIDTHS, BLOCK_DEPTH, **kwargs):
        super().__init__()

        self.normalizer = layers.Normalization()
        self.network = get_network(image_size, WIDTHS, BLOCK_DEPTH) #image_size, WIDTHS, BLOCK_DEPTH
        self.EMA_network = keras.models.clone_model(self.network)

    def compile(self, **kwargs):
        super().compile(**kwargs)

        self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = keras.metrics.Mean(name="i_loss")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker]

    def denormalize(self, images):
        images = self.normalizer.mean + images * self.normalizer.variance**0.5
        return tf.clip_by_value(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = tf.acos(MAX_SIGNAL_RATE)
        end_angle = tf.acos(MIN_SIGNAL_RATE)

        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

        signal_rates = tf.cos(diffusion_angles)
        noise_rates = tf.sin(diffusion_angles)

        return noise_rates, signal_rates

    @tf.function
    def denoise(self, noisy_images, caption, noise_rates, signal_rates, training):
        print("In denoise")
        if training:
            network = self.network
        else:
            network = self.EMA_network
        pred_noises = network([noisy_images, noise_rates**2, caption], training=training)
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates

        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps, caption_emb, un_caption_emb, cfg_scale):
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps

        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images

            diffusion_times = tf.ones((num_images, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            
            un_pred_noises, un_pred_images = self.denoise(
                noisy_images, un_caption_emb, noise_rates, signal_rates, training=False
            )
            pred_noises, pred_images = self.denoise(
                noisy_images, caption_emb, noise_rates, signal_rates, training=False
            )
            
            pred_noises = un_pred_noises + cfg_scale *(pred_noises-un_pred_noises)
            pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates

            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )
            next_noisy_images = (
                next_signal_rates * pred_images + next_noise_rates * pred_noises
            )


        return pred_images

    def generate(self, num_images, diffusion_steps, caption_emb, un_caption_emb, cfg_scale, seed=None):
        initial_noise = tf.random.normal(shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, 3), seed=seed)
        generated_images = self.reverse_diffusion(initial_noise, diffusion_steps, caption_emb, un_caption_emb, cfg_scale)
        generated_images = self.denormalize(generated_images)
        return generated_images
    
    @tf.function
    def train_step(self, images, cap):
        print("Tracing...(This line should be printed only once if @tf.function is enabled.)")
        images = self.normalizer(images, training=True)
        noises = tf.random.normal(shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))

        diffusion_times = tf.random.uniform(
            shape=(BATCH_SIZE, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            pred_noises, pred_images = self.denoise(
                noisy_images, cap, noise_rates, signal_rates, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)

        if global_steps < START_EMA:
            for weight, EMA_weight in zip(self.network.weights, self.EMA_network.weights):
                EMA_weight.assign(weight)            
        else:
            for weight, EMA_weight in zip(self.network.weights, self.EMA_network.weights):
                EMA_weight.assign(EMA * EMA_weight + (1 - EMA) * weight)
            
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, idx, caps, cfg_scale=7.5, un_caps=uncondition_caps_emb, batch_size=BATCH_SIZE, seed=None):
        generated_images = self.generate(
            num_images=batch_size,
            diffusion_steps=PLOT_DIFFUSION_STEPS,
            caption_emb=caps,
            un_caption_emb=un_caps, 
            cfg_scale=cfg_scale, 
            seed=seed
        )
        
        for i in range(BATCH_SIZE):
            img = tf.image.resize(generated_images[i], size=[64, 64], antialias=True)
            img = tf.clip_by_value(img, 0.0, 1.0)
            plt.imsave(INFERENCE_DIR+'/inference_{:04d}.jpg'.format(idx[i].numpy()), img.numpy())
        return
    
    def merge(self, images, size):
        h, w = images.shape[1], images.shape[2]
        img = np.zeros((h * size[0], w * size[1], 3))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j*h:j*h+h, i*w:i*w+w, :] = image
        return img

    def imsave(self, images, size, path):
        return plt.imsave(path, self.merge(images, size))

    def save_images(self, images, size, image_path):
        return self.imsave(images, size, image_path)

    def plot_images(self, epoch=None, cfg_scale=7.5, caps=sample_sentences_emb, un_caps=un_sample_sentences_emb, num_rows=2, num_cols=5, img_name='/train_'):
        generated_images = self.generate(
            num_images=num_rows * num_cols,
            diffusion_steps=PLOT_DIFFUSION_STEPS,
            caption_emb=caps,
            un_caption_emb=un_caps, 
            cfg_scale=cfg_scale, 
            seed=87
        )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index])
                plt.axis("off")
        plt.tight_layout()
        plt.show()
        plt.close()
        self.save_images(generated_images, [num_rows, num_cols], SAMPLE_DIR + img_name + '{:02d}.jpg'.format(epoch))

In [None]:
def sinusoidal_embedding(x):
    embedding_min_frequency = 1.0
    frequencies = tf.exp(
        tf.linspace(
            tf.math.log(embedding_min_frequency),
            tf.math.log(EMBEDDING_MAX_FREQUENCY),
            EMBEDDING_DIMENSIONS // 2,
        )
    )
    angular_speeds = 2.0 * math.pi * frequencies
    embeddings = tf.concat(
        [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
    )
    return embeddings

def td_dot(a, b):
    aa = tf.reshape(a, (-1, a.shape[2], a.shape[3]))
    bb = tf.reshape(b, (-1, b.shape[2], b.shape[3]))
    cc = keras.backend.batch_dot(aa, bb)
    return tf.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2]))

def TransformerBlock(dim, n_heads, d_head):
    multi_head_attention = layers.MultiHeadAttention(num_heads=n_heads, key_dim=d_head)

    def feed_forward_network(x):
        xp = layers.Dense(dim * 2)(x)
        x, gate = xp[..., :dim], xp[..., dim:]  
        return x + gate

    def apply(x, context=None):
        # Store the original shape of x
        original_shape = tf.shape(x)
        batch_size, seq_length, features = original_shape[0], original_shape[1], original_shape[2]

        # Layer normalization and reshaping to rank 3
        x_norm = layers.LayerNormalization(epsilon=1e-5)(x)
        x_norm = tf.reshape(x_norm, [batch_size, -1, features])

        # Prepare context if provided
        if context is not None:
            context = layers.LayerNormalization(epsilon=1e-5)(context)
            context = tf.reshape(context, [batch_size, -1, context.shape[-1]])  # Reshape to rank 3

        # Apply MultiHeadAttention
        if context is not None:
            attention_output = multi_head_attention(x_norm, context)
        else:
            attention_output = multi_head_attention(x_norm, x_norm)

        # Reshape attention_output to match the original shape of x
        attention_output = tf.reshape(attention_output, original_shape)

        # Add the attention output to the original x
        x = attention_output + x

        # Apply feedforward network
        x = layers.Dense(dim)(feed_forward_network(layers.LayerNormalization(epsilon=1e-5)(x))) + x

        return x

    return apply

def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[3]
        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(width, kernel_size=1)(x)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", activation=keras.activations.swish
        )(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = layers.Add()([x, residual])
        return x

    return apply

def DownBlock(width, BLOCK_DEPTH):
    def apply(x):
        x, t, content, skips = x
        tt = layers.UpSampling2D(size=int(x.shape[1]/t.shape[1]), interpolation="nearest")(t)
        x = ResidualBlock(width)(layers.Concatenate()([x, tt]))
        skips.append(x)
        x = layers.AveragePooling2D(pool_size=2)(x)

        for _ in range(BLOCK_DEPTH):
            tt = layers.UpSampling2D(size=int(x.shape[1]/t.shape[1]), interpolation="nearest")(t)
            x = ResidualBlock(width)(layers.Concatenate()([x, tt]))

            # TransformerBlock
            x = TransformerBlock(width, HEAD, int(width/HEAD))(x, context=content)

            skips.append(x)
        return x

    return apply


def UpBlock(width, BLOCK_DEPTH):
    def apply(x):
        x, t, content, skips = x
        for _ in range(BLOCK_DEPTH):
            tt = layers.UpSampling2D(size=int(x.shape[1]/t.shape[1]), interpolation="nearest")(t)
            s_pop = skips.pop()
            x = layers.Concatenate()([x, s_pop, tt])
            x = ResidualBlock(width)(x)

            # TransformerBlock
            x = TransformerBlock(width, HEAD, int(width/HEAD))(x, context=content)

        x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        tt = layers.UpSampling2D(size=int(x.shape[1]/t.shape[1]), interpolation="nearest")(t)
        s_pop = skips.pop()
        x = layers.Concatenate()([x, s_pop, tt])
        x = ResidualBlock(width)(x)
        return x

    return apply

def get_network(image_size, WIDTHS, BLOCK_DEPTH):
    noisy_images = layers.Input(shape=(image_size, image_size, 3))
    noise_variances = layers.Input(shape=(1, 1, 1))
    caption = layers.Input(shape=(77, 768))

    e = layers.Lambda(sinusoidal_embedding)(noise_variances)
    c = caption
    x = layers.Conv2D(32, kernel_size=1)(noisy_images)

    skips = []
    for width in WIDTHS[:-1]:
        x = DownBlock(width, BLOCK_DEPTH)([x, e, c, skips])

    for _ in range(BLOCK_DEPTH):
        t = layers.UpSampling2D(size=int(x.shape[1]/e.shape[1]), interpolation="nearest")(e)
        x = ResidualBlock(WIDTHS[-1])(layers.Concatenate()([x, t]))

    for width in reversed(WIDTHS[:-1]):
        x = UpBlock(width, BLOCK_DEPTH)([x, e, c, skips])

    x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)

    return keras.Model([noisy_images, noise_variances, caption], x, name="residual_unet")


In [None]:
network = get_network(IMAGE_SIZE, WIDTHS, BLOCK_DEPTH)
# network.summary()

In [None]:
# create and compile the model
model = DiffusionModel(IMAGE_SIZE, WIDTHS, BLOCK_DEPTH)

cosine_decay_scheduler = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = 1e-6,
    decay_steps = 1, 
    warmup_steps = 10000, 
    warmup_target = 1e-4,
    alpha=1.0,
)

model.compile(
    optimizer=tf.keras.optimizers.AdamW(
        learning_rate=cosine_decay_scheduler,
        weight_decay=WEIGHT_DECAY
    ),
    loss=tf.keras.losses.MeanAbsoluteError()
)

In [None]:
steps_per_epoch = len(train_dataset)
total_steps = steps_per_epoch*NUM_EPOCHS

# ckp = tf.train.latest_checkpoint(CHECKPOINT_DIR)
ckp = CHECKPOINT_DIR + '/ddpm-71'

if ckp:
    init_epoch = (int(ckp.split('-')[-1]))+1
    init_step = steps_per_epoch*(init_epoch-1)+1
    global_steps = tf.Variable(init_step, trainable=False, dtype=tf.int64)
    ckpt = tf.train.Checkpoint(epoch=tf.Variable(init_epoch),
                               step=tf.Variable(init_step), 
                               net=model)
    ckpt.restore(ckp)
    print(f'Resume training from global_epoch {init_epoch-1}, global_steps {init_step-1}')
else:
    init_epoch = 1
    init_step = 1
    global_steps = tf.Variable(init_step, trainable=False, dtype=tf.int64)
    ckpt = tf.train.Checkpoint(epoch=tf.Variable(0), 
                               step=tf.Variable(0), 
                               net=model)
    print(f'Start training from global_epoch {init_epoch}, global_steps {init_step}')

manager = tf.train.CheckpointManager(ckpt, CHECKPOINT_DIR, max_to_keep=400, 
                                     checkpoint_name='ddpm')


feature_ds = train_dataset.map(lambda x, y: x)
model.normalizer.adapt(feature_ds)

In [None]:
model.plot_images(0)
for epoch in range(init_epoch, NUM_EPOCHS+1):
    print(f'Epoch {epoch:>2}/{NUM_EPOCHS}')
    ckpt.epoch.assign_add(1)

    for image_data, cap in train_dataset:
        if np.random.random() < 0.1:
            cap_emb = uncondition_caps_emb
        else:
            cap_emb = text_encoder([cap, pos_ids])
            
        matrix = model.train_step(image_data, cap_emb)
          
        global_steps.assign_add(1)
        
        print("=> STEP %d/%d  lr: %f n_loss: %f  i_loss: %f" % (global_steps, total_steps, model.optimizer.lr, matrix['n_loss'].numpy(), matrix['i_loss'].numpy()), end='\r')
    
    print()
    save_path = manager.save()
    if save_path:
        print("Saved checkpoint for epoch {}: {}".format(int(ckpt.epoch), save_path)) 
        
    if epoch%5==0:
        model.plot_images(epoch, cfg_scale=3.6)
        model.plot_images(epoch, cfg_scale=3.6, caps=train_sample_sentences_emb, img_name='/train_ex_')
    

## Inference

In [None]:
def testing_data_generator(caption, index):
    caption = tf.cast(caption, tf.float32)
    return caption, index

def testing_dataset_generator(filenames, batch_size, data_generator):
    data = pd.read_pickle(filenames)
    captions = data['Captions'].values
    caption = []
    for i in range(len(captions)):
        img_raw_caption_id = captions[i]
        img_raw_words = [id2word_dict[str(w_id.astype(np.int32))] for w_id in img_raw_caption_id if w_id != '5427']
        img_raw_caption = " ".join(img_raw_words)
        img_caption_id = tokenizer.encode(img_raw_caption)
        phrase = img_caption_id + [49407] * (MAX_TEXT_LENGTH - len(img_caption_id)) 
        assert(len(phrase) == MAX_TEXT_LENGTH)
        caption.append(phrase)
    assert len(captions) == len(caption)
        
    caption = np.asarray(caption).astype(np.int32)
    index = np.asarray(data['ID'].values)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)
    return dataset

In [None]:
testing_dataset = testing_dataset_generator('./dataset/testData.pkl', BATCH_SIZE, testing_data_generator)

data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / BATCH_SIZE)

In [None]:
import time
def inference(dataset, cfg_scale, seed=None):
    step = 0
    start = time.time()
    for captions, idx in dataset:
        print(f"=> {step}/{EPOCH_TEST}", end='\r')
        if step > EPOCH_TEST:
            break
            
        cap_emb = text_encoder([captions, pos_ids])
        model.test_step(idx, cap_emb, cfg_scale=cfg_scale, seed=seed)
        step += 1
            
    print('Time for inference is {:.4f} sec'.format(time.time()-start))

In [None]:
inference(testing_dataset, 3.6, 900523)

In denoise
Time for inference is 59.1357 sec


In [None]:
import os
os.chdir('./testing')
!python inception_score.py ../inference/final_version ../final_version.csv 39
os.chdir('../')

2 Physical GPUs, 1 Logical GPUs
--------------Evaluation Success-----------------


In [None]:
import os
import pandas as pd
import numpy as np

score_path = './final_version.csv'

if os.path.exists(score_path):
    df_score = pd.read_csv(score_path)
    mean_score = np.mean(df_score['score'].values)
    print(f'Mean Score: {mean_score:f}')
else:
    print('Evaluation Failed!')

Mean Score: 0.440848
