In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers, applications
import os, random, pathlib
import kagglehub
import matplotlib.pyplot as plt
import numpy as np
import datetime


IMG_RES    = 256
BATCH      = 16
MAX_IMAGES = None
AUTOTUNE   = tf.data.AUTOTUNE

img_root = kagglehub.dataset_download("sivarazadi/wikiart-art-movementsstyles")

patterns = ["*.jpg", "*.jpeg", "*.png"]
files = []
for pat in patterns:
    files += pathlib.Path(img_root).rglob(pat)

if not files:
    raise RuntimeError(f"No images found under {img_root}")

random.shuffle(files)
if MAX_IMAGES is not None:
    files = files[:MAX_IMAGES]

print(f"Using {len(files):,} total images.")

paths_ds = tf.data.Dataset.from_tensor_slices([str(p) for p in files])

def decode(path):
    img  = tf.io.read_file(path)
    img  = tf.image.decode_image(img, channels=3, expand_animations=False)
    img  = tf.image.random_flip_left_right(img)
    img  = tf.image.resize(img, [IMG_RES, IMG_RES])
    img  = (img / 127.5) - 1.0
    return img

dataset = (paths_ds
           .map(decode, num_parallel_calls=AUTOTUNE)
           .apply(tf.data.experimental.ignore_errors())
           .shuffle(min(len(files), 8_000))
           .batch(BATCH, drop_remainder=True)
           #.cache()                   
           .repeat()
           .prefetch(AUTOTUNE))

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers, applications
import os, random, pathlib
import kagglehub
import matplotlib.pyplot as plt
import numpy as np
import datetime

IMG     = 256
BATCH   = 16
LATENT  = 2048
STR_CH  = 16
LAMBDA_L1 = 5.0
LAMBDA_PERCEP = 1.0
LAMBDA_STYLE = 0.1

# VGG model for perceptual loss
def build_vgg_features():
    vgg = applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    outputs = [vgg.get_layer(name).output for name in [
        'block1_conv1', 'block2_conv1', 
        'block3_conv1', 'block4_conv1', 
        'block5_conv1'
    ]]
    return models.Model(vgg.input, outputs)

# cbl with residual connection option
def cbl(x, f, k=3, s=1, use_residual=False, name=None):
    """Conv → BatchNorm → LeakyReLU with optional residual connection."""
    skip = x
    x = layers.Conv2D(f, k, s, 'same', use_bias=False, name=f'{name}_conv')(x)
    x = layers.BatchNormalization(momentum=0.9, name=f'{name}_bn')(x)
    x = layers.LeakyReLU(0.2, name=f'{name}_lrelu')(x)
    
    # Add residual connection if requested and shapes match
    if use_residual and skip.shape[-1] == f and s == 1:
        x = layers.Add(name=f'{name}_add')([x, skip])
    return x

# Enhanced AdaIN
def adain(x, w, name=None):
    """Enhanced AdaIN with deeper MLP for style transformation."""
    # normalise x
    mu, var = tf.nn.moments(x, [1,2], keepdims=True)
    x_norm = (x - mu) / tf.sqrt(var + 1e-5)
    
    # Multi-layer style network for better feature modulation
    style = layers.Dense(512, name=f'{name}_dense1')(w)
    style = layers.LeakyReLU(0.2)(style)
    style = layers.Dense(2 * x.shape[-1], name=f'{name}_dense2')(style)
    
    gamma, beta = tf.split(style, 2, axis=-1)
    gamma = tf.reshape(gamma, [-1,1,1,x.shape[-1]])
    beta = tf.reshape(beta, [-1,1,1,x.shape[-1]])
    
    return gamma * x_norm + beta

