In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers, models, applications, callbacks
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr, mean_squared_error as mse

In [None]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

CONFIG = {
    "TOTAL_IMAGES": 2000,
    "NUM_TRAIN": 1800,
    "NUM_TEST": 200,
    "IMG_SIZE": 128,
    "BATCH_SIZE": 64,
    "EPOCHS_WARMUP": 10,
    "EPOCHS_FINETUNE": 30,
    "LEARNING_RATE": 1e-4,
    "PROJECT_DIR": "./working_vgg_attention_ae_32x32",
    "USE_AUGMENTATION": True,
    "USE_ATTENTION": True,
}

if not os.path.exists(CONFIG["PROJECT_DIR"]):
    os.makedirs(CONFIG["PROJECT_DIR"])

In [None]:
ds_full = tfds.load('cats_vs_dogs', split='train', as_supervised=True)
ds_full = ds_full.take(CONFIG['TOTAL_IMAGES'])

ds_train = ds_full.take(CONFIG['NUM_TRAIN'])
ds_val = ds_full.skip(CONFIG['NUM_TRAIN'])

def preprocess(image, label):
    image = tf.image.resize(image, [CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']])
    image = tf.cast(image, tf.float32) / 255.0
    return image, image

def augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    image = tf.image.random_saturation(image, lower=0.9, upper=1.1)
    image = tf.clip_by_value(image, 0.0, 1.0)
    return image, label

In [None]:
val_images = []
val_labels = []
for img, lbl in ds_val:
    resized_img = tf.image.resize(img, [CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']])
    val_images.append(resized_img.numpy())
    val_labels.append(lbl.numpy())

X_test = np.array(val_images)
y_test = np.array(val_labels).reshape(-1, 1)
X_test_normalized = X_test.astype('float32') / 255.0

train_ds = ds_train.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
if CONFIG['USE_AUGMENTATION']:
    train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.cache().shuffle(1000).batch(CONFIG["BATCH_SIZE"]).prefetch(tf.data.AUTOTUNE)

test_ds = ds_val.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(CONFIG["BATCH_SIZE"]).prefetch(tf.data.AUTOTUNE)

In [None]:
class SpatialAttention(layers.Layer):
    def __init__(self, channels, reduction=8, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.channels = channels
        self.reduction = reduction
        self.query_conv = layers.Conv2D(channels // reduction, 1, name='query')
        self.key_conv = layers.Conv2D(channels // reduction, 1, name='key')
        self.value_conv = layers.Conv2D(channels, 1, name='value')
        self.gamma = self.add_weight(name='gamma', shape=(), initializer='zeros', trainable=True)
    
    def call(self, x):
        batch_size = tf.shape(x)[0]
        height = tf.shape(x)[1]
        width = tf.shape(x)[2]
        
        query = self.query_conv(x)
        key = self.key_conv(x)
        value = self.value_conv(x)
        
        query = tf.reshape(query, [batch_size, height * width, self.channels // self.reduction])
        key = tf.reshape(key, [batch_size, height * width, self.channels // self.reduction])
        value = tf.reshape(value, [batch_size, height * width, self.channels])
        
        attention_logits = tf.matmul(query, key, transpose_b=True)
        scale = tf.cast(self.channels // self.reduction, tf.float32)
        attention_logits = attention_logits / tf.sqrt(scale)
        attention_weights = tf.nn.softmax(attention_logits, axis=-1)
        
        out = tf.matmul(attention_weights, value)
        out = tf.reshape(out, [batch_size, height, width, self.channels])
        
        return x + self.gamma * out
    
    def get_config(self):
        config = super().get_config()
        config.update({'channels': self.channels, 'reduction': self.reduction})
        return config


class SSIMPSNRCallback(callbacks.Callback):
    def __init__(self, validation_data, sample_size=200):
        super().__init__()
        self.validation_data = validation_data
        self.sample_size = min(sample_size, len(validation_data))
        self.ssim_history = []
        self.psnr_history = []
    
    def on_epoch_end(self, epoch, logs=None):
        indices = np.random.choice(len(self.validation_data), self.sample_size, replace=False)
        sample_images = self.validation_data[indices]
        predictions = self.model.predict(sample_images, verbose=0)
        
        ssim_scores = []
        psnr_scores = []
        for i in range(len(sample_images)):
            s = ssim(sample_images[i], predictions[i], channel_axis=2, data_range=1.0)
            p = psnr(sample_images[i], predictions[i], data_range=1.0)
            ssim_scores.append(s)
            psnr_scores.append(p)
        
        avg_ssim = np.mean(ssim_scores)
        avg_psnr = np.mean(psnr_scores)
        
        self.ssim_history.append(avg_ssim)
        self.psnr_history.append(avg_psnr)
        
        print(f" - SSIM: {avg_ssim:.4f}, PSNR: {avg_psnr:.2f} dB", end='')

In [None]:
def simple_loss(y_true, y_pred):
    mse = tf.reduce_mean(tf.square(y_true - y_pred))
    mae = tf.reduce_mean(tf.abs(y_true - y_pred))
    ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    return 0.7 * mse + 0.2 * mae + 0.1 * ssim_loss

def build_model(mode='warmup'):
    vgg = applications.VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
    encoder_output = vgg.get_layer('block2_pool').output
    
    x = layers.Conv2D(32, 1, padding='same', name='bottleneck')(encoder_output)
    x = layers.BatchNormalization()(x)
    latent = layers.Activation('relu')(x)
    encoder = models.Model(vgg.input, latent, name='encoder')
    
    decoder_input = layers.Input(shape=(8, 8, 32))
    x = decoder_input
    
    x = layers.Conv2D(64, 3, padding='same', activation='relu', name='dec_conv1')(x)
    if CONFIG['USE_ATTENTION']:
        x = SpatialAttention(64, name='attention1')(x)
    x = layers.UpSampling2D(2, name='up1')(x)
    
    x = layers.Conv2D(64, 3, padding='same', activation='relu', name='dec_conv2')(x)
    if CONFIG['USE_ATTENTION']:
        x = SpatialAttention(64, name='attention2')(x)
    x = layers.UpSampling2D(2, name='up2')(x)
    
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    decoder_output = layers.Conv2D(3, 3, activation='sigmoid', padding='same')(x)
    
    decoder = models.Model(decoder_input, decoder_output, name='decoder')
    autoencoder = models.Model(encoder.input, decoder(encoder.output), name='autoencoder')
    
    if mode == 'warmup':
        encoder.trainable = False
    else:
        encoder.trainable = True
        for layer in vgg.layers:
            if 'block2' not in layer.name:
                layer.trainable = False
    
    return autoencoder

In [None]:
import time
import gc
start_time = time.time()

model = build_model(mode='warmup')
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=CONFIG['LEARNING_RATE']), loss=simple_loss)
history_warmup = model.fit(train_ds, epochs=CONFIG['EPOCHS_WARMUP'], validation_data=test_ds, verbose=1)

temp_weights = os.path.join(CONFIG["PROJECT_DIR"], "warmup_weights.weights.h5")
model.save_weights(temp_weights)

del model
gc.collect()
tf.keras.backend.clear_session()

model = build_model(mode='finetune')
model.load_weights(temp_weights)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
model.compile(optimizer=optimizer, loss=simple_loss)

ckpt_path = os.path.join(CONFIG["PROJECT_DIR"], "best_model_attention.keras")
ssim_psnr_callback = SSIMPSNRCallback(X_test_normalized, sample_size=200)
callbacks_list = [
    ssim_psnr_callback,
    callbacks.ModelCheckpoint(ckpt_path, monitor='val_loss', save_best_only=True, mode='min', verbose=1),
    callbacks.EarlyStopping(monitor='val_loss', patience=7, restore_best_weights=True, verbose=1),
    callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7, verbose=1)
]

history = model.fit(train_ds, epochs=CONFIG['EPOCHS_FINETUNE'], validation_data=test_ds, callbacks=callbacks_list, verbose=1)

autoencoder = keras.models.load_model(ckpt_path, custom_objects={'simple_loss': simple_loss, 'SpatialAttention': SpatialAttention})
elapsed = (time.time() - start_time) / 60

In [None]:
test_predictions = autoencoder.predict(X_test_normalized, batch_size=16, verbose=0)

ssim_all = []
psnr_all = []
mse_all = []
for i in range(len(X_test_normalized)):
    ssim_score = ssim(X_test_normalized[i], test_predictions[i], channel_axis=2, data_range=1.0)
    psnr_score = psnr(X_test_normalized[i], test_predictions[i], data_range=1.0)
    mse_score = mse(X_test_normalized[i], test_predictions[i])
    ssim_all.append(ssim_score)
    psnr_all.append(psnr_score)
    mse_all.append(mse_score)

ssim_all = np.array(ssim_all)
psnr_all = np.array(psnr_all)
mse_all = np.array(mse_all)
class_names = ['Cat', 'Dog']

ssim_by_class = {}
psnr_by_class = {}
mse_by_class = {}
for class_idx in range(2):
    mask = (y_test.flatten() == class_idx)
    ssim_by_class[class_idx] = ssim_all[mask]
    psnr_by_class[class_idx] = psnr_all[mask]
    mse_by_class[class_idx] = mse_all[mask]

excellent_ssim = np.sum(ssim_all >= 0.85)
excellent_psnr = np.sum(psnr_all >= 30)
correlation = np.corrcoef(ssim_all, psnr_all)[0, 1]

print(f"Average SSIM: {np.mean(ssim_all):.4f}")
print(f"Average PSNR: {np.mean(psnr_all):.2f} dB")
print(f"Average MSE: {np.mean(mse_all):.6f}")
print(f"Training time: {elapsed:.1f} minutes")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history.history['loss'], label='Train Loss', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(ssim_psnr_callback.ssim_history, linewidth=2, color='#3498db', marker='o')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('SSIM')
axes[1].set_title('SSIM over Epochs (Fine-tuning)')
axes[1].grid(True, alpha=0.3)

axes[2].plot(ssim_psnr_callback.psnr_history, linewidth=2, color='#e74c3c', marker='o')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('PSNR (dB)')
axes[2].set_title('PSNR over Epochs (Fine-tuning)')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].hist(ssim_all, bins=50, alpha=0.7, edgecolor='black')
axes[0].axvline(np.mean(ssim_all), color='red', linestyle='--', linewidth=2)
axes[0].set_xlabel('SSIM Score')
axes[0].set_title('SSIM Distribution')

axes[1].hist(psnr_all, bins=50, alpha=0.7, edgecolor='black')
axes[1].axvline(np.mean(psnr_all), color='red', linestyle='--', linewidth=2)
axes[1].set_xlabel('PSNR (dB)')
axes[1].set_title('PSNR Distribution')
plt.tight_layout()
plt.show()

best_idx = np.argsort(ssim_all)[-5:][::-1]
worst_idx = np.argsort(ssim_all)[:5]

fig, axes = plt.subplots(4, 5, figsize=(15, 12))
for i, idx in enumerate(best_idx):
    axes[0, i].imshow(X_test_normalized[idx])
    axes[0, i].axis('off')
    axes[1, i].imshow(test_predictions[idx])
    axes[1, i].set_title(f'SSIM: {ssim_all[idx]:.3f}', color='green')
    axes[1, i].axis('off')

for i, idx in enumerate(worst_idx):
    axes[2, i].imshow(X_test_normalized[idx])
    axes[2, i].axis('off')
    axes[3, i].imshow(test_predictions[idx])
    axes[3, i].set_title(f'SSIM: {ssim_all[idx]:.3f}', color='red')
    axes[3, i].axis('off')

plt.suptitle(f'Best and Worst Reconstructions (Avg SSIM: {np.mean(ssim_all):.4f})')
plt.tight_layout()
plt.show()

In [None]:
print(f"\n{'='*70}")
print(f"\nFINAL RESULTS:")
print(f"   Average SSIM: {np.mean(ssim_all):.4f}")
print(f"   Average PSNR: {np.mean(psnr_all):.2f} dB")
print(f"   Average MSE: {np.mean(mse_all):.6f}")
print(f"   Excellent images (â‰¥0.85): {excellent_ssim} ({excellent_ssim/len(ssim_all)*100:.1f}%)")
print(f"   Training time: {elapsed:.1f} minutes ({elapsed/60:.2f} hours)")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

cat_mask = (y_test.flatten() == 0)
dog_mask = (y_test.flatten() == 1)

axes[0, 0].bar(['Cat', 'Dog'], [np.mean(ssim_by_class[0]), np.mean(ssim_by_class[1])], 
               color=['#3498db', '#e74c3c'], alpha=0.7, edgecolor='black')
axes[0, 0].axhline(np.mean(ssim_all), color='red', linestyle='--', linewidth=2, label='Overall')
axes[0, 0].set_ylabel('Average SSIM')
axes[0, 0].set_title('SSIM Comparison: Cat vs Dog')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3, axis='y')

axes[0, 1].bar(['Cat', 'Dog'], [np.mean(psnr_by_class[0]), np.mean(psnr_by_class[1])], 
               color=['#3498db', '#e74c3c'], alpha=0.7, edgecolor='black')
axes[0, 1].axhline(np.mean(psnr_all), color='red', linestyle='--', linewidth=2, label='Overall')
axes[0, 1].set_ylabel('Average PSNR (dB)')
axes[0, 1].set_title('PSNR Comparison: Cat vs Dog')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3, axis='y')

axes[1, 0].scatter(ssim_all[cat_mask], psnr_all[cat_mask], alpha=0.5, s=30, label='Cat', color='#3498db')
axes[1, 0].scatter(ssim_all[dog_mask], psnr_all[dog_mask], alpha=0.5, s=30, label='Dog', color='#e74c3c')
axes[1, 0].set_xlabel('SSIM')
axes[1, 0].set_ylabel('PSNR (dB)')
axes[1, 0].set_title(f'SSIM vs PSNR by Class (corr={correlation:.3f})')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

quality_bins = ['Poor\n(<0.75)', 'Fair\n(0.75-0.80)', 'Good\n(0.80-0.85)', 'Excellent\n(>=0.85)']
quality_counts = [
    np.sum(ssim_all < 0.75),
    np.sum((ssim_all >= 0.75) & (ssim_all < 0.80)),
    np.sum((ssim_all >= 0.80) & (ssim_all < 0.85)),
    np.sum(ssim_all >= 0.85)
]
colors = ['#e74c3c', '#f39c12', '#2ecc71', '#27ae60']
axes[1, 1].bar(quality_bins, quality_counts, color=colors, alpha=0.7, edgecolor='black')
axes[1, 1].set_ylabel('Number of Images')
axes[1, 1].set_title('Image Quality Distribution')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

In [None]:
n_samples = 16
random_indices = np.random.choice(len(X_test_normalized), n_samples, replace=False)

fig, axes = plt.subplots(4, n_samples//2, figsize=(20, 10))

for i, idx in enumerate(random_indices):
    row = (i // (n_samples//2)) * 2
    col = i % (n_samples//2)
    
    axes[row, col].imshow(X_test_normalized[idx])
    axes[row, col].set_title(f'{class_names[y_test[idx][0]]}', fontsize=9)
    axes[row, col].axis('off')
    
    axes[row+1, col].imshow(test_predictions[idx])
    ssim_val = ssim_all[idx]
    psnr_val = psnr_all[idx]
    color = 'green' if ssim_val >= 0.80 else 'orange' if ssim_val >= 0.70 else 'red'
    axes[row+1, col].set_title(f'SSIM:{ssim_val:.3f}\nPSNR:{psnr_val:.1f}dB', 
                               fontsize=8, color=color)
    axes[row+1, col].axis('off')

plt.suptitle(f'Random Sample Reconstructions (n={n_samples})', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
if CONFIG['USE_ATTENTION']:
    n_samples = 8
    random_indices = np.random.choice(len(X_test_normalized), n_samples, replace=False)
    
    fig, axes = plt.subplots(3, n_samples, figsize=(20, 8))
    
    for i, idx in enumerate(random_indices):
        axes[0, i].imshow(X_test_normalized[idx])
        axes[0, i].set_title(f'{class_names[y_test[idx][0]]}', fontsize=9)
        axes[0, i].axis('off')
        
        axes[1, i].imshow(test_predictions[idx])
        axes[1, i].set_title(f'SSIM:{ssim_all[idx]:.3f}', fontsize=8)
        axes[1, i].axis('off')
        
        diff = np.abs(X_test_normalized[idx] - test_predictions[idx])
        diff_gray = np.mean(diff, axis=-1)
        im = axes[2, i].imshow(diff_gray, cmap='hot', vmin=0, vmax=0.3)
        axes[2, i].set_title('Error Map', fontsize=8)
        axes[2, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=11, fontweight='bold', rotation=0, ha='right', va='center')
    axes[1, 0].set_ylabel('Reconstructed', fontsize=11, fontweight='bold', rotation=0, ha='right', va='center')
    axes[2, 0].set_ylabel('Attention\nFocus', fontsize=11, fontweight='bold', rotation=0, ha='right', va='center')
    
    plt.colorbar(im, ax=axes[2, :], orientation='horizontal', fraction=0.05, pad=0.05, label='Error Magnitude')
    plt.suptitle('Spatial Attention Visualization (Red = High Error/Attention)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()