### Libraries

In [10]:
# General Imports
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import sys
import yaml
import numpy as np
from glob import glob
import time
import cv2

# Tensorflow Imports
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense, Lambda, Input, Dense
from tensorflow.keras.layers import Flatten, Reshape, Concatenate
from tensorflow.keras.layers import Conv2DTranspose

# Local Module Import
sys.path.append("../src")  # adds source code directory
from utils import frame_to_label, frames_to_video
from polygon_handle import masks_to_polygons
from log_setup import logger
from cvae_model import CVAEDataGenerator, CVAEComponents
from cvae_model import ReduceLROnPlateauSteps, EarlyStoppingSteps, HistoryLogger

### Global Variables

In [11]:
""" 
DATA: "full" (full dataset), "sampled" (distance sampled dataset) 
        or "unet" (unet generated dataset)
MODE: "interpol" (interpolation) or "extrapol" (extrapolation)
MODEL: "CVAE"
PERCENTAGE: percentage of training data to be used for training
LAST_FRAME: last frame number of the video
"""

DATA = "unet"
MODE = "extrapol"
MODEL = "CVAE"
PERCENTAGE = 30
LAST_FRAME = 22500

### Directories

In [12]:
BASE_DIR = os.path.dirname(os.getcwd())
dataset_dir = os.path.join(BASE_DIR, "dataset")
data_dir = os.path.join(BASE_DIR, "data")
config_file = os.path.join(BASE_DIR, "config.yml")

# Output PNG directory
if MODE == "extrapol":
    output_dir = os.path.join(BASE_DIR, "outputs", "CVAE", MODE, str(PERCENTAGE), DATA)
    logger.info(
        f"Data: {DATA}, Mode: {MODE}, Model: {MODEL} Percentage: {PERCENTAGE}%,\nOutput directory: {output_dir}"
    )
elif MODE == "interpol":
    output_dir = os.path.join(BASE_DIR, "outputs", "CVAE", MODE, DATA)
    logger.info(
        f"\nData: {DATA}, Mode: {MODE}, Model: {MODEL}\nOutput directory: {output_dir}"
    )

INFO - Data: unet, Mode: extrapol, Model: CVAE Percentage: 30%,
Output directory: /home/tiagociic/Projectos/spatiotemporal-vae-reconstruction/outputs/CVAE/extrapol/30/unet


### Config file

In [13]:
with open(config_file, "r", encoding="utf-8") as f:
    config = yaml.safe_load(f)

### Data loading

In [14]:
# Training data
if DATA == "full":
    train_dir = os.path.join(BASE_DIR, config["data"]["full"]["train_dir"], "masks")
    # sort the paths
    train_paths = sorted(glob(os.path.join(train_dir, "*.png")))
    # extract labels from the paths
    train_labels = [
        int(os.path.basename(m).split("_")[1].split(".")[0]) * 100 for m in train_paths
    ]
    epochs = config["CVAE"]["epochs"]

elif DATA == "sampled":
    sampled_masks_txt_path = os.path.join(
        BASE_DIR, config["data"]["wkt"]["sampled_masks_txt"]
    )
    with open(sampled_masks_txt_path, "r", encoding="utf-8") as f:
        polygons = f.readlines()
        # extract indexes
    indexes = [int(polygon.split(",")[0]) for polygon in polygons]
    train_dir = os.path.join(BASE_DIR, config["data"]["sampled"]["train_dir"], "masks")
    train_paths = sorted(glob(os.path.join(train_dir, "*.png")))
    train_labels = [100 * i for i in indexes]
    epochs = config["CVAE"]["epochs"]

elif DATA == "unet":
    train_dir = os.path.join(BASE_DIR, config["data"]["unet"]["train_dir"], "masks")
    train_paths = sorted(glob(os.path.join(train_dir, "*.png")))
    train_labels = [
        int(os.path.basename(m).split("_")[1].split(".")[0]) for m in train_paths
    ]
    epochs = 2