def build_encoder():
    inp = layers.Input([IMG, IMG, 3])
    
    # Initial conv
    x = cbl(inp, 64, name='enc_c1')
    
    # Downsampling blocks with more capacity
    x = cbl(x, 128, s=2, name='enc_c2')           # 128×128
    x = cbl(x, 128, use_residual=True, name='enc_r2')
    
    x = cbl(x, 256, s=2, name='enc_c3')           # 64×64
    x = cbl(x, 256, use_residual=True, name='enc_r3')
    
    x = cbl(x, 512, s=2, name='enc_c4')           # 32×32
    x = cbl(x, 512, use_residual=True, name='enc_r4')
    
    # Add one more layer for better feature extraction
    x = cbl(x, 512, use_residual=True, name='enc_r5')

    # Structure code with more channels
    structure = layers.Conv2D(STR_CH, 1, padding='same', name='structure')(x)  # 32×32×16
    
    # Texture code with enhanced feature extraction
    texture_feat = layers.GlobalAveragePooling2D(name='gap')(x)
    texture = layers.Dense(1024, activation='relu', name='texture_fc1')(texture_feat)
    texture = layers.Dense(LATENT, name='texture')(texture)

    return models.Model(inp, [structure, texture], name='Encoder')

def build_generator():
    s_in = layers.Input([None, None, STR_CH])
    t_in = layers.Input([LATENT])
    
    # Initial processing with residual blocks
    x = cbl(s_in, 512, k=3, name='gen_c0')
    x = adain(x, t_in, 'gen_adain0')
    x = cbl(x, 512, use_residual=True, name='gen_r0')
    x = adain(x, t_in, 'gen_adain0b')

    # Upsampling blocks with residual connections
    for i, f in enumerate([256, 128, 64]):
        x = layers.UpSampling2D(interpolation='bilinear', name=f'up{i}')(x) 
        x = cbl(x, f, k=3, name=f'gen_c{i+1}')
        x = adain(x, t_in, f'gen_adain{i+1}a')
        x = cbl(x, f, use_residual=True, name=f'gen_r{i+1}')
        x = adain(x, t_in, f'gen_adain{i+1}b')

    # Multi-scale output for better detail
    out = layers.Conv2D(3, 3, padding='same', activation='tanh', name='toRGB')(x)
    
    return models.Model([s_in, t_in], out, name='Generator')

def build_discriminator():
    inp = layers.Input([IMG, IMG, 3])
    
    # More capacity with residual blocks
    x = cbl(inp, 64, name='d_c1')
    x = cbl(x, 64, use_residual=True, name='d_r1')
    
    x = cbl(x, 128, s=2, name='d_c2')             # 128×128
    x = cbl(x, 128, use_residual=True, name='d_r2')
    
    x = cbl(x, 256, s=2, name='d_c3')             # 64×64
    x = cbl(x, 256, use_residual=True, name='d_r3')
    
    x = cbl(x, 512, s=2, name='d_c4')             # 32×32
    x = cbl(x, 512, use_residual=True, name='d_r4')
    
    # Add spectral normalization for stability
    x = layers.Flatten()(x)
    x = layers.Dense(512, activation='leaky_relu', name='d_fc')(x)
    out = layers.Dense(1, name='d_out')(x)
    
    return models.Model(inp, out, name='Discriminator')

