## Stage 1: Fine-tune a Vision Transformer (ViT) Model

### Step 1: Import libraries and define constants

- Replace `PRETRAINED_PATH` with the path to your pre-trained ViT model.

- Replace `IMG_DIR` with the path to your ImgCeleba dataset.

In [None]:
import os
import copy
import time
import math
import random
import pickle
from PIL import Image
from tqdm import tqdm
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import ViTForImageClassification, ViTImageProcessor, ViTConfig, ViTModel
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity


IMG_DIR = './TransformInverse/ImgCeleba/img_celeba'
PRETRAINED_PATH = "./TransformInverse"
INVERSE_VIT_PATH = 'TransformInverse/model_0414c(Blackbox-Self).pth'
NUM_LABELS = 40
BATCH_SIZE = 32
LEARNING_RATE_VIT = 3e-5
LEARNING_RATE_CLASSIFIER = 1e-4
NUM_EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Step 2: Adjust hyperparameters for fine-tuning

- Label smoothing contributes to the robustness of defending MIA when the factor is **negative**.

    $$\mathcal{L}^{LS}(\bold{y}, \bold{p}, \alpha) = (1-\alpha) \cdot \mathcal{L}_{CE}(\bold{y}, \bold{p}) + \frac{\alpha}{C} \cdot \mathcal{L}_{CE} (\bold{1}, \bold{p})$$

    > Be Careful What You Smooth For: Label Smoothing Can Be a Privacy Shield but Also a Catalyst for Model Inversion Attacks

- The other methods do not make much sense when `RESET_ALL_PARAMETERS` is set to `True`.

In [None]:
ENABLE_LABEL_SMOOTHING = False
LABEL_SMOOTHING_FACTOR = -0.1

ENABLE_L1_REGULARIZATION = False
L1_LAMBDA = 5e-6

ENABLE_L2_REGULARIZATION = False
weight_decay = 1e-4

ENABLE_VIB = True
beta = 1e-5

RESET_ALL_PARAMETERS = False

### Step 3: Enhance original ViT with Information Bottleneck (if used)

Classifier replaced with a linear layer outputing 40 classes as a downstream task.

In [None]:
class ViTVIBForImageClassification(nn.Module):
    def __init__(self, pretrained_model_name_or_path, num_labels, beta=1e-3, latent_dim_extension_factor=1):
        super().__init__()
        self.beta = beta
        self.num_labels = num_labels
        try:
            self.vit_backbone = ViTModel.from_pretrained(pretrained_model_name_or_path)
            print(f"Successfully loaded ViTModel from {pretrained_model_name_or_path}")
        except OSError:
            print(f"Could not load ViTModel directly. Attempting to load ViTForImageClassification and extract backbone from {pretrained_model_name_or_path}...")
            temp_full_model = ViTForImageClassification.from_pretrained(pretrained_model_name_or_path)
            self.vit_backbone = temp_full_model.vit
            print("Successfully extracted ViTModel (vit backbone).")


        self.config = self.vit_backbone.config
        self.hidden_size = self.config.hidden_size

        self.fc_mu = nn.Linear(self.hidden_size, self.hidden_size)
        self.fc_logvar = nn.Linear(self.hidden_size, self.hidden_size)

        self.classifier = nn.Linear(self.hidden_size, num_labels)

        nn.init.xavier_uniform_(self.fc_mu.weight)
        nn.init.zeros_(self.fc_mu.bias)
        nn.init.xavier_uniform_(self.fc_logvar.weight)
        nn.init.zeros_(self.fc_logvar.bias)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, pixel_values, labels=None):
        outputs = self.vit_backbone(pixel_values=pixel_values)
        last_hidden_state = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)

        mu = self.fc_mu(last_hidden_state)         # (batch_size, seq_len, hidden_size)
        logvar = self.fc_logvar(last_hidden_state) # (batch_size, seq_len, hidden_size)

        z = self.reparameterize(mu, logvar)      # (batch_size, seq_len, hidden_size)

        cls_token_representation = z[:, 0]      # (batch_size, hidden_size)
        logits = self.classifier(cls_token_representation) # (batch_size, num_labels)

        loss = None
        kld_loss_total = None
        if labels is not None:
            classification_loss = F.cross_entropy(logits, labels)

            kld_loss_element_wise = 0.5 * (mu.pow(2) + logvar.exp() - 1 - logvar)

            kld_loss_per_sample = torch.sum(kld_loss_element_wise, dim=[1, 2]) # Sum over seq_len and hidden_size
            kld_loss_total = torch.mean(kld_loss_per_sample) # Average over batch

            loss = classification_loss + self.beta * kld_loss_total

        return {
            "loss": loss,
            "logits": logits,
            "mu": mu,
            "logvar": logvar,
            "z_sampled": z,
            "classification_loss": classification_loss if labels is not None else None,
            "kld_loss": kld_loss_total
        }

