In [12]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import preprocess_input
from CyclicGen_model import Voxel_flow_model

In [13]:
def load_triplets(triplets_txt_path):
    with open(triplets_txt_path, 'r') as f:
        lines = f.readlines()
    triplets = [line.strip().split() for line in lines]
    return triplets

In [14]:
def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [256, 256])
    image = tf.cast(image, tf.float32) / 255.0
    return image

In [15]:
def get_vgg_feature_extractor():
    base_model = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
    output = base_model.get_layer('block4_conv3').output
    return Model(inputs=base_model.input, outputs=output)

In [16]:
def train_voxel_flow(triplets_txt_path, batch_size, num_epochs, train_dir):
    os.makedirs(train_dir, exist_ok=True)
    
    # Load triplets
    triplets = load_triplets(triplets_txt_path)
    
    # Prepare VGG feature extractor
    vgg_model = get_vgg_feature_extractor()
    
    # Create Voxel Flow Model
    voxel_model = Voxel_flow_model()

    # Optimizer
    optimizer = tf.keras.optimizers.Adam(1e-4)

    # Loss history
    loss_history = []

    # Epoch training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        np.random.shuffle(triplets)
        total_loss = 0

        for i in range(0, len(triplets), batch_size):
            batch = triplets[i:i + batch_size]
            if len(batch) < batch_size:
                continue  # Skip incomplete batch

            # Load batch images
            f1_batch, f2_batch, f3_batch = [], [], []
            for f1, f2, f3 in batch:
                f1_batch.append(load_and_preprocess_image(f1))
                f2_batch.append(load_and_preprocess_image(f2))
                f3_batch.append(load_and_preprocess_image(f3))

            f1_batch = tf.stack(f1_batch)
            f2_batch = tf.stack(f2_batch)
            f3_batch = tf.stack(f3_batch)

            # Extract VGG features
            edge_1 = tf.nn.sigmoid(vgg_model(preprocess_input(f1_batch * 255.0)))
            edge_3 = tf.nn.sigmoid(vgg_model(preprocess_input(f3_batch * 255.0)))

            # Resize feature maps to match input size
            edge_1 = tf.image.resize(edge_1, [256, 256])
            edge_3 = tf.image.resize(edge_3, [256, 256])

            # Concatenate inputs
            input_tensor = tf.concat([f1_batch, f3_batch, edge_1, edge_3], axis=-1)

            # Training step
            with tf.GradientTape() as tape:
                pred_f2, _ = voxel_model.inference(input_tensor)
                loss = tf.reduce_mean(tf.abs(pred_f2 - f2_batch))

            grads = tape.gradient(loss, voxel_model.trainable_variables)
            optimizer.apply_gradients(zip(grads, voxel_model.trainable_variables))

            total_loss += loss.numpy()

        avg_loss = total_loss / (len(triplets) // batch_size)
        loss_history.append(avg_loss)
        print(f"Average Loss: {avg_loss:.5f}")

        # Save model weights every epoch
        voxel_model.save_weights(os.path.join(train_dir, f"model_epoch_{epoch+1}.ckpt"))

    # Plot loss curve
    plt.figure()
    plt.plot(range(1, num_epochs + 1), loss_history, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('L1 Loss')
    plt.title('Training Loss per Epoch')
    plt.grid(True)
    plt.savefig(os.path.join(train_dir, "loss_plot.png"))
    plt.show()

In [None]:
train_voxel_flow(
    triplets_txt_path="D:/VFI/CyclicGen/triplets.txt",
    batch_size=4,
    num_epochs=10,
    train_dir="training_logs"
)