# Test data
test_dir = os.path.join(BASE_DIR, config["data"]["test"]["test_dir"], "masks")
test_paths = sorted(glob(os.path.join(test_dir, "*.png")))
test_labels = [
    int(os.path.basename(m).split("_")[1].split(".")[0]) * 100 + 20250
    for m in test_paths
]

if MODE == "extrapol":
    # Truncate the training data
    train_paths = train_paths[: int(len(train_paths) * PERCENTAGE / 100)]
    train_labels = train_labels[: int(len(train_labels) * PERCENTAGE / 100)]
    logger.info(
        f"No. train. samples: {len(train_paths)} out of {LAST_FRAME} ({PERCENTAGE}%) | No. test samples: {len(test_paths)}"
    )
elif MODE == "interpol":
    logger.info(
        f"No. train. samples: {len(train_paths)} out of {LAST_FRAME} | No. test samples: {len(test_paths)}"
    )

INFO - No. train. samples: 6759 out of 22500 (30%) | No. test samples: 23


In [15]:
try:
    sampled_masks_txt_path = os.path.join(BASE_DIR, config["data"]["wkt"]["sampled_masks_txt"])
except KeyError:
    print("Key 'sampled_masks_txt' not found in the config data.")

### Data Generators

In [16]:
input_shape = config["CVAE"]["input_shape"]

# Create training data generator
train_data_gen = CVAEDataGenerator(
    data_paths=train_paths,
    labels=train_labels,
    batch_size=1,
    input_shape=input_shape[:2],
    last_frame=LAST_FRAME,
)

# Create testing data generator
test_data_gen = CVAEDataGenerator(
    data_paths=test_paths,
    labels=test_labels,
    batch_size=1,
    input_shape=input_shape[:2],
    last_frame=LAST_FRAME,
)

### C-VAE definition

In [18]:
H, W, C = config["CVAE"]["input_shape"]
filters = int(config["CVAE"]["ref_filters"])
cvae_comp = CVAEComponents()


# --------
# Encoder
# --------

encoder_inputs = Input(shape=(H, W, C))
# Reshape input to 2D image

x = cvae_comp.conv_block(
    input=encoder_inputs, filters=filters * 2, f_init=config["CVAE"]["w_init"]
)
x = cvae_comp.conv_block(input=x, filters=filters, f_init=config["CVAE"]["w_init"])
x = Flatten()(x)
x = Dense(64, activation="leaky_relu")(x)

# VAE specific layers for mean and log variance
z_mean = Dense(config["CVAE"]["latent_dim"], activation="leaky_relu", name="z_mean")(x)
z_log_var = Dense(
    config["CVAE"]["latent_dim"], activation="leaky_relu", name="z_log_var"
)(x)

# Sampling layer to sample z from the latent space
z = Lambda(cvae_comp.sampler, output_shape=(config["CVAE"]["latent_dim"],), name="z")(
    [z_mean, z_log_var]
)

# Instantiate encoder model
encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")

# --------
# Decoder
# --------

latent_inputs = Input(shape=(config["CVAE"]["latent_dim"],), name="z_sampling")
label_size = 1  # one tf.float32 label
label_inputs = Input(shape=(label_size,), name="label")
decoder_inputs = Concatenate()([latent_inputs, label_inputs])
x = Dense(64 * 64 * 64, activation="leaky_relu")(decoder_inputs)
x = Reshape((128, 128, 16))(x)
x = cvae_comp.deconv_block(input= x, filters= filters * 2, f_init = config["CVAE"]["w_init"])
x = cvae_comp.deconv_block(input= x, filters=filters * 4, f_init = config["CVAE"]["w_init"])
decoder_output = Conv2DTranspose(1, 3, activation="tanh", padding="same")(x)

decoder = Model([latent_inputs, label_inputs], decoder_output, name="decoder")