class SAE(models.Model):
    def __init__(self):
        super().__init__()
        self.E = build_encoder()
        self.G = build_generator()
        self.D = build_discriminator()
        self.vgg = build_vgg_features()
        
        self.bce = losses.BinaryCrossentropy(from_logits=True)
        self.l1 = losses.MeanAbsoluteError()
        self.l2 = losses.MeanSquaredError()
    
    def compile(self, g_opt, d_opt):
        super().compile()
        self.g_opt = g_opt
        self.d_opt = d_opt
    
    def perceptual_loss(self, real, fake):
        # Convert from [-1,1] to VGG input format
        real = (real + 1) * 127.5
        fake = (fake + 1) * 127.5
        
        # Preprocess for VGG
        real = applications.vgg19.preprocess_input(real)
        fake = applications.vgg19.preprocess_input(fake)
        
        # Get features
        real_features = self.vgg(real)
        fake_features = self.vgg(fake)
        
        # Calculate content loss (feature reconstruction)
        p_loss = 0
        for r, f in zip(real_features, fake_features):
            p_loss += self.l2(r, f)
        
        return p_loss / len(real_features)
    
    def style_loss(self, real, fake):
        # Convert from [-1,1] to VGG input format
        real = (real + 1) * 127.5
        fake = (fake + 1) * 127.5
        
        # Preprocess for VGG
        real = applications.vgg19.preprocess_input(real)
        fake = applications.vgg19.preprocess_input(fake)
        
        # Get features
        real_features = self.vgg(real)
        fake_features = self.vgg(fake)
        
        # Calculate style loss (Gram matrix difference)
        s_loss = 0
        for r, f in zip(real_features, fake_features):
            # Calculate Gram matrices
            r_gram = self._gram_matrix(r)
            f_gram = self._gram_matrix(f)
            s_loss += self.l2(r_gram, f_gram)
        
        return s_loss / len(real_features)
    
    def _gram_matrix(self, x):
        # Reshape to [batch_size, height*width, channels]
        features = tf.reshape(x, (tf.shape(x)[0], -1, tf.shape(x)[-1]))
        
        # Calculate Gram matrix
        gram = tf.matmul(features, features, transpose_a=True)
        
        # Normalize
        n = tf.cast(tf.shape(features)[1], tf.float32)
        return gram / n
    
    def train_step(self, batch):
        x, y = tf.split(batch, 2)               # x=first half, y=second half

        with tf.GradientTape(persistent=True) as tape:
            sx, tx = self.E(x)
            sy, ty = self.E(y)
            
            x_rec = self.G([sx, tx])
            x_swap = self.G([sx, ty])

            real = tf.concat([x, y], 0)
            fake = tf.concat([x_rec, x_swap], 0)

            # Discriminator outputs
            d_real = self.D(real)
            d_fake = self.D(fake)

            # Discriminator loss
            d_loss = self.bce(tf.ones_like(d_real), d_real) + \
                     self.bce(tf.zeros_like(d_fake), d_fake)

            # Generator losses
            g_adv = self.bce(tf.ones_like(d_fake), d_fake)
            g_l1 = self.l1(x, x_rec)
            
            # Perceptual and style losses
            p_loss = self.perceptual_loss(x, x_rec)
            s_loss = self.style_loss(x, x_rec)
            
            # Combined generator loss
            g_loss = g_adv + LAMBDA_L1 * g_l1 + LAMBDA_PERCEP * p_loss + LAMBDA_STYLE * s_loss

        # Update discriminator
        d_grads = tape.gradient(d_loss, self.D.trainable_variables)
        self.d_opt.apply_gradients(zip(d_grads, self.D.trainable_variables))

        # Update encoder and generator
        eg_vars = self.E.trainable_variables + self.G.trainable_variables
        eg_grads = tape.gradient(g_loss, eg_vars)
        self.g_opt.apply_gradients(zip(eg_grads, eg_vars))

        return {
            "d_loss": d_loss, 
            "g_loss": g_loss, 
            "recon": g_l1,
            "perceptual": p_loss,
            "style": s_loss
        }

def load_custom(path):
    """Load and preprocess an image."""
    img = tf.io.read_file(str(path))
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.resize(img, [IMG, IMG])
    img = (img / 127.5) - 1.0
    return img[None, ...]

# Convert from [-1,1] to [0,1] range
def to01(x): 
    return tf.clip_by_value((x + 1) * 0.5, 0, 1)

# Create mosaic visualization
def make_mosaic(imgs, rows=3, cols=2):
    imgs = [to01(i).numpy().squeeze() for i in imgs]
    h, w, _ = imgs[0].shape
    canvas = np.zeros((rows*h, cols*w, 3), np.float32)
    k = 0
    for r in range(rows):
        for c in range(cols):
            canvas[r*h:(r+1)*h, c*w:(c+1)*w] = imgs[k]; k += 1
    return canvas