In [None]:
try:
    processor = ViTImageProcessor.from_pretrained(PRETRAINED_PATH)
except OSError:
    print(
        f"Warning: Could not load ViTImageProcessor from {PRETRAINED_PATH}. "
        "Attempting to load a default ViT processor. Make sure this is intended."
    )
    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")


transform = transforms.Compose(
    [
        transforms.Resize((processor.size["height"], processor.size["width"])),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
    ]
)

vit_model = (
    ViTVIBForImageClassification(PRETRAINED_PATH, num_labels=40, beta=beta)
    if ENABLE_VIB
    else ViTForImageClassification.from_pretrained(PRETRAINED_PATH)
)

In [None]:
with open('./TransformInverse/ImgCeleba/labels.pkl', "rb") as f:
    labels = pickle.load(f)
labels.keys()

attr_df = labels['list_attr_celebA'].copy()
attr_df = attr_df.replace(-1, 0)
attr_labels = torch.tensor(attr_df.values, dtype=torch.float32)

partition_df = labels['list_eval_partition']
partition = partition_df[1].values
img_filenames = attr_df.index.tolist()

### Step 3: Prepare ImgCeleba dataset and the optimizer with regularization

In [None]:
class CelebADataset(Dataset):
    def __init__(self, filenames, labels, img_dir, transform):
        self.filenames = filenames
        self.labels = labels
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.filenames[idx])
        try:
            image = Image.open(img_path).convert("RGB")
            image = self.transform(image)
        except FileNotFoundError:
            print(f"Warning: File not found {img_path}. Returning dummy data.")
            image = torch.zeros((3, 224, 224)) # Match transform output shape
            label = torch.zeros(NUM_LABELS) # Match label shape
            return image, label
        except Exception as e:
            print(f"Warning: Error loading {img_path}: {e}. Returning dummy data.")
            image = torch.zeros((3, 224, 224))
            label = torch.zeros(NUM_LABELS)
            return image, label

        label = self.labels[idx]
        return image, label

train_indices = [i for i, p in enumerate(partition) if p == 0]
val_indices = [i for i, p in enumerate(partition) if p == 1]
test_indices = [i for i, p in enumerate(partition) if p == 2]

train_dataset = CelebADataset([img_filenames[i] for i in train_indices], attr_labels[train_indices], IMG_DIR, transform)
val_dataset = CelebADataset([img_filenames[i] for i in val_indices], attr_labels[val_indices], IMG_DIR, transform)
test_dataset = CelebADataset([img_filenames[i] for i in test_indices], attr_labels[test_indices], IMG_DIR, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Data loaded:")
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
num_ftrs = vit_model.classifier.in_features
vit_model.classifier = nn.Linear(num_ftrs, NUM_LABELS)
vit_model = vit_model.to(DEVICE)

criterion = nn.BCEWithLogitsLoss()

true_weight_decay = weight_decay if ENABLE_L2_REGULARIZATION else 0.0

if RESET_ALL_PARAMETERS:
    print("Resetting all model parameters to their initial state.")
    vit_model.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)

if ENABLE_VIB:
    optimizer = optim.AdamW([
        {'params': vit_model.vit_backbone.parameters(), 'lr': LEARNING_RATE_VIT, 'weight_decay': weight_decay},
        {'params': vit_model.fc_mu.parameters(), 'lr': LEARNING_RATE_CLASSIFIER, 'weight_decay': weight_decay},
        {'params': vit_model.fc_logvar.parameters(), 'lr': LEARNING_RATE_CLASSIFIER, 'weight_decay': weight_decay},
        {'params': vit_model.classifier.parameters(), 'lr': LEARNING_RATE_CLASSIFIER, 'weight_decay': weight_decay}
    ])