# -----------------
# Conditional VAE
# -----------------

outputs = decoder([encoder(encoder_inputs)[2], label_inputs])
cvae = Model([encoder_inputs, label_inputs], outputs, name="cvae")
cvae.summary()

Model: "cvae"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 512, 512, 1)]        0         []                            
                                                                                                  
 encoder (Functional)        [(None, 64),                 3357281   ['input_2[0][0]']             
                              (None, 64),                 6                                       
                              (None, 64)]                                                         
                                                                                                  
 label (InputLayer)          [(None, 1)]                  0         []                            
                                                                                               

### Callbacks

In [20]:
reduce_lr = ReduceLROnPlateauSteps(
    monitor="loss", factor=0.5, mode="min", patience=5000, verbose=1, min_lr=1e-8
)

early_stopping = EarlyStoppingSteps(
    monitor="loss",
    min_delta=0,
    patience=10000,
    verbose=1,
    mode="auto",
    restore_best_weights=True,
)

checkpoint_dir = os.path.join(BASE_DIR, config["data"]["checkpoint_dir"])
if MODE == "extrapol":
    checkpoint_path = os.path.join(
        checkpoint_dir, f"cvae_{DATA}_{MODE}_{PERCENTAGE}.h5"
    )
elif MODE == "interpol":
    checkpoint_path = os.path.join(checkpoint_dir, f"cvae_{DATA}_{MODE}.h5")

# use ModelCheckpoint to save best model
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_best_only=True,
    monitor="loss",
    mode="auto",
    verbose=1,
    save_weights_only=True,
)

history_logger = HistoryLogger(log_interval=500)

### Model compilation

In [21]:
cvae.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(
        learning_rate=config["CVAE"]["learning_rate"]
    ),
    loss= cvae_comp.mse_kl_loss,
)

### Training

In [22]:
cvae.optimizer.lr = config["CVAE"]["learning_rate"]

# Fit the model
history = cvae.fit(
    train_data_gen,
    steps_per_epoch=len(train_data_gen),
    epochs=2,
    validation_data=test_data_gen,
    validation_steps=len(test_data_gen),
    callbacks=[reduce_lr, early_stopping,history_logger],
)

Epoch 1/2




### Inference

In [15]:
# load the best model
cvae.load_weights(checkpoint_path)

In [16]:
def generate_frames(
    decoder, output_dir: str, total_frames: int = 22500, resize_original: bool = False
):
    """
    Generates and saves the frames from a trained decoder.

    Parameters:
        decoder (keras.Model): The trained decoder.
        output_dir (str): The path to the output directory.
        total_frames (int): The total number of frames to generate.
        resize_original (bool): Whether to resize the frames to the original dimensions.
    """

    start_total_time = time.time()

    frames_num = np.arange(1, total_frames + 1, 1)

    for i in range(total_frames):
        frame_num = frames_num[i]

        # Sample from the latent space
        z_sample = np.full((1, config["CVAE"]["latent_dim"]), 0.5)

        # Generate the frame
        try:
            start_time = time.time()
            reconst = decoder.predict([z_sample, frame_to_label(frame_num)])
            reconst_time = (time.time() - start_time) * 1000
            reconst = np.squeeze(reconst, axis=0)
        except Exception as e:
            print(f"Error generating frame {frame_num}: {e}")
            continue

        if resize_original:
            start_time = time.time()
            reconst = tf.image.resize(
                images=reconst, size=config["data"]["original_vid_dims"]
            )
            resize_time = (time.time() - start_time) * 1000
        else:
            resize_time = 0.0  # Not resizing

        # Binarize the reconstructed image with OpenCV
        start_time = time.time()
        _, thresh_img = cv2.threshold(
            reconst, config["CVAE"]["threshold"], 255, cv2.THRESH_BINARY
        )
        threshold_time = (time.time() - start_time) * 1000

        # Save the thresholded image as png in grayscale
        try:
            start_time = time.time()
            cv2.imwrite(
                os.path.join(output_dir, f"frame_{frame_num:06d}.png"), thresh_img
            )
            save_time = (time.time() - start_time) * 1000
        except Exception as e:
            print(f"Error saving frame {frame_num}: {e}")
            continue

        # Print progress with time information
        print(
            f"Generated frame {i+1} of {total_frames} | "
            f"Reconst: {reconst_time:.2f}ms | "
            f"Resize: {resize_time:.2f}ms | "
            f"Threshold: {threshold_time:.2f}ms | "
            f"Save: {save_time:.2f}ms | "
            f"Elapsed Time: {time.time() - start_total_time:.2f}s  ",
            end="\r",
        )
    print()