class TensorBoardCallback(tf.keras.callbacks.Callback):
    def __init__(self, sample_ds, reference_images=None, log_dir=None):
        super().__init__()
        self.sample_batch = next(iter(sample_ds.take(1)))
        self.history = []
        
        # Create TensorBoard writer
        log_dir = log_dir or f"logs/fit/{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
        self.tb_writer = tf.summary.create_file_writer(log_dir)
        
        # Load reference images if provided
        self.has_references = False
        if reference_images:
            self.img_a = load_custom(reference_images[0])
            self.img_b = load_custom(reference_images[1])
            self.has_references = True
    
    def on_epoch_end(self, epoch, logs=None):
        # Track losses for plotting
        self.history.append(logs)
        
        # Log metrics to main TensorBoard log directory
        metrics_writer = tf.summary.create_file_writer(self.log_dir)
        with metrics_writer.as_default():
            for name, value in logs.items():
                tf.summary.scalar(name, value, step=epoch)
        
        # Process batch samples
        try:
            img_a = self.sample_batch[:1]
            img_b = self.sample_batch[1:2]
            
            s_a, t_a = self.model.E(img_a)
            s_b, t_b = self.model.E(img_b)
            
            recon_a = self.model.G([s_a, t_a])
            recon_b = self.model.G([s_b, t_b])
            swap_ab = self.model.G([s_a, t_b])
            swap_ba = self.model.G([s_b, t_a])
            
            # Create mosaic
            row1 = tf.concat([img_a, img_b], axis=2)   #width
            row2 = tf.concat([recon_a, recon_b], axis=2)
            row3 = tf.concat([swap_ab, swap_ba], axis=2)
            big = tf.concat([row1, row2, row3], axis=1)  #height
            big = self._to01(big)[0]
            
            # Logging
            with self.tb_writer.as_default():
                tf.summary.image("Batch_Samples", big[None, ...], step=epoch)
                
                ssim_a = tf.image.ssim(self._to01(img_a), self._to01(recon_a), max_val=1.0).numpy()[0]
                ssim_b = tf.image.ssim(self._to01(img_b), self._to01(recon_b), max_val=1.0).numpy()[0]
                
                tf.summary.scalar("SSIM_A", ssim_a, step=epoch)
                tf.summary.scalar("SSIM_B", ssim_b, step=epoch)
                
            print(f"Epoch {epoch}: Saved batch sample images to TensorBoard")
            
            # Process reference images if available
            if self.has_references:
                ref_s_a, ref_t_a = self.model.E(self.img_a)
                ref_s_b, ref_t_b = self.model.E(self.img_b)
                
                ref_recon_a = self.model.G([ref_s_a, ref_t_a])
                ref_recon_b = self.model.G([ref_s_b, ref_t_b])
                ref_swap_ab = self.model.G([ref_s_a, ref_t_b])
                ref_swap_ba = self.model.G([ref_s_b, ref_t_a])
                
                # Create reference mosaic
                ref_row1 = tf.concat([self.img_a, self.img_b], axis=2)
                ref_row2 = tf.concat([ref_recon_a, ref_recon_b], axis=2)
                ref_row3 = tf.concat([ref_swap_ab, ref_swap_ba], axis=2)
                ref_big = tf.concat([ref_row1, ref_row2, ref_row3], axis=1)
                ref_big = self._to01(ref_big)[0]
                
                # Log
                with self.tb_writer.as_default():
                    tf.summary.image("Reference_Images", ref_big[None, ...], step=epoch)
                    
                    # Log reference SSIM metrics
                    ref_ssim_a = tf.image.ssim(self._to01(self.img_a), self._to01(ref_recon_a), max_val=1.0).numpy()[0]
                    ref_ssim_b = tf.image.ssim(self._to01(self.img_b), self._to01(ref_recon_b), max_val=1.0).numpy()[0]
                    
                    tf.summary.scalar("Ref_SSIM_A", ref_ssim_a, step=epoch)
                    tf.summary.scalar("Ref_SSIM_B", ref_ssim_b, step=epoch)
                
                print(f"Epoch {epoch}: Saved reference images to TensorBoard")
            
            self.tb_writer.flush()
            
        except Exception as e:
            print(f"Error generating TensorBoard images: {e}")
            import traceback
            traceback.print_exc()