else:
    optimizer = optim.AdamW([
        {'params': vit_model.vit.parameters(), 'lr': LEARNING_RATE_VIT, 'weight_decay': true_weight_decay},
        {'params': vit_model.classifier.parameters(), 'lr': LEARNING_RATE_CLASSIFIER, 'weight_decay': true_weight_decay}
    ])

def label_smoothing_loss(outputs, targets, smoothing=-0.1):
    num_classes = outputs.size(1)
    # (1 - alpha) * L_CE(y, p) + alpha/C * L_CE(1, p)
    loss = (1 - smoothing) * criterion(outputs, targets) + (smoothing / num_classes) * criterion(torch.ones_like(outputs), targets)
    return loss


best_val_loss = float('inf')
best_model_state = None

### Step 4: Train the model

In [None]:
train_corrects_record = []

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    vit_model.train()
    train_loss = 0.0
    train_corrects = 0
    train_total = 0

    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
    for inputs, labels_batch in train_pbar:
        inputs, labels_batch = inputs.to(DEVICE), labels_batch.to(DEVICE)

        optimizer.zero_grad()

        raw_outputs = vit_model(inputs, labels=labels_batch)
        outputs = raw_outputs["logits"]

        if ENABLE_LABEL_SMOOTHING:
            loss = label_smoothing_loss(outputs, labels_batch, LABEL_SMOOTHING_FACTOR)
        else:
            loss = criterion(outputs, labels_batch)

        if ENABLE_L1_REGULARIZATION and L1_LAMBDA > 0:
            l1_penalty = torch.tensor(0.0, device=DEVICE)
            for param in vit_model.parameters():
                if param.requires_grad:
                    l1_penalty += torch.sum(torch.abs(param))
            loss += L1_LAMBDA * l1_penalty

        if ENABLE_VIB:
            mu = raw_outputs["mu"]
            logvar = raw_outputs["logvar"]
            loss += raw_outputs["kld_loss"] * beta

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5
        train_corrects += torch.sum(preds == labels_batch.byte()).item()
        train_total += labels_batch.numel()
        train_corrects_record.append(train_corrects)
        if ENABLE_VIB:
            train_pbar.set_postfix(
                {
                    "loss": loss.item(),
                    "acc": train_corrects / train_total,
                    "mu": mu.mean().item(),
                    "logvar": logvar.mean().item(),
                }
            )
        else:
            train_pbar.set_postfix({"loss": loss.item(), "acc": train_corrects / train_total})

    epoch_train_loss = train_loss / len(train_loader.dataset)
    epoch_train_acc = train_corrects / train_total

    vit_model.eval()
    val_loss = 0.0
    val_corrects = 0
    val_total = 0

    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]")
    with torch.no_grad():
        for inputs, labels_batch in val_pbar:
            inputs, labels_batch = inputs.to(DEVICE), labels_batch.to(DEVICE)

            outputs = vit_model(inputs).logits

        if ENABLE_LABEL_SMOOTHING:
            loss = label_smoothing_loss(outputs, labels_batch, LABEL_SMOOTHING_FACTOR)
        else:
            loss = criterion(outputs, labels_batch)

        if ENABLE_L1_REGULARIZATION and L1_LAMBDA > 0:
            l1_penalty = torch.tensor(0.0, device=DEVICE)
            for param in vit_model.parameters():
                if param.requires_grad:
                    l1_penalty += torch.sum(torch.abs(param))
            loss += L1_LAMBDA * l1_penalty

        if ENABLE_VIB:
            mu = raw_outputs["mu"]
            logvar = raw_outputs["logvar"]
            loss += raw_outputs["kld_loss"] * beta

        val_loss += loss.item() * inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5
        val_corrects += torch.sum(preds == labels_batch.byte()).item()
        val_total += labels_batch.numel()

        val_pbar.set_postfix({"loss": loss.item()})
        val_pbar.set_postfix({"acc": val_corrects / val_total})

    epoch_val_loss = val_loss / len(val_loader.dataset)
    epoch_val_acc = val_corrects / val_total
    end_time = time.time()
    epoch_duration = end_time - start_time

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Duration: {epoch_duration:.2f}s")
    print(f"  Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f}")
    print(f"  Val Loss:   {epoch_val_loss:.4f} | Val Acc:   {epoch_val_acc:.4f}")

    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        best_model_state = copy.deepcopy(vit_model.state_dict())
        print(f"  New best model saved with validation loss: {best_val_loss:.4f}")