In [17]:
output_png_dir = os.path.join(output_dir, "PNG")
generate_frames(decoder, output_png_dir, total_frames=LAST_FRAME)

  updates=self.state_updates,


Generated frame 22500 of 22500 | Reconst: 23.82ms | Resize: 0.00ms | Threshold: 0.17ms | Save: 0.87ms | Elapsed Time: 540.08s  


In [19]:
# generate video from the generated frames
if MODE == "extrapol":
    file_name = f"video_{DATA}_{MODE}_{PERCENTAGE}"
    title = f"CVAE: {MODE}ation - {DATA}, {PERCENTAGE}, {config['CVAE']['epochs']} epochs, 10x speed"
elif MODE == "interpol":
    file_name = f"video_{DATA}_{MODE}"
    title = f"CVAE: {MODE}ation - {DATA}, {config['CVAE']['epochs']} epochs, 10x speed"

frames_to_video(
    img_list_dir=os.path.join(output_dir, "PNG"),
    output_dir=output_dir,
    output_resolution=config["data"]["original_vid_dims"],
    title=title,
    f_ps=250,  # 10x speed
    file_name=file_name,
    frame_num_text=True,
    font_size=1,
)

INFO - Creating image list...                          
INFO - Writing frames to file 1/22500
INFO - Writing frames to file 1001/22500
INFO - Writing frames to file 2001/22500
INFO - Writing frames to file 3001/22500
INFO - Writing frames to file 4001/22500
INFO - Writing frames to file 5001/22500
INFO - Writing frames to file 6001/22500
INFO - Writing frames to file 7001/22500
INFO - Writing frames to file 8001/22500
INFO - Writing frames to file 9001/22500
INFO - Writing frames to file 10001/22500
INFO - Writing frames to file 11001/22500
INFO - Writing frames to file 12001/22500
INFO - Writing frames to file 13001/22500
INFO - Writing frames to file 14001/22500
INFO - Writing frames to file 15001/22500
INFO - Writing frames to file 16001/22500
INFO - Writing frames to file 17001/22500
INFO - Writing frames to file 18001/22500
INFO - Writing frames to file 19001/22500
INFO - Writing frames to file 20001/22500
INFO - Writing frames to file 21001/22500
INFO - Writing frames to file 220

In [20]:
# List of generated frames paths
msks_paths = sorted(glob(os.path.join(output_png_dir, "*.png")))

# Convert the masks to polygons and save them as a WKT file
masks_to_polygons(
    msks_paths,
    out_dim=tuple(config["data"]["original_vid_dims"]),
    save_path=os.path.join(BASE_DIR,"outputs", MODEL, MODE, str(PERCENTAGE), DATA, "WKT", f"{MODE}_{DATA}.wkt"),
)

INFO - Converting masks to polygons...


Processed 22499 masks out of 22500 | Time elapsed: 4408.34s  

INFO - Saved polygons to /home/tiagociic/Projectos/spatiotemporal-vae-reconstruction/outputs/CVAE/extrapol/30/unet/WKT/extrapol_unet.wkt


Processed 22500 masks out of 22500 | Time elapsed: 4408.57s  