from absl import logging
logging.set_verbosity(logging.ERROR)

ckpt_dir = pathlib.Path("checkpoints")
ckpt_dir.mkdir(exist_ok=True)

#LOG_ROOT = "/root/tf-logs"
LOG_ROOT = ""
#os.makedirs(LOG_ROOT, exist_ok=True) 

run_id  = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = os.path.join(LOG_ROOT, run_id)
tb_cb = tf.keras.callbacks.TensorBoard(
    log_dir      = log_dir,
    update_freq  = "batch",   # log for every batch
    histogram_freq = 0
)


# Training
num_epochs=60
sae = SAE()
sae.compile(
    g_opt=tf.keras.optimizers.Adam(1e-4, beta_1=0.5),  # Lower learning rate for stability
    d_opt=tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
)

reference_images = ["A.png", "B.png"]

sample_ds = dataset.take(1)
callbacks = [
    # TensorBoard callback with reference images
    TensorBoardCallback(
        sample_ds=sample_ds,
        reference_images=reference_images,
        log_dir=log_dir
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=str(ckpt_dir / "improved_sae_{epoch:02d}.ckpt"),
        save_weights_only=True,
        save_freq=5 * (len(files) // BATCH) if 'files' in globals() else 5 * 100,  # Save every 5 epochs
        verbose=1
    ),
    tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        update_freq="batch"
    )
]
    
num_steps = len(files) // BATCH if 'files' in globals() else 100

In [None]:
#Train
sae.fit(
    dataset, 
    epochs=num_epochs, 
    steps_per_epoch=num_steps, 
    callbacks=callbacks
)


In [None]:
def load_custom(path):
    """Load and preprocess an image."""
    img  = tf.io.read_file(str(path))
    img  = tf.image.decode_image(img, channels=3, expand_animations=False)
    img  = tf.image.resize(img, [IMG, IMG])
    img  = (img / 127.5) - 1.0      # →  [-1, 1]
    return img[None, ...]           # add batch dim (1, H, W, C)

def to01(x): 
    """Convert from [-1,1] to [0,1] range."""
    return tf.clip_by_value((x + 1) * 0.5, 0, 1)

# Create model
sae = SAE()
sae.compile(
    g_opt=tf.keras.optimizers.Adam(1e-4, beta_1=0.5),
    d_opt=tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
)

checkpoint_dir = "./checkpoints/"
checkpoint_name = "improved_sae_115"
checkpoint_path = os.path.join(checkpoint_dir, f"{checkpoint_name}.ckpt")

# Check if checkpoint exists before loading
if os.path.exists(f"{checkpoint_path}.index"):
    sae.load_weights(checkpoint_path)
    print(f"Model loaded from checkpoint: {checkpoint_path}")
else:
    print(f"Checkpoint not found at {checkpoint_path}")
    print(f"Available checkpoints in {checkpoint_dir}:")
    for file in os.listdir(checkpoint_dir):
        if file.endswith(".index"):
            print(f"  - {file.replace('.index', '')}")
    
    # Try to find any available checkpoint
    available_checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".index")]
    if available_checkpoints:
        latest_ckpt = os.path.join(checkpoint_dir, available_checkpoints[-1].replace(".index", ""))
        print(f"\nAttempting to load the latest available checkpoint: {latest_ckpt}")
        sae.load_weights(latest_ckpt)
        print(f"Successfully loaded: {latest_ckpt}")

In [None]:
from pathlib import Path

# Load images
a_path = Path("A.png")
b_path = Path("B.png")
img_a, img_b = load_custom(a_path), load_custom(b_path)

# Encode, reconstruct, swap
s_a, t_a = sae.E(img_a)
s_b, t_b = sae.E(img_b)
recon_a = sae.G([s_a, t_a])
recon_b = sae.G([s_b, t_b])
swap_ab = sae.G([s_a, t_b])   # A structure, B texture
swap_ba = sae.G([s_b, t_a])   # B structure, A texture