if best_model_state:
    vit_model.load_state_dict(best_model_state)
    print("\nLoaded best model weights based on validation loss.")

In [None]:
print("\nStarting testing phase...")
vit_model.eval()
test_loss = 0.0
test_corrects = 0
test_total = 0

test_pbar = tqdm(test_loader, desc="Testing")
with torch.no_grad():
    for inputs, labels_batch in test_pbar:
        inputs, labels_batch = inputs.to(DEVICE), labels_batch.to(DEVICE)

        outputs = vit_model(inputs)['logits']
        loss = criterion(outputs, labels_batch)

        test_loss += loss.item() * inputs.size(0)
        preds = torch.sigmoid(outputs) > 0.5
        test_corrects += torch.sum(preds == labels_batch.byte()).item()
        test_total += labels_batch.numel()

        test_pbar.set_postfix({'loss': loss.item()})

final_test_loss = test_loss / len(test_loader.dataset)
final_test_acc = test_corrects / test_total

print(f"\nTest Results:")
print(f"  Test Loss (without regularization): {final_test_loss:.4f}")
print(f"  Test Accuracy: {final_test_acc:.4f}")

## Stage 2: Test inversion on the fine-tuned model

### Step 1: Define inversion structure

The structure is same as the `Self-Attention` network in the `invVIT-notebook.ipynb`.

In [None]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop2 = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class InverseAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x, context):
        B, N, C = x.shape
        B_ctx, N_ctx, C_ctx = context.shape
        assert B == B_ctx and C == C_ctx
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x_attn = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_out = self.proj(x_attn)
        x_out = self.proj_drop(x_out)
        return x_out

class InverseBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = InverseAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    def forward(self, x, context):
        x = x + self.attn(self.norm1(x), context)
        x = x + self.mlp(self.norm2(x))
        return x

class PixelDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.patch_size = config.patch_size
        self.num_channels = config.num_channels
        self.image_size_h = config.image_size
        self.image_size_w = config.image_size
        self.num_patches = (self.image_size_h // self.patch_size) * (self.image_size_w // self.patch_size)
        self.proj = nn.Linear(self.hidden_size, self.num_channels * self.patch_size * self.patch_size)
        self.unpatchify = nn.Fold(output_size=(self.image_size_h, self.image_size_w),
                                  kernel_size=(self.patch_size, self.patch_size),
                                  stride=(self.patch_size, self.patch_size))

    def forward(self, x):
        if x.shape[1] == self.num_patches + 1:
            x = x[:, 1:, :]
        elif x.shape[1] != self.num_patches:
            raise ValueError(f"Input embedding sequence length ({x.shape[1]}) doesn't match expected number of patches ({self.num_patches}) or patches+1.")
        x = self.proj(x)
        x = x.transpose(1, 2)
        reconstructed_image = self.unpatchify(x)

        return reconstructed_image


class InverseViT(nn.Module):
    def __init__(self, forward_model_config, norm_layer=nn.LayerNorm, act_layer=nn.GELU):
        super().__init__()
        self.config = forward_model_config # Store config for decoder
        self.num_layers = forward_model_config.num_hidden_layers
        self.hidden_dim = forward_model_config.hidden_size
        self.num_heads = forward_model_config.num_attention_heads
        self.mlp_ratio = getattr(forward_model_config, 'intermediate_size', self.hidden_dim * 4) / self.hidden_dim

        self.norm_start = norm_layer(self.hidden_dim)
        self.blocks = nn.ModuleList([
            InverseBlock(
                dim=self.hidden_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
                qkv_bias=True, attn_drop=0., drop=0., norm_layer=norm_layer, act_layer=act_layer
            ) for _ in range(self.num_layers)])
        self.norm_end = norm_layer(self.hidden_dim)
        self.decoder = PixelDecoder(self.config)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, final_features, intermediate_forward_outputs=None):
        x = final_features
        context = final_features
        x = self.norm_start(x)

        losses_q = []
        for i in range(self.num_layers):
            x = self.blocks[i](x, context=context)

        reconstructed_embedding = self.norm_end(x) # Shape: [B, N, D]
        reconstructed_image = self.decoder(reconstructed_embedding) # Shape: [B, C, H, W]
        return reconstructed_image

