# Sample Code for training `invViT`

### Step 1: Import Required Libraries and set configurations

Preparations:

- Dependencies in `requirements.txt`（`pip install -r requirements.txt`）
- ViT-Large model (set the folder in the `model_name` variable)
- OpenImages dataset (set the folder in the `dataset_path` variable, pure images should be in the folder)

In [None]:
import os
import time
import math
import torch
import lpips
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import ViTForImageClassification, ViTImageProcessor
from collections import deque

# --- Configuration (Adjust as needed) ---
model_name = './TransformInverse'
dataset_dir = './TransformInverse/dataset_OID2018'
batch_size = 16 # Reduce if memory errors occur with decoder
learning_rate_q = 1e-4
learning_rate_q_prime = 1e-4 # May need adjustment for pixel reconstruction
epochs_q = 5
epochs_q_prime = 20 # Might need more epochs for pixel reconstruction
num_workers = 2
seed = 42
val_split = 0.15
test_split = 0.15
print_every = 500   # Print visualization results every N batches while training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

print("Loading forward model...")
try:
    forward_model = ViTForImageClassification.from_pretrained(model_name)
    processor = ViTImageProcessor.from_pretrained(model_name)
    config = forward_model.config # Get config for decoder
except Exception as e:
    print(f"Error loading model/processor from {model_name}: {e}")
    exit()

forward_model.to(device)
forward_model.eval()

### Step 2: Define network architecture

Self attention mode by default. Switch to cross attention by uncomment some of these lines:
```python
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# q = self.q(context).reshape(B_ctx, N_ctx, self.num_heads, C_ctx // self.num_heads).permute(0, 2, 1, 3)
k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# k = self.k(context).reshape(B_ctx, N_ctx, self.num_heads, C_ctx // self.num_heads).permute(0, 2, 1, 3)
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# v = self.v(context).reshape(B_ctx, N_ctx, self.num_heads, C_ctx // self.num_heads).permute(0, 2, 1, 3)
```

In [None]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop2 = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class InverseAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x, context):
        B, N, C = x.shape
        B_ctx, N_ctx, C_ctx = context.shape
        assert B == B_ctx and C == C_ctx

        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        # q = self.q(context).reshape(B_ctx, N_ctx, self.num_heads, C_ctx // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        # k = self.k(context).reshape(B_ctx, N_ctx, self.num_heads, C_ctx // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        # v = self.v(context).reshape(B_ctx, N_ctx, self.num_heads, C_ctx // self.num_heads).permute(0, 2, 1, 3)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x_attn = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x_out = self.proj(x_attn)
        x_out = self.proj_drop(x_out)
        return x_out

class InverseBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = InverseAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    def forward(self, x, context):
        x = x + self.attn(self.norm1(x), context)
        x = x + self.mlp(self.norm2(x))
        return x

class PixelDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.patch_size = config.patch_size
        self.num_channels = config.num_channels
        self.image_size_h = config.image_size # Assuming square image from config.image_size
        self.image_size_w = config.image_size
        self.num_patches = (self.image_size_h // self.patch_size) * (self.image_size_w // self.patch_size)
        self.proj = nn.Linear(self.hidden_size, self.num_channels * self.patch_size * self.patch_size)
        self.unpatchify = nn.Fold(output_size=(self.image_size_h, self.image_size_w),
                                  kernel_size=(self.patch_size, self.patch_size),
                                  stride=(self.patch_size, self.patch_size))

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Reconstructed embeddings. Shape [B, N, D].
                              N = num_patches + 1 (if CLS token included).
        """
        # Assume the first token is CLS if N = num_patches + 1, discard it
        if x.shape[1] == self.num_patches + 1:
            x = x[:, 1:, :] # Shape: [B, num_patches, D]
        elif x.shape[1] != self.num_patches:
            raise ValueError(f"Input embedding sequence length ({x.shape[1]}) doesn't match expected number of patches ({self.num_patches}) or patches+1.")

        x = self.proj(x) # Shape: [B, num_patches, C * patch_size * patch_size]

        # Reshape for Fold: Fold expects [B, C * kH * kW, L] where L is num_patches
        x = x.transpose(1, 2) # Shape: [B, C * patch_size * patch_size, num_patches]

        # Unpatchify
        reconstructed_image = self.unpatchify(x) # Shape: [B, C, H, W]

        return reconstructed_image


class InverseViT(nn.Module):
    """ Inverse Vision Transformer Network with Pixel Decoder. """
    def __init__(self, forward_model_config, norm_layer=nn.LayerNorm, act_layer=nn.GELU):
        super().__init__()
        self.config = forward_model_config # Store config for decoder
        self.num_layers = forward_model_config.num_hidden_layers
        self.hidden_dim = forward_model_config.hidden_size
        self.num_heads = forward_model_config.num_attention_heads
        self.mlp_ratio = getattr(forward_model_config, 'intermediate_size', self.hidden_dim * 4) / self.hidden_dim
        self.norm_start = norm_layer(self.hidden_dim)
        self.blocks = nn.ModuleList([
            InverseBlock(
                dim=self.hidden_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
                qkv_bias=True, attn_drop=0., drop=0., norm_layer=norm_layer, act_layer=act_layer
            ) for _ in range(self.num_layers)])
        self.norm_end = norm_layer(self.hidden_dim)
        self.decoder = PixelDecoder(self.config)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, final_features):
        """
        Args:
            final_features (torch.Tensor): Output of the last encoder layer of the forward ViT. [B, N, D]
            intermediate_forward_outputs (list[torch.Tensor], optional): Needed for 'train_q'.
            mode (str): 'train_q', 'train_q_prime', or 'eval'.

        Returns:
            Reconstructed image tensor [B, C, H, W].
        """
        x = final_features
        context = final_features
        x = self.norm_start(x)

        losses_q = []
        for i in range(self.num_layers):
            x = self.blocks[i](x, context=context)

        # Output after blocks is the reconstructed embedding
        reconstructed_embedding = self.norm_end(x) # Shape: [B, N, D]

        reconstructed_image = self.decoder(reconstructed_embedding) # Shape: [B, C, H, W]
        return reconstructed_image
    
class ReconstructionLoss(nn.Module):
    def __init__(self, lambda_l1=1.0, lambda_perceptual=0.1):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        self.perceptual_loss = lpips.LPIPS(net='vgg')
        self.lambda_l1 = lambda_l1
        self.lambda_perceptual = lambda_perceptual
        
    def forward(self, reconstructed, original):
        l1 = self.l1_loss(reconstructed, original)
        mse = self.mse_loss(reconstructed, original)
        perceptual = self.perceptual_loss(reconstructed, original).mean()
        total_loss = l1 + mse + self.lambda_perceptual * perceptual
        
        return {
            'total': total_loss,
            'l1': l1,
            'mse': mse,
            'perceptual': perceptual
        }

### Step 3: Data preparation

In [None]:
@torch.no_grad()
def get_forward_outputs(forward_model, inputs, device):
    forward_model.eval()
    inputs = inputs.to(device)
    outputs = forward_model(inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    if not isinstance(hidden_states, tuple) or len(hidden_states) != forward_model.config.num_hidden_layers + 1:
        raise ValueError("Failed to retrieve hidden states correctly.")
    original_embeddings = hidden_states[0].detach()
    intermediate_and_final_outputs = tuple(h.detach() for h in hidden_states)
    final_features = hidden_states[-1].detach()
    return original_embeddings, intermediate_and_final_outputs, final_features

In [None]:
image_height = config.image_size
image_width = config.image_size
transform = transforms.Compose([
    transforms.Resize((image_height, image_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=getattr(config, 'image_mean', [0.5, 0.5, 0.5]),
                            std=getattr(config, 'image_std', [0.5, 0.5, 0.5]))
])

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        supported_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')
        try:
            self.image_files = [f for f in os.listdir(root_dir)
                                if os.path.isfile(os.path.join(root_dir, f)) and f.lower().endswith(supported_extensions)]
            if not self.image_files: print(f"Warning: No images found in {root_dir}")
        except FileNotFoundError:
            print(f"Error: Dataset directory not found: {root_dir}"); self.image_files = []
        self.transform = transform
        print(f"Found {len(self.image_files)} images in {root_dir}")
    def __len__(self): return len(self.image_files)
    def __getitem__(self, idx):
        if idx >= len(self.image_files): raise IndexError("Index out of bounds")
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        try:
            image = Image.open(img_name).convert('RGB')
            if self.transform: image = self.transform(image)
            return image # Return the transformed image tensor
        except Exception as e:
            print(f"Warning: Error loading image {img_name}: {e}. Returning dummy.")
            # Ensure dummy has correct shape [C, H, W]
            return torch.zeros((config.num_channels, image_height, image_width))

dataset = ImageDataset(root_dir=dataset_dir, transform=transform)
if len(dataset) == 0: raise ValueError(f"Dataset empty. Check path: {dataset_dir}")

val_count = int(val_split * len(dataset)); test_count = int(test_split * len(dataset))
train_count = len(dataset) - val_count - test_count
if train_count <= 0 or val_count <= 0:
        print("Warning: Dataset small, adjusting splits."); val_count = max(1, int(0.1 * len(dataset)))
        train_count = len(dataset) - val_count; test_count = 0

print(f"Splits - Train: {train_count}, Val: {val_count}, Test: {test_count}")
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_count, val_count, test_count], generator=torch.Generator().manual_seed(seed))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
# --- Initialize Inverse Model (includes PixelDecoder) ---
print("Initializing inverse model with pixel decoder...")
inverse_model = InverseViT(config) # Pass config now
inverse_model.to(device)
print(f"Inverse model params: {sum(p.numel() for p in inverse_model.parameters() if p.requires_grad)}")

### Step 4: Helper functions for training and evaluation

In [None]:
def denormalize(tensor, mean, std):
    """Denormalizes a tensor image with mean and standard deviation.
    Args:
        tensor (torch.Tensor): Tensor image of size (N, C, H, W) to be denormalized.
        mean (list or tuple): Mean values for each channel.
        std (list or tuple): Standard deviation values for each channel.
    Returns:
        torch.Tensor: Denormalized tensor image.
    """
    mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1)
    std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1)
    return tensor * std + mean

def visualize_reconstruction(
    forward_model,
    inverse_model,
    test_loader,
    device,
    pixel_loss_history,
    current_batch_idx,
    mean,
    std,
    num_images=5,
    loss_history_length=1000
    ):
    print(f"Generating visualization for batch {current_batch_idx}...")

    # --- Store original modes and set to eval ---
    inv_mode_orig = inverse_model.training
    fwd_mode_orig = forward_model.training
    inverse_model.eval()
    forward_model.eval()

    try:
        try:
            test_images = next(iter(test_loader))
            if isinstance(test_images, (list, tuple)):
                 test_images = test_images[0] # If dataset returns (image, label)

            if test_images.shape[0] < num_images:
                 print(f"Warning: Test batch size ({test_images.shape[0]}) is smaller than num_images ({num_images}).")
                 num_images = test_images.shape[0]

            test_images = test_images[:num_images].to(device) # Select first N images

        except StopIteration:
            print("Error: Cannot get data from test_loader")
            return
        except Exception as e:
            print(f"Error getting data from test_loader: {e}")
            return

        with torch.no_grad():
            _, _, final_features = get_forward_outputs(forward_model, test_images, device)
            reconstructed_images = inverse_model(final_features)

        original_images_denorm = denormalize(test_images, mean, std)
        reconstructed_images_denorm = denormalize(reconstructed_images, mean, std)

        original_images_np = original_images_denorm.clamp(0, 1).cpu().numpy()
        reconstructed_images_np = reconstructed_images_denorm.clamp(0, 1).cpu().numpy()

        # Transpose from (N, C, H, W) to (N, H, W, C) for matplotlib
        if original_images_np.shape[1] == 3: # RGB
            original_images_np = np.transpose(original_images_np, (0, 2, 3, 1))
            reconstructed_images_np = np.transpose(reconstructed_images_np, (0, 2, 3, 1))
        elif original_images_np.shape[1] == 1: # Grayscale
            original_images_np = np.squeeze(original_images_np, axis=1)
            reconstructed_images_np = np.squeeze(reconstructed_images_np, axis=1)

        fig = plt.figure(figsize=(max(15, num_images * 3), 11)) # Adjust width based on num_images
        gs = gridspec.GridSpec(3, num_images, height_ratios=[1, 1, 1.5], wspace=0.1, hspace=0.3)

        for i in range(num_images):
            ax_orig = fig.add_subplot(gs[0, i])
            ax_orig.imshow(original_images_np[i], cmap='gray' if len(original_images_np[i].shape) == 2 else None)
            ax_orig.set_title(f'Original {i+1}')
            ax_orig.axis('off')

            ax_rec = fig.add_subplot(gs[1, i])
            ax_rec.imshow(reconstructed_images_np[i], cmap='gray' if len(reconstructed_images_np[i].shape) == 2 else None)
            ax_rec.set_title(f'Reconstructed {i+1}')
            ax_rec.axis('off')

        ax_loss = fig.add_subplot(gs[2, :]) # Span across the bottom row

        loss_list = list(pixel_loss_history)
        num_losses = len(loss_list)

        if num_losses > 0:
             start_batch_idx = max(0, current_batch_idx - num_losses + 1)
             batch_indices = list(range(start_batch_idx, current_batch_idx + 1))

             if len(batch_indices) != num_losses:
                  print(f"Warning: Adjusting batch indices for plotting. Current batch: {current_batch_idx}, Losses count: {num_losses}")
                  batch_indices = list(range(max(0, current_batch_idx - num_losses + 1), current_batch_idx + 1))
                  if len(batch_indices) != num_losses:
                      batch_indices = list(range(num_losses))


             ax_loss.plot(batch_indices, loss_list, label='Loss', color='orange')
             ax_loss.set_xlabel('Batch Index')
             ax_loss.set_ylabel('Loss')
             ax_loss.set_title(f'Recent Pixel Loss History (Last {min(num_losses, loss_history_length)} Batches up to Batch {current_batch_idx})')
             ax_loss.legend()
             ax_loss.grid(True, alpha = 0.4)
        else:
            ax_loss.set_title('Loss History (No batch data yet)')
            ax_loss.text(0.5, 0.5, 'Waiting for training data...', ha='center', va='center', transform=ax_loss.transAxes)

        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
        fig.suptitle(f'Reconstruction Visualization at Batch {current_batch_idx} ({time.strftime("%Y-%m-%d %H:%M:%S")})', fontsize=16)
        plt.show()

    except Exception as e:
        print(f"An error occurred during visualization: {e}")
        import traceback
        traceback.print_exc()

    finally:
        inverse_model.train(inv_mode_orig)
        forward_model.train(fwd_mode_orig)

### Step 5: Training loop

uncomment the following block to recover weights from a previous run

In [None]:
# inverse_model.load_state_dict(torch.load('TransformInverse/model_0414b.pth'))

In [None]:
loss_fn_pixel = ReconstructionLoss(lambda_l1=1.0, lambda_perceptual=0.15).to(device)  # For Phase 2

optimizer = optim.AdamW(inverse_model.parameters(), lr=learning_rate_q_prime)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.2, patience=3, verbose=True)
loss_records = deque(maxlen=1000)

for epoch in range(epochs_q_prime):
    epoch_start_time = time.time()
    inverse_model.train()
    total_train_loss_q_prime = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs_q_prime}", leave=False)
    # Training
    for batch_idx, original_images in enumerate(progress_bar):
        original_images = original_images.to(device)

        _, _, final_features = get_forward_outputs(forward_model, original_images, device)

        optimizer.zero_grad()
        reconstructed_images = inverse_model(final_features)

        loss = loss_fn_pixel(reconstructed_images, original_images)['total']

        loss.backward()
        optimizer.step()
        total_train_loss_q_prime += loss.item()
        progress_bar.set_postfix(
            {"Loss(pixel)": f"{loss.item():.4f}", "LR": f"{optimizer.param_groups[0]['lr']:.1e}"}
        )
        loss_records.append(loss.item())
        if batch_idx % print_every == 0:
            visualize_reconstruction(
                forward_model=forward_model,
                inverse_model=inverse_model,
                test_loader=test_loader,
                device=device,
                pixel_loss_history=loss_records,
                current_batch_idx=batch_idx,
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5],
                num_images=5,
                loss_history_length=1000,
            )

    avg_train_loss_q_prime = total_train_loss_q_prime / len(train_loader) if len(train_loader) > 0 else 0

    # Validation
    inverse_model.eval()
    total_val_loss_q_prime = 0.0
    val_progress_bar = tqdm(val_loader, desc=f"Validation", leave=False)
    with torch.no_grad():
        for original_images in val_progress_bar:
            original_images = original_images.to(device)
            _, _, final_features = get_forward_outputs(forward_model, original_images, device)
            reconstructed_images = inverse_model(final_features)
            # Use pixel loss for validation
            total_val_loss_q_prime += loss_fn_pixel(reconstructed_images, original_images)['total'].item()

    avg_val_loss_q_prime = total_val_loss_q_prime / len(val_loader) if len(val_loader) > 0 else 0
    epoch_time = time.time() - epoch_start_time
    print(
        f"Epoch {epoch+1}/{epochs_q_prime} - Time: {epoch_time:.2f}s - Train Loss: {avg_train_loss_q_prime:.4f} - Val Loss: {avg_val_loss_q_prime:.4f}"
    )
    scheduler.step(avg_val_loss_q_prime)

print("\nTraining finished.")

### Step 6: Save weights or test the model

In [None]:
# torch.save(inverse_model.state_dict(), 'TransformInverse/model_0414c(Blackbox-Self).pth')

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

inverse_model.eval()
forward_model.eval()

mse_list = []
mae_list = []
psnr_list = []
ssim_list = []


mean = torch.tensor(getattr(config, 'image_mean', [0.5, 0.5, 0.5]), device=device).view(1, -1, 1, 1)
std = torch.tensor(getattr(config, 'image_std', [0.5, 0.5, 0.5]), device=device).view(1, -1, 1, 1)

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating on test set"):
        original_images = batch.to(device)
        _, _, final_features = get_forward_outputs(forward_model, original_images, device)
        reconstructed_images = inverse_model(final_features)

        # Denormalize
        original_images_denorm = original_images * std + mean
        reconstructed_images_denorm = reconstructed_images * std + mean

        original_images_denorm = original_images_denorm.clamp(0, 1).cpu().numpy()
        reconstructed_images_denorm = reconstructed_images_denorm.clamp(0, 1).cpu().numpy()

        for orig, rec in zip(original_images_denorm, reconstructed_images_denorm):
            # (C, H, W) -> (H, W, C)
            orig_img = np.transpose(orig, (1, 2, 0))
            rec_img = np.transpose(rec, (1, 2, 0))

            mse = mean_squared_error(orig_img.flatten(), rec_img.flatten())
            mae = mean_absolute_error(orig_img.flatten(), rec_img.flatten())
            psnr = compare_psnr(orig_img, rec_img, data_range=1.0)
            ssim = compare_ssim(orig_img, rec_img, multichannel=True, data_range=1.0)

            mse_list.append(mse)
            mae_list.append(mae)
            psnr_list.append(psnr)
            ssim_list.append(ssim)

print(f"Test set results:")
print(f"  MSE:  {np.mean(mse_list):.6f}")
print(f"  MAE:  {np.mean(mae_list):.6f}")
print(f"  PSNR: {np.mean(psnr_list):.2f}")
print(f"  SSIM: {np.mean(ssim_list):.4f}")