In [1]:
##generating 1000 samples

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageEnhance, ImageFilter
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.spatial.distance import jensenshannon
from scipy import stats
import kagglehub

# Hyperparameters
RESIZE = 128
intermediate_dim = 512
latent_dim = 256
num_classes = 5
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Custom Dataset for Kaggle Vehicle Type Image Dataset (for reference images)
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Enhanced Image Grid for Plotting
def image_grid(imgs, rows, cols, class_label, height=RESIZE, width=RESIZE):
    assert len(imgs) <= rows * cols
    grid = Image.new('RGB', size=(cols * width, rows * height), color=(255, 255, 255))
    for i, img in enumerate(imgs):
        x = (i % cols) * width + 2
        y = (i // cols) * height + 2
        grid.paste(img, box=(x, y))
    return grid

# Weight initialization function
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder class
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder class
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_size, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE class
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Function to enhance image quality
def enhance_image(pil_image):
    pil_image = pil_image.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
    pil_image = ImageEnhance.Sharpness(pil_image).enhance(2.0)
    pil_image = ImageEnhance.Contrast(pil_image).enhance(1.3)
    pil_image = ImageEnhance.Brightness(pil_image).enhance(1.1)
    pil_image = pil_image.resize((RESIZE * 2, RESIZE * 2), Image.Resampling.LANCZOS)
    pil_image = pil_image.resize((RESIZE, RESIZE), Image.Resampling.LANCZOS)
    return pil_image

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    orig_np = orig.cpu().numpy() * 0.5 + 0.5
    gen_np = gen.cpu().numpy() * 0.5 + 0.5
    orig_np = np.transpose(orig_np, (1, 2, 0))
    gen_np = np.transpose(gen_np, (1, 2, 0))
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    psnr_score = psnr(orig_np, gen_np, data_range=1.0)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    js_scores = []
    for c in range(3):
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Instantiate the model
input_size = (8, 3, RESIZE, RESIZE)
encoder = Encoder(input_size, intermediate_dim, latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, intermediate_dim, input_size, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Load the saved model weights
model_path = "cvae_vehicle_final.pth"
try:
    cvae.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded CVAE model from {model_path}")
except Exception as e:
    print(f"Error loading model: {str(e)}")
    exit()

# Load reference images for metrics computation
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Define class names
class_names = ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

# Modified generate_samples_labelwise function to generate 1000 samples and compute statistics
def generate_samples_labelwise(cvae, num_samples_per_class, base_dir, latent_dim, device):
    cvae.eval()
    os.makedirs(base_dir, exist_ok=True)

    # Initialize dictionary to store metrics
    all_metrics = {class_label: [] for class_label in range(num_classes)}

    with torch.no_grad():
        for class_label in range(num_classes):
            print(f"Generating samples for class {class_label} ({class_names[class_label]})...")
            label_tensor = torch.tensor([class_label]).repeat(num_samples_per_class).to(device)
            one_hot_labels = F.one_hot(label_tensor, num_classes=num_classes).float().to(device)

            # Generate samples in smaller batches to avoid memory issues
            batch_size = 50
            idx = 0
            for start_idx in range(0, num_samples_per_class, batch_size):
                end_idx = min(start_idx + batch_size, num_samples_per_class)
                current_batch_size = end_idx - start_idx

                # Use truncation trick
                z = torch.randn(current_batch_size, latent_dim).to(device) * 0.7
                z = torch.clamp(z, -2.0, 2.0)
                generated_samples = cvae.decoder(z, one_hot_labels[start_idx:end_idx])

                # Process and save each sample
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)

                images_for_grid = []
                for batch_idx, sample in enumerate(generated_samples):
                    global_idx = start_idx + batch_idx
                    # Convert tensor to numpy array
                    sample = sample.cpu().detach().numpy()
                    sample = sample * 0.5 + 0.5
                    sample = np.nan_to_num(sample, nan=0.0, posinf=0.0, neginf=0.0)
                    sample = (255 * sample).astype(np.uint8)
                    sample = np.transpose(sample, (1, 2, 0))
                    pil_image = Image.fromarray(sample).convert('RGB')

                    # Enhance the image
                    pil_image = enhance_image(pil_image)

                    # Save the enhanced image
                    pil_image.save(os.path.join(class_dir, f"sample_{global_idx}.png"))
                    if global_idx < 32:
                        images_for_grid.append(pil_image)

                    # Convert back to tensor for metrics computation
                    gen_tensor = transform(pil_image).to(device)

                    # Compute metrics against reference images
                    ref_images = reference_images[class_label]
                    psnr_scores, ssim_scores, js_scores = [], [], []
                    for ref_tensor in ref_images:
                        psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
                        psnr_scores.append(psnr_score)
                        ssim_scores.append(ssim_score)
                        js_scores.append(js_score)

                    avg_psnr = np.mean(psnr_scores)
                    avg_ssim = np.mean(ssim_scores)
                    avg_js = np.mean(js_scores)

                    all_metrics[class_label].append({
                        'image': f"sample_{global_idx}.png",
                        'psnr': avg_psnr,
                        'ssim': avg_ssim,
                        'js': avg_js
                    })

                idx += current_batch_size

            # Save a grid of the first 32 images
            if images_for_grid:
                grid = image_grid(images_for_grid, rows=4, cols=8, class_label=class_names[class_label])
                grid.save(os.path.join(class_dir, f"grid_{class_names[class_label]}.png"))
            print(f"Generated {num_samples_per_class} samples for class {class_label}")

    # Compute and print mean, median, mode for each class
    for class_label in range(num_classes):
        psnr_values = np.array([metric['psnr'] for metric in all_metrics[class_label]])
        ssim_values = np.array([metric['ssim'] for metric in all_metrics[class_label]])
        js_values = np.array([metric['js'] for metric in all_metrics[class_label]])

        # Compute mean
        psnr_mean = np.mean(psnr_values)
        ssim_mean = np.mean(ssim_values)
        js_mean = np.mean(js_values)

        # Compute median
        psnr_median = np.median(psnr_values)
        ssim_median = np.median(ssim_values)
        js_median = np.median(js_values)

        # Compute mode (approximate by binning for continuous data)
        psnr_mode = stats.mode(np.round(psnr_values, 1), keepdims=False)[0]
        ssim_mode = stats.mode(np.round(ssim_values, 3), keepdims=False)[0]
        js_mode = stats.mode(np.round(js_values, 3), keepdims=False)[0]

        print(f"\nClass {class_label} ({class_names[class_label]}) Metrics Statistics (1000 samples):")
        print(f"PSNR - Mean: {psnr_mean:.2f}, Median: {psnr_median:.2f}, Mode: {psnr_mode:.2f}")
        print(f"SSIM - Mean: {ssim_mean:.4f}, Median: {ssim_median:.4f}, Mode: {ssim_mode:.4f}")
        print(f"JS   - Mean: {js_mean:.4f}, Median: {js_median:.4f}, Mode: {js_mode:.4f}")

    return all_metrics

# Generate 1000 samples per class and compute metrics
base_dir = "generated_samples-arch-A5-1000-enhanced"
try:
    all_metrics = generate_samples_labelwise(
        cvae, num_samples_per_class=1000, base_dir=base_dir, 
        latent_dim=latent_dim, device=device
    )
    print("Finished generating 1000 enhanced samples per class and computing metrics")
except Exception as e:
    print(f"Error during sample generation: {str(e)}")

# Clear GPU memory
if device.type == 'cuda':
    torch.cuda.empty_cache()

Using device: cuda:1
GPU Name: NVIDIA RTX A5000
GPU Memory Allocated: 0.00 MB


  cvae.load_state_dict(torch.load(model_path, map_location=device))


Loaded CVAE model from cvae_vehicle_final.pth
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Generating samples for class 0 (Hatchback)...
Generated 1000 samples for class 0
Generating samples for class 1 (Other)...
Generated 1000 samples for class 1
Generating samples for class 2 (Pickup)...
Generated 1000 samples for class 2
Generating samples for class 3 (Seden)...
Generated 1000 samples for class 3
Generating samples for class 4 (SUV)...
Generated 1000 samples for class 4

Class 0 (Hatchback) Metrics Statistics (1000 samples):
PSNR - Mean: 8.43, Median: 8.48, Mode: 8.60
SSIM - Mean: 0.0585, Median: 0.0537, Mode: 0.0460
JS   - Mean: 0.3738, Median: 0.3724, Mode: 0.3620

Class 1 (Other) Metrics Statistics (1000 samples):
PSNR - Mean: 9.80, Median: 9.85, Mode: 10.10
SSIM - Mean: 0.0756, Median: 0.0772, Mode: 0.0750
JS   - Mean: 0.3327, Median: 0.3327, Mode: 0.3340

Class 2 (Pickup) Metrics Statistics (1000 samples):
PSNR - Mean: 8.36, Med

In [3]:
    # Compute and print max, min, mean, median, mode for each class
    for class_label in range(num_classes):
        psnr_values = np.array([metric['psnr'] for metric in all_metrics[class_label]])
        ssim_values = np.array([metric['ssim'] for metric in all_metrics[class_label]])
        js_values = np.array([metric['js'] for metric in all_metrics[class_label]])

        # Compute max and min
        psnr_max = np.max(psnr_values)
        psnr_min = np.min(psnr_values)
        ssim_max = np.max(ssim_values)
        ssim_min = np.min(ssim_values)
        js_max = np.max(js_values)
        js_min = np.min(js_values)

        # Compute mean
        psnr_mean = np.mean(psnr_values)
        ssim_mean = np.mean(ssim_values)
        js_mean = np.mean(js_values)

        # Compute median
        psnr_median = np.median(psnr_values)
        ssim_median = np.median(ssim_values)
        js_median = np.median(js_values)

        # Compute mode (approximate by binning for continuous data)
        psnr_mode = stats.mode(np.round(psnr_values, 1), keepdims=False)[0]
        ssim_mode = stats.mode(np.round(ssim_values, 3), keepdims=False)[0]
        js_mode = stats.mode(np.round(js_values, 3), keepdims=False)[0]

        print(f"\nClass {class_label} ({class_names[class_label]}) Metrics Statistics (1000 samples):")
        print(f"PSNR - Max: {psnr_max:.2f}, Min: {psnr_min:.2f}, Mean: {psnr_mean:.2f}, Median: {psnr_median:.2f}, Mode: {psnr_mode:.2f}")
        print(f"SSIM - Max: {ssim_max:.4f}, Min: {ssim_min:.4f}, Mean: {ssim_mean:.4f}, Median: {ssim_median:.4f}, Mode: {ssim_mode:.4f}")
        print(f"JS   - Max: {js_max:.4f}, Min: {js_min:.4f}, Mean: {js_mean:.4f}, Median: {js_median:.4f}, Mode: {js_mode:.4f}")


Class 0 (Hatchback) Metrics Statistics (1000 samples):
PSNR - Max: 10.50, Min: 5.85, Mean: 8.43, Median: 8.48, Mode: 8.60
SSIM - Max: 0.1431, Min: 0.0125, Mean: 0.0585, Median: 0.0537, Mode: 0.0460
JS   - Max: 0.4676, Min: 0.2906, Mean: 0.3738, Median: 0.3724, Mode: 0.3620

Class 1 (Other) Metrics Statistics (1000 samples):
PSNR - Max: 11.23, Min: 6.93, Mean: 9.80, Median: 9.85, Mode: 10.10
SSIM - Max: 0.1404, Min: 0.0175, Mean: 0.0756, Median: 0.0772, Mode: 0.0750
JS   - Max: 0.4303, Min: 0.2495, Mean: 0.3327, Median: 0.3327, Mode: 0.3340

Class 2 (Pickup) Metrics Statistics (1000 samples):
PSNR - Max: 10.37, Min: 6.22, Mean: 8.36, Median: 8.35, Mode: 8.60
SSIM - Max: 0.1554, Min: 0.0027, Mean: 0.0418, Median: 0.0374, Mode: 0.0370
JS   - Max: 0.4608, Min: 0.2907, Mean: 0.3719, Median: 0.3703, Mode: 0.3570

Class 3 (Seden) Metrics Statistics (1000 samples):
PSNR - Max: 11.30, Min: 6.51, Mean: 8.79, Median: 8.45, Mode: 7.80
SSIM - Max: 0.1789, Min: 0.0117, Mean: 0.0788, Median: 0.0665,

In [4]:
##selecting_good_images_500

In [5]:
import os
import numpy as np
from PIL import Image

# Assuming all_metrics is available from the previous code
# If not, you'd need to recompute it by running the metrics computation part of the previous code
# all_metrics is a dictionary: {class_label: [{'image': 'sample_X.png', 'psnr': ..., 'ssim': ..., 'js': ...}, ...]}

# Define directories
base_dir = "generated_samples-arch-A5-1000-enhanced"
good_images_dir = "top_500_images_per_class"
os.makedirs(good_images_dir, exist_ok=True)

# Class names (for printing)
class_names = ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

# Define class-specific thresholds based on the provided statistics
thresholds = {
    0: {'psnr': 8.5, 'ssim': 0.06, 'js': 0.37},  # Hatchback
    1: {'psnr': 9.9, 'ssim': 0.08, 'js': 0.33},  # Other
    2: {'psnr': 8.4, 'ssim': 0.045, 'js': 0.37}, # Pickup
    3: {'psnr': 8.8, 'ssim': 0.08, 'js': 0.33},  # Seden
    4: {'psnr': 7.7, 'ssim': 0.04, 'js': 0.41}   # SUV
}

# Select top 500 images per class
top_images = {class_label: [] for class_label in range(num_classes)}
combined_scores = {class_label: [] for class_label in range(num_classes)}

for class_label in range(num_classes):
    class_dir = os.path.join(base_dir, str(class_label))
    good_class_dir = os.path.join(good_images_dir, str(class_label))
    os.makedirs(good_class_dir, exist_ok=True)

    # Apply class-specific thresholds
    good_candidates = []
    for metric in all_metrics[class_label]:
        avg_psnr = metric['psnr']
        avg_ssim = metric['ssim']
        avg_js = metric['js']
        thresh = thresholds[class_label]

        if avg_psnr > thresh['psnr'] and avg_ssim > thresh['ssim'] and avg_js < thresh['js']:
            good_candidates.append({
                'image': metric['image'],
                'psnr': avg_psnr,
                'ssim': avg_ssim,
                'js': avg_js
            })

    # Compute combined score for sorting (PSNR + SSIM * 10 - JS * 10)
    for metric in all_metrics[class_label]:
        score = metric['psnr'] + metric['ssim'] * 10 - metric['js'] * 10
        combined_scores[class_label].append({
            'image': metric['image'],
            'psnr': metric['psnr'],
            'ssim': metric['ssim'],
            'js': metric['js'],
            'score': score
        })

    # Sort by combined score (descending)
    combined_scores[class_label].sort(key=lambda x: x['score'], reverse=True)

    # If we have enough good candidates, select the top 500 by score
    if len(good_candidates) >= 500:
        good_candidates.sort(key=lambda x: (x['psnr'] + x['ssim'] * 10 - x['js'] * 10), reverse=True)
        top_images[class_label] = good_candidates[:500]
    else:
        # If fewer than 500 meet the criteria, take the top 500 by combined score
        top_images[class_label] = combined_scores[class_label][:500]
        print(f"Class {class_label} ({class_names[class_label]}): Only {len(good_candidates)} images met the criteria. Selecting top 500 by combined score.")

    # Save the top 500 images
    for img_data in top_images[class_label]:
        src_path = os.path.join(class_dir, img_data['image'])
        dst_path = os.path.join(good_class_dir, img_data['image'])
        pil_image = Image.open(src_path)
        pil_image.save(dst_path)

    # Print the number of selected images
    print(f"Class {class_label} ({class_names[class_label]}): Selected {len(top_images[class_label])} top images saved to {good_class_dir}")

    # Print top 10 and bottom 10 images based on combined score
    print(f"\nClass {class_label} ({class_names[class_label]}) Top 10 Images:")
    for i in range(min(10, len(combined_scores[class_label]))):
        img_data = combined_scores[class_label][i]
        print(f"Image: {img_data['image']}, PSNR: {img_data['psnr']:.2f}, SSIM: {img_data['ssim']:.4f}, JS: {img_data['js']:.4f}, Score: {img_data['score']:.2f}")

    print(f"\nClass {class_label} ({class_names[class_label]}) Bottom 10 Images:")
    for i in range(max(0, len(combined_scores[class_label]) - 10), len(combined_scores[class_label])):
        img_data = combined_scores[class_label][i]
        print(f"Image: {img_data['image']}, PSNR: {img_data['psnr']:.2f}, SSIM: {img_data['ssim']:.4f}, JS: {img_data['js']:.4f}, Score: {img_data['score']:.2f}")

Class 0 (Hatchback): Only 254 images met the criteria. Selecting top 500 by combined score.
Class 0 (Hatchback): Selected 500 top images saved to top_500_images_per_class\0

Class 0 (Hatchback) Top 10 Images:
Image: sample_975.png, PSNR: 9.77, SSIM: 0.1405, JS: 0.3013, Score: 8.17
Image: sample_230.png, PSNR: 10.34, SSIM: 0.1093, JS: 0.3274, Score: 8.15
Image: sample_50.png, PSNR: 10.50, SSIM: 0.1198, JS: 0.3556, Score: 8.14
Image: sample_401.png, PSNR: 10.36, SSIM: 0.1121, JS: 0.3361, Score: 8.12
Image: sample_431.png, PSNR: 10.35, SSIM: 0.1174, JS: 0.3474, Score: 8.05
Image: sample_624.png, PSNR: 9.57, SSIM: 0.1431, JS: 0.3099, Score: 7.90
Image: sample_619.png, PSNR: 9.95, SSIM: 0.1121, JS: 0.3266, Score: 7.81
Image: sample_547.png, PSNR: 9.53, SSIM: 0.1315, JS: 0.3138, Score: 7.71
Image: sample_327.png, PSNR: 10.20, SSIM: 0.0925, JS: 0.3410, Score: 7.71
Image: sample_28.png, PSNR: 10.01, SSIM: 0.1068, JS: 0.3393, Score: 7.69

Class 0 (Hatchback) Bottom 10 Images:
Image: sample_992.