images_to_save = {
    "recon_a.png": recon_a,
    "recon_b.png": recon_b,
    "swap_ab.png": swap_ab,
    "swap_ba.png": swap_ba
}

output_dir = Path(".")
print(f"Saving images to {output_dir}...")
for filename, tensor in images_to_save.items():
    try:
        img_01 = to01(tensor)
        img_np = img_01.numpy()
        img_squeezed = img_np.squeeze()

        save_path = output_dir / filename
        
        plt.imsave(save_path, img_squeezed)
    except Exception as e:
        print(f"Error saving {filename}: {e}")

print("Finished saving images.")
def make_mosaic(imgs, rows=3, cols=2):
    imgs = [to01(i).numpy().squeeze() for i in imgs]
    h, w, _ = imgs[0].shape
    canvas = np.zeros((rows*h, cols*w, 3), np.float32)
    k = 0
    for r in range(rows):
        for c in range(cols):
            canvas[r*h:(r+1)*h, c*w:(c+1)*w] = imgs[k]; k += 1
    return canvas

mosaic = make_mosaic([
    img_a, img_b,
    recon_a, recon_b,
    swap_ab, swap_ba
])


plt.figure(figsize=(12, 18))
plt.imshow(mosaic)
plt.axis("off")

# Add labels
labels = ["Original A", "Original B", 
          "Reconstructed A", "Reconstructed B",
          "A structure + B texture", "B structure + A texture"]

positions = [(0.25, 0.05), (0.75, 0.05),
             (0.25, 0.38), (0.75, 0.38),
             (0.25, 0.71), (0.75, 0.71)]

# Calculate and display SSIM metrics for reconstructions
ssim_a = tf.image.ssim(to01(img_a), to01(recon_a), max_val=1.0).numpy()[0]
ssim_b = tf.image.ssim(to01(img_b), to01(recon_b), max_val=1.0).numpy()[0]

plt.tight_layout()
plt.show()

In [None]:

#Interpolate between texture codes of images A and B.
steps = 4
alphas = np.linspace(0, 1, steps)
results = []

for alpha in alphas:
    # Interpolate texture vectors
    t_mix = t_a * (1 - alpha) + t_b * alpha
    # Generate using fixed structure with interpolated texture
    img = sae.G([s_a, t_mix])
    results.append(img)

# Visualize results
fig = plt.figure(figsize=(steps*3, 4))
for i, img in enumerate(results):
    plt.subplot(1, steps, i+1)
    plt.imshow(to01(img)[0])
    plt.title(f"α={alphas[i]:.2f}")
    plt.axis('off')

plt.suptitle("Texture Interpolation (Structure A with textures from A→B)", fontsize=16)
plt.tight_layout()
plt.show()
fig.savefig("interpolation.png")

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os

a_path = Path("A.png")
b_path = Path("B.png")
img_a, img_b = load_custom(a_path), load_custom(b_path)

# Encode, reconstruct, swap
s_a, t_a = sae.E(img_a)
s_b, t_b = sae.E(img_b)
recon_a = sae.G([s_a, t_a])   # B structure, A texture
swap_ab = sae.G([s_a, t_b])   # A structure, B texture
print("Generation complete.")


images_to_plot = [img_a, img_b, recon_a, swap_ab]
titles = [
    "Structure Image (A)",
    "Texture Image (B)",
    "Reconstructed Structure (A)",
    "Style Transferred (A struct, B tex)"
]

fig, axes = plt.subplots(2, 2, figsize=(10, 10))

axes = axes.flatten()

print("Generating 2x2 plot...")
for i, ax in enumerate(axes):
    if i < len(images_to_plot):
        img_display = to01(images_to_plot[i]).numpy().squeeze()
        ax.imshow(img_display)
        ax.set_title(titles[i], fontsize=12)
        ax.axis("off")
    else:
        ax.axis("off")

plt.tight_layout(pad=1.5)

# Show the plot
plt.show()
print("Plot displayed.")

fig.savefig("results.png")