Replace the path of inverse_model to the trained inverse model in `invViT-notebook.ipynb`.

In [None]:
inverse_model = InverseViT(vit_model.config)
inverse_model.load_state_dict(torch.load(INVERSE_VIT_PATH))
inverse_model.to(DEVICE)

vit_model_original = ViTForImageClassification.from_pretrained(PRETRAINED_PATH).to(DEVICE)

### Step 2: Test visualization

Here we explore in detail the inverse effect of a single image.

From the heatmap of the difference below, we can confirm that fine-tune model differs in the `last_hidden_state` numerically from the original model.

In [None]:
test_img = Image.open('./TransformInverse/custom_test_imges/cat.jpg').convert('RGB')

test_img_resized = test_img.resize((224, 224))
test_img_transformed = transform(test_img_resized).unsqueeze(0).to(DEVICE)

original_features = vit_model_original.vit(test_img_transformed).last_hidden_state
finetuned_features = vit_model.vit_backbone(test_img_transformed).last_hidden_state if ENABLE_VIB else vit_model.vit(test_img_transformed).last_hidden_state

plt.figure(figsize=(10, 3), dpi=300)
sns.heatmap(abs((original_features - finetuned_features)[0].cpu().detach().numpy()), cmap="Blues", vmin=0, vmax=1)

### Step 3: Visualize

In [None]:
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])

reconstructed_original_img_tensor = inverse_model(original_features)
reconstructed_original_img_np = reconstructed_original_img_tensor.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
reconstructed_original_img_denorm = reconstructed_original_img_np * std + mean
reconstructed_original_img_clipped = np.clip(reconstructed_original_img_denorm, 0, 1)

reconstructed_finetuned_img_tensor = inverse_model(finetuned_features)
reconstructed_finetuned_img_np = reconstructed_finetuned_img_tensor.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
reconstructed_finetuned_img_denorm = reconstructed_finetuned_img_np * std + mean
reconstructed_finetuned_img_clipped = np.clip(reconstructed_finetuned_img_denorm, 0, 1)

plt.figure(figsize=(10, 3), dpi=300)
plt.subplot(1,3,1)
plt.imshow(test_img_resized)
plt.title("Input")
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(reconstructed_original_img_clipped)
plt.title("Reconstructed(Original)")
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(reconstructed_finetuned_img_clipped)
plt.title("Reconstructed(Fine-tuned)")
plt.axis('off')
plt.show()

In [None]:
original_img_np = np.array(test_img_resized).astype(np.float32) / 255.0

mse_original = mean_squared_error(original_img_np, reconstructed_original_img_clipped)
psnr_original = peak_signal_noise_ratio(original_img_np, reconstructed_original_img_clipped, data_range=1.0)
ssim_original = structural_similarity(original_img_np, reconstructed_original_img_clipped, data_range=1.0, channel_axis=-1, win_size=7)

mse_finetuned = mean_squared_error(original_img_np, reconstructed_finetuned_img_clipped)
psnr_finetuned = peak_signal_noise_ratio(original_img_np, reconstructed_finetuned_img_clipped, data_range=1.0)
ssim_finetuned = structural_similarity(original_img_np, reconstructed_finetuned_img_clipped, data_range=1.0, channel_axis=-1, win_size=7)


print("--- Metrics vs Original Image ---")
print("\nReconstruction (Original Model):")
print(f"  MSE:  {mse_original:.4f}")
print(f"  PSNR: {psnr_original:.2f} dB")
print(f"  SSIM: {ssim_original:.4f}")

print("\nReconstruction (Fine-tuned Model):")
print(f"  MSE:  {mse_finetuned:.4f}")
print(f"  PSNR: {psnr_finetuned:.2f} dB")
print(f"  SSIM: {ssim_finetuned:.4f}")