In [4]:
import scipy
print(scipy.__version__)

1.15.3


In [2]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.stats import jensenshannon
import kagglehub

# Verify scipy version
import scipy
print(f"Using scipy version: {scipy.__version__}")
if int(scipy.__version__.split('.')[1]) < 7:
    raise ImportError("scipy version must be >= 1.7.0 to use jensenshannon. Please upgrade scipy using 'pip install --upgrade scipy'.")

# Define paths
generated_base_dir = r"C:\Users\hp\Desktop\fedavg\generated_samples-arch-A5"
good_images_dir = r"C:\Users\hp\Desktop\fedavg\good_images-arch-A5"
os.makedirs(good_images_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
num_classes = 5
original_dim = RESIZE * RESIZE * 3

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Load a subset of real images as references (5 per class)
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    # Convert tensors to numpy arrays and rescale to [0, 1]
    orig_np = orig.cpu().numpy() * 0.5 + 0.5  # Denormalize
    gen_np = gen.cpu().numpy() * 0.5 + 0.5    # Denormalize
    orig_np = np.transpose(orig_np, (1, 2, 0))  # (C, H, W) -> (H, W, C)
    gen_np = np.transpose(gen_np, (1, 2, 0))    # (C, H, W) -> (H, W, C)
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    # Compute PSNR
    psnr_score = psnr(orig_np, gen_np, data_range=1.0)

    # Compute SSIM (multichannel=True for RGB images)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    # Compute JS divergence per channel
    js_scores = []
    for c in range(3):  # For each RGB channel
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        # Normalize to probability distributions
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0  # Worst case if computation fails
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Process generated images and compute metrics
metrics_dict = {i: [] for i in range(num_classes)}  # Store metrics for each class
good_images = {i: [] for i in range(num_classes)}   # Store good images for each class

for class_label in range(num_classes):
    class_dir = os.path.join(generated_base_dir, str(class_label))
    good_class_dir = os.path.join(good_images_dir, str(class_label))
    os.makedirs(good_class_dir, exist_ok=True)

    if not os.path.exists(class_dir):
        print(f"Class directory {class_dir} not found, skipping class {class_label}.")
        continue

    # Reference images for this class
    ref_images = reference_images[class_label]  # List of 5 reference tensors

    # Process each generated image
    idx = 0
    while idx < 100:  # Attempt to process up to 100 samples per class
        img_path = os.path.join(class_dir, f"sample_{idx}.png")
        if not os.path.exists(img_path):
            print(f"Image {img_path} not found, stopping at index {idx} for class {class_label}.")
            break

        # Load and preprocess the generated image
        try:
            gen_image = Image.open(img_path).convert('RGB')
            gen_tensor = transform(gen_image)
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            idx += 1
            continue

        # Compute metrics against each reference image and average
        psnr_scores, ssim_scores, js_scores = [], [], []
        for ref_tensor in ref_images:
            try:
                psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
                psnr_scores.append(psnr_score)
                ssim_scores.append(ssim_score)
                js_scores.append(js_score)
            except Exception as e:
                print(f"Error computing metrics for image {img_path}: {str(e)}")
                psnr_scores.append(0.0)  # Worst case
                ssim_scores.append(0.0)  # Worst case
                js_scores.append(1.0)    # Worst case

        avg_psnr = np.mean(psnr_scores)
        avg_ssim = np.mean(ssim_scores)
        avg_js = np.mean(js_scores)

        metrics_dict[class_label].append({
            'image': f"sample_{idx}.png",
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'js': avg_js
        })

        # Check if the image meets the "good" criteria
        if avg_psnr > 30 and avg_ssim > 0.7 and avg_js < 0.3:
            # Save the good image to the new folder
            good_img_path = os.path.join(good_class_dir, f"sample_{idx}.png")
            gen_image.save(good_img_path)
            good_images[class_label].append(f"sample_{idx}.png")

        idx += 1

    print(f"\nClass {class_label} Metrics:")
    for metric in metrics_dict[class_label]:
        print(f"Image: {metric['image']}, PSNR: {metric['psnr']:.2f}, SSIM: {metric['ssim']:.4f}, JS: {metric['js']:.4f}")
    print(f"Good images for class {class_label}: {len(good_images[class_label])} images saved to {good_class_dir}")

# Summary of good images
total_good = sum(len(good_images[cls]) for cls in good_images)
print(f"\nTotal good images across all classes: {total_good}")

ImportError: cannot import name 'jensenshannon' from 'scipy.stats' (C:\Users\hp\Desktop\New folder\Lib\site-packages\scipy\stats\__init__.py)

In [3]:
pip install --upgrade scipy




In [5]:
import scipy
print(scipy.__version__)

1.15.3


In [7]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
# Correct the import path for jensenshannon
from scipy.spatial.distance import jensenshannon
import kagglehub

# Verify scipy version - keep this check for good practice, though importing from spatial.distance
# is more robust across versions where the function might be in both or just spatial.distance.
import scipy
print(f"Using scipy version: {scipy.__version__}")
# The check for > 1.7 is relevant if you strictly wanted it in scipy.stats.
# Since we import from spatial.distance, any version >= 1.6.0 should work.
# Let's adjust the check to reflect the actual requirement for scipy.spatial.distance.jensenshannon
from pkg_resources import parse_version
if parse_version(scipy.__version__) < parse_version("1.6.0"):
     raise ImportError(f"scipy version must be >= 1.6.0 to use jensenshannon from scipy.spatial.distance. Please upgrade scipy using 'pip install --upgrade scipy'. Current version: {scipy.__version__}")


# Define paths
generated_base_dir = r"C:\Users\hp\Desktop\fedavg\generated_samples-arch-A5"
good_images_dir = r"C:\Users\hp\Desktop\fedavg\good_images-arch-A5"
os.makedirs(good_images_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
num_classes = 5
original_dim = RESIZE * RESIZE * 3

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Load a subset of real images as references (5 per class)
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    # Convert tensors to numpy arrays and rescale to [0, 1]
    orig_np = orig.cpu().numpy() * 0.5 + 0.5  # Denormalize
    gen_np = gen.cpu().numpy() * 0.5 + 0.5    # Denormalize
    orig_np = np.transpose(orig_np, (1, 2, 0))  # (C, H, W) -> (H, W, C)
    gen_np = np.transpose(gen_np, (1, 2, 0))    # (C, H, W) -> (H, W, C)
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    # Compute PSNR
    psnr_score = psnr(orig_np, gen_np, data_range=1.0)

    # Compute SSIM (multichannel=True for RGB images)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    # Compute JS divergence per channel
    js_scores = []
    for c in range(3):  # For each RGB channel
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        # Normalize to probability distributions
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0  # Worst case if computation fails
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Process generated images and compute metrics
metrics_dict = {i: [] for i in range(num_classes)}  # Store metrics for each class
good_images = {i: [] for i in range(num_classes)}   # Store good images for each class

for class_label in range(num_classes):
    class_dir = os.path.join(generated_base_dir, str(class_label))
    good_class_dir = os.path.join(good_images_dir, str(class_label))
    os.makedirs(good_class_dir, exist_ok=True)

    if not os.path.exists(class_dir):
        print(f"Class directory {class_dir} not found, skipping class {class_label}.")
        continue

    # Reference images for this class
    ref_images = reference_images[class_label]  # List of 5 reference tensors

    # Process each generated image
    idx = 0
    while idx < 100:  # Attempt to process up to 100 samples per class
        img_path = os.path.join(class_dir, f"sample_{idx}.png")
        if not os.path.exists(img_path):
            print(f"Image {img_path} not found, stopping at index {idx} for class {class_label}.")
            break

        # Load and preprocess the generated image
        try:
            gen_image = Image.open(img_path).convert('RGB')
            gen_tensor = transform(gen_image)
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            idx += 1
            continue

        # Compute metrics against each reference image and average
        psnr_scores, ssim_scores, js_scores = [], [], []
        for ref_tensor in ref_images:
            try:
                psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
                psnr_scores.append(psnr_score)
                ssim_scores.append(ssim_score)
                js_scores.append(js_score)
            except Exception as e:
                print(f"Error computing metrics for image {img_path}: {str(e)}")
                psnr_scores.append(0.0)  # Worst case
                ssim_scores.append(0.0)  # Worst case
                js_scores.append(1.0)    # Worst case

        avg_psnr = np.mean(psnr_scores)
        avg_ssim = np.mean(ssim_scores)
        avg_js = np.mean(js_scores)

        metrics_dict[class_label].append({
            'image': f"sample_{idx}.png",
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'js': avg_js
        })

        # Check if the image meets the "good" criteria
        if avg_psnr > 11.8 and avg_ssim > 0.19 and avg_js < 0.24:
            # Save the good image to the new folder
            good_img_path = os.path.join(good_class_dir, f"sample_{idx}.png")
            gen_image.save(good_img_path)
            good_images[class_label].append(f"sample_{idx}.png")

        idx += 1

    print(f"\nClass {class_label} Metrics:")
    for metric in metrics_dict[class_label]:
        print(f"Image: {metric['image']}, PSNR: {metric['psnr']:.2f}, SSIM: {metric['ssim']:.4f}, JS: {metric['js']:.4f}")
    print(f"Good images for class {class_label}: {len(good_images[class_label])} images saved to {good_class_dir}")

# Summary of good images
total_good = sum(len(good_images[cls]) for cls in good_images)
print(f"\nTotal good images across all classes: {total_good}")

Using scipy version: 1.15.3
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

Class 0 Metrics:
Image: sample_0.png, PSNR: 11.93, SSIM: 0.1640, JS: 0.2541
Image: sample_1.png, PSNR: 10.80, SSIM: 0.1612, JS: 0.2956
Image: sample_2.png, PSNR: 10.92, SSIM: 0.1493, JS: 0.2460
Image: sample_3.png, PSNR: 11.50, SSIM: 0.2208, JS: 0.2346
Image: sample_4.png, PSNR: 12.32, SSIM: 0.1646, JS: 0.2506
Image: sample_5.png, PSNR: 12.13, SSIM: 0.1911, JS: 0.2549
Image: sample_6.png, PSNR: 11.24, SSIM: 0.1623, JS: 0.2438
Image: sample_7.png, PSNR: 12.06, SSIM: 0.2141, JS: 0.2492
Image: sample_8.png, PSNR: 10.21, SSIM: 0.1236, JS: 0.3296
Image: sample_9.png, PSNR: 9.25, SSIM: 0.1064, JS: 0.2737
Image: sample_10.png, PSNR: 10.65, SSIM: 0.1652, JS: 0.2691
Image: sample_11.png, PSNR: 11.34, SSIM: 0.1956, JS: 0.2737
Image: sample_12.png, PSNR: 10.19, SS

In [8]:
##3000epoch image of mine

In [10]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.spatial.distance import jensenshannon
import kagglehub

# Define paths
combined_image_path = r"C:\Users\hp\Desktop\fedavg\FL_VEHICLE_CVAE_IMPROVED\sample_epoch_3000.png"  # Replace with the actual path to your combined image
good_images_dir = r"C:\Users\hp\Desktop\fedavg\good_images_3000_mine"
os.makedirs(good_images_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
num_classes = 5
original_dim = RESIZE * RESIZE * 3

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Load a subset of real images as references (5 per class)
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    orig_np = orig.cpu().numpy() * 0.5 + 0.5
    gen_np = gen.cpu().numpy() * 0.5 + 0.5
    orig_np = np.transpose(orig_np, (1, 2, 0))
    gen_np = np.transpose(gen_np, (1, 2, 0))
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    psnr_score = psnr(orig_np, gen_np, data_range=1.0)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    js_scores = []
    for c in range(3):
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Load and split the combined image
combined_image = Image.open(combined_image_path).convert('RGB')
combined_width, combined_height = combined_image.size

# Assuming a 2x5 grid (2 rows, 5 columns), adjust based on your image layout
rows, cols = 2, 5
subimage_width = combined_width // cols  # Width of each subimage
subimage_height = combined_height // rows  # Height of each subimage

# Check if the subimages match the expected size (128x128)
if subimage_width != RESIZE or subimage_height != RESIZE:
    print(f"Warning: Subimage size is {subimage_width}x{subimage_height}, expected {RESIZE}x{RESIZE}. Resizing subimages.")
    subimage_transform = transforms.Compose([
        transforms.Resize((RESIZE, RESIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
else:
    subimage_transform = transform

# Split the combined image into 10 subimages
subimages = []
for i in range(rows):
    for j in range(cols):
        left = j * subimage_width
        upper = i * subimage_height
        right = (j + 1) * subimage_width
        lower = (i + 1) * subimage_height
        subimage = combined_image.crop((left, upper, right, lower))
        subimages.append(subimage)

# Process each subimage and compute metrics
class_label = 0  # Assuming all subimages belong to Class 0, adjust if needed
good_class_dir = os.path.join(good_images_dir, str(class_label))
os.makedirs(good_class_dir, exist_ok=True)

metrics_dict = []
good_images = []

# Reference images for this class
ref_images = reference_images[class_label]  # List of 5 reference tensors

# Process each subimage
for idx, subimage in enumerate(subimages):
    # Preprocess the subimage
    gen_tensor = subimage_transform(subimage)

    # Compute metrics against each reference image and average
    psnr_scores, ssim_scores, js_scores = [], [], []
    for ref_tensor in ref_images:
        try:
            psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
            psnr_scores.append(psnr_score)
            ssim_scores.append(ssim_score)
            js_scores.append(js_score)
        except Exception as e:
            print(f"Error computing metrics for subimage {idx}: {str(e)}")
            psnr_scores.append(0.0)
            ssim_scores.append(0.0)
            js_scores.append(1.0)

    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)
    avg_js = np.mean(js_scores)

    metrics_dict.append({
        'image': f"subimage_{idx}.png",
        'psnr': avg_psnr,
        'ssim': avg_ssim,
        'js': avg_js
    })

    # Check if the subimage meets the "good" criteria (using adjusted thresholds from previous response)
    if avg_psnr > 11.8 and avg_ssim > 0.19 and avg_js < 0.24:
        good_img_path = os.path.join(good_class_dir, f"subimage_{idx}.png")
        subimage.save(good_img_path)
        good_images.append(f"subimage_{idx}.png")

# Print metrics for the subimages
print(f"\nClass {class_label} Subimage Metrics:")
for metric in metrics_dict:
    print(f"Image: {metric['image']}, PSNR: {metric['psnr']:.2f}, SSIM: {metric['ssim']:.4f}, JS: {metric['js']:.4f}")
print(f"Good subimages for class {class_label}: {len(good_images)} images saved to {good_class_dir}")

Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

Class 0 Subimage Metrics:
Image: subimage_0.png, PSNR: 10.24, SSIM: 0.1015, JS: 0.3255
Image: subimage_1.png, PSNR: 11.08, SSIM: 0.0900, JS: 0.2812
Image: subimage_2.png, PSNR: 8.47, SSIM: 0.0787, JS: 0.3119
Image: subimage_3.png, PSNR: 11.58, SSIM: 0.1174, JS: 0.2795
Image: subimage_4.png, PSNR: 10.04, SSIM: 0.0954, JS: 0.3197
Image: subimage_5.png, PSNR: 11.00, SSIM: 0.1210, JS: 0.2870
Image: subimage_6.png, PSNR: 7.59, SSIM: 0.0181, JS: 0.6601
Image: subimage_7.png, PSNR: 7.53, SSIM: 0.0024, JS: 0.2439
Image: subimage_8.png, PSNR: 7.53, SSIM: 0.0024, JS: 0.2439
Image: subimage_9.png, PSNR: 7.53, SSIM: 0.0024, JS: 0.2439
Good subimages for class 0: 0 images saved to C:\Users\hp\Desktop\fedavg\good_images_3000_mine\0


In [11]:
##generating_500 samples insted of 100 of sneha code

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import numpy as np

# Hyperparameters (same as in your original code)
RESIZE = 128
intermediate_dim = 512
latent_dim = 256
num_classes = 5
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Enhanced Image Grid for Plotting (same as in your original code)
def image_grid(imgs, rows, cols, class_label, height=RESIZE, width=RESIZE):
    assert len(imgs) <= rows * cols
    grid = Image.new('RGB', size=(cols * width, rows * height), color=(255, 255, 255))
    for i, img in enumerate(imgs):
        x = (i % cols) * width + 2
        y = (i // cols) * height + 2
        grid.paste(img, box=(x, y))
    return grid

# Weight initialization function (same as in your original code)
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder class (same as in your original code)
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder class (same as in your original code)
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_size, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE class (same as in your original code)
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Instantiate the model
input_size = (8, 3, RESIZE, RESIZE)  # Batch size is irrelevant for generation, but needed for model definition
encoder = Encoder(input_size, intermediate_dim, latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, intermediate_dim, input_size, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Load the saved model weights
model_path = "cvae_vehicle_final.pth"
try:
    cvae.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded CVAE model from {model_path}")
except Exception as e:
    print(f"Error loading model: {str(e)}")
    exit()

# Define class names (you can update this list based on your dataset)
class_names = ['Class0', 'Class1', 'Class2', 'Class3', 'Class4']  # Update with actual class names if known

# Modified generate_samples_labelwise function to generate 500 samples per class
def generate_samples_labelwise(cvae, num_samples_per_class, base_dir, latent_dim, device):
    cvae.eval()
    os.makedirs(base_dir, exist_ok=True)
    with torch.no_grad():
        for class_label in range(num_classes):
            print(f"Generating samples for class {class_label}...")
            label_tensor = torch.tensor([class_label]).repeat(num_samples_per_class).to(device)
            one_hot_labels = F.one_hot(label_tensor, num_classes=num_classes).float().to(device)
            
            # Generate samples in smaller batches to avoid memory issues
            batch_size = 50  # Adjust based on your GPU memory
            images = []
            for start_idx in range(0, num_samples_per_class, batch_size):
                end_idx = min(start_idx + batch_size, num_samples_per_class)
                current_batch_size = end_idx - start_idx
                z = torch.randn(current_batch_size, latent_dim).to(device)
                generated_samples = cvae.decoder(z, one_hot_labels[start_idx:end_idx])

                # Process and save each sample
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                for idx, sample in enumerate(generated_samples):
                    global_idx = start_idx + idx
                    sample = sample.cpu().detach().numpy()
                    sample = sample * 0.5 + 0.5  # Denormalize
                    sample = np.nan_to_num(sample, nan=0.0, posinf=0.0, neginf=0.0)
                    sample = (255 * sample).astype(np.uint8)
                    sample = np.transpose(sample, (1, 2, 0))
                    pil_image = Image.fromarray(sample).convert('RGB')
                    pil_image.save(os.path.join(class_dir, f"sample_{global_idx}.png"))
                    if global_idx < 32:  # Save first 32 for the grid
                        images.append(pil_image)

            # Save a grid of the first 32 images
            if images:
                grid = image_grid(images, rows=4, cols=8, class_label=class_names[class_label])
                grid.save(os.path.join(class_dir, f"grid_{class_names[class_label]}.png"))
            print(f"Generated {num_samples_per_class} samples for class {class_label}")

# Generate 500 samples per class
base_dir = "generated_samples-arch-A5-500"  # New directory to avoid overwriting previous samples
try:
    generate_samples_labelwise(cvae, num_samples_per_class=500, base_dir=base_dir, latent_dim=latent_dim, device=device)
    print("Finished generating 500 samples per class")
except Exception as e:
    print(f"Error during sample generation: {str(e)}")

# Clear GPU memory
if device.type == 'cuda':
    torch.cuda.empty_cache()

Using device: cuda:1
GPU Name: NVIDIA RTX A5000
GPU Memory Allocated: 0.00 MB


  cvae.load_state_dict(torch.load(model_path, map_location=device))


Loaded CVAE model from cvae_vehicle_final.pth
Generating samples for class 0...
Generated 500 samples for class 0
Generating samples for class 1...
Generated 500 samples for class 1
Generating samples for class 2...
Generated 500 samples for class 2
Generating samples for class 3...
Generated 500 samples for class 3
Generating samples for class 4...
Generated 500 samples for class 4
Finished generating 500 samples per class


In [16]:
##generating 500 Images and increase quality and sorting of sneha code

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageEnhance, ImageFilter
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.spatial.distance import jensenshannon
import kagglehub

# Hyperparameters (same as in your original code)
RESIZE = 128
intermediate_dim = 512
latent_dim = 256
num_classes = 5
device = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cuda')

# Print GPU information
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

# Custom Dataset for Kaggle Vehicle Type Image Dataset (for reference images)
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Enhanced Image Grid for Plotting
def image_grid(imgs, rows, cols, class_label, height=RESIZE, width=RESIZE):
    assert len(imgs) <= rows * cols
    grid = Image.new('RGB', size=(cols * width, rows * height), color=(255, 255, 255))
    for i, img in enumerate(imgs):
        x = (i % cols) * width + 2
        y = (i // cols) * height + 2
        grid.paste(img, box=(x, y))
    return grid

# Weight initialization function
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

# Encoder class (same as in your original code)
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_dim, latent_dim, num_classes):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.fc_hidden = nn.Linear(1024 * 2 * 2 + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.apply(init_weights)

    def forward(self, x, y):
        x = self.encoder_cnn(x)
        x = x.view(x.size(0), -1)
        x_with_y = torch.cat([x, y], dim=-1)
        h = F.relu(self.fc_hidden(x_with_y))
        h = self.dropout(h)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z_logvar = torch.clamp(z_logvar, min=-10, max=10)
        return z_mean, z_logvar

# Decoder class (same as in your original code)
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_size, num_classes):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_classes, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc_to_cnn = nn.Linear(hidden_dim, 1024 * 2 * 2)
        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(1024, 1024, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )
        self.apply(init_weights)

    def forward(self, z, y):
        z_with_y = torch.cat([z, y], dim=-1)
        h = F.relu(self.fc(z_with_y))
        h = self.dropout(h)
        h = F.relu(self.fc_to_cnn(h))
        h = h.view(-1, 1024, 2, 2)
        x_reconstructed = self.decoder_cnn(h)
        x_reconstructed = torch.clamp(x_reconstructed, min=-1, max=1)
        return x_reconstructed

# Conditional VAE class
class ConditionalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(ConditionalVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def forward(self, data, y):
        z_mean, z_logvar = self.encoder(data, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Function to enhance image quality (sharpen, adjust contrast, brightness, etc.)
def enhance_image(pil_image):
    # Apply Unsharp Masking for additional sharpness
    pil_image = pil_image.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
    # Enhance sharpness
    pil_image = ImageEnhance.Sharpness(pil_image).enhance(2.0)  # Increase sharpness by a factor of 2
    # Enhance contrast
    pil_image = ImageEnhance.Contrast(pil_image).enhance(1.3)  # Increase contrast by 30%
    # Enhance brightness
    pil_image = ImageEnhance.Brightness(pil_image).enhance(1.1)  # Slightly increase brightness by 10%
    # Optional: Super-resolution-like effect by upscaling and downscaling
    pil_image = pil_image.resize((RESIZE * 2, RESIZE * 2), Image.Resampling.LANCZOS)
    pil_image = pil_image.resize((RESIZE, RESIZE), Image.Resampling.LANCZOS)
    return pil_image

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    orig_np = orig.cpu().numpy() * 0.5 + 0.5
    gen_np = gen.cpu().numpy() * 0.5 + 0.5
    orig_np = np.transpose(orig_np, (1, 2, 0))
    gen_np = np.transpose(gen_np, (1, 2, 0))
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    psnr_score = psnr(orig_np, gen_np, data_range=1.0)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    js_scores = []
    for c in range(3):
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Instantiate the model
input_size = (8, 3, RESIZE, RESIZE)  # Batch size is irrelevant for generation
encoder = Encoder(input_size, intermediate_dim, latent_dim, num_classes).to(device)
decoder = Decoder(latent_dim, intermediate_dim, input_size, num_classes).to(device)
cvae = ConditionalVAE(encoder, decoder).to(device)

# Load the saved model weights
model_path = "cvae_vehicle_final.pth"
try:
    cvae.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded CVAE model from {model_path}")
except Exception as e:
    print(f"Error loading model: {str(e)}")
    exit()

# Load reference images for metrics computation
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Define class names (update with actual class names if known)
class_names = ['Class0', 'Class1', 'Class2', 'Class3', 'Class4']

# Modified generate_samples_labelwise function with enhancements
def generate_samples_labelwise(cvae, num_samples_per_class, base_dir, good_images_dir, latent_dim, device):
    cvae.eval()
    os.makedirs(base_dir, exist_ok=True)
    os.makedirs(good_images_dir, exist_ok=True)

    # Initialize dictionaries to store metrics and good images
    all_metrics = {class_label: [] for class_label in range(num_classes)}
    good_images = {class_label: [] for class_label in range(num_classes)}

    with torch.no_grad():
        for class_label in range(num_classes):
            print(f"Generating and processing samples for class {class_label}...")
            label_tensor = torch.tensor([class_label]).repeat(num_samples_per_class).to(device)
            one_hot_labels = F.one_hot(label_tensor, num_classes=num_classes).float().to(device)

            # Generate samples in smaller batches to avoid memory issues
            batch_size = 50
            idx = 0
            for start_idx in range(0, num_samples_per_class, batch_size):
                end_idx = min(start_idx + batch_size, num_samples_per_class)
                current_batch_size = end_idx - start_idx

                # Use truncation trick: Sample from a normal distribution with reduced variance
                z = torch.randn(current_batch_size, latent_dim).to(device) * 0.7  # Reduce variance
                z = torch.clamp(z, -2.0, 2.0)  # Truncation trick to avoid extreme values
                generated_samples = cvae.decoder(z, one_hot_labels[start_idx:end_idx])

                # Process and save each sample
                class_dir = os.path.join(base_dir, str(class_label))
                os.makedirs(class_dir, exist_ok=True)
                good_class_dir = os.path.join(good_images_dir, str(class_label))
                os.makedirs(good_class_dir, exist_ok=True)

                images_for_grid = []
                for batch_idx, sample in enumerate(generated_samples):
                    global_idx = start_idx + batch_idx
                    # Convert tensor to numpy array
                    sample = sample.cpu().detach().numpy()
                    sample = sample * 0.5 + 0.5  # Denormalize
                    sample = np.nan_to_num(sample, nan=0.0, posinf=0.0, neginf=0.0)
                    sample = (255 * sample).astype(np.uint8)
                    sample = np.transpose(sample, (1, 2, 0))
                    pil_image = Image.fromarray(sample).convert('RGB')

                    # Enhance the image
                    pil_image = enhance_image(pil_image)

                    # Save the enhanced image
                    pil_image.save(os.path.join(class_dir, f"sample_{global_idx}.png"))
                    if global_idx < 32:
                        images_for_grid.append(pil_image)

                    # Convert back to tensor for metrics computation
                    gen_tensor = transform(pil_image).to(device)

                    # Compute metrics against reference images
                    ref_images = reference_images[class_label]
                    psnr_scores, ssim_scores, js_scores = [], [], []
                    for ref_tensor in ref_images:
                        psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
                        psnr_scores.append(psnr_score)
                        ssim_scores.append(ssim_score)
                        js_scores.append(js_score)

                    avg_psnr = np.mean(psnr_scores)
                    avg_ssim = np.mean(ssim_scores)
                    avg_js = np.mean(js_scores)

                    all_metrics[class_label].append({
                        'image': f"sample_{global_idx}.png",
                        'psnr': avg_psnr,
                        'ssim': avg_ssim,
                        'js': avg_js
                    })

                    # Check if the image meets the "good" criteria (initial thresholds, will adjust later)
                    if avg_psnr > 11.8 and avg_ssim > 0.19 and avg_js < 0.24:
                        good_img_path = os.path.join(good_class_dir, f"sample_{global_idx}.png")
                        pil_image.save(good_img_path)
                        good_images[class_label].append(f"sample_{global_idx}.png")

                idx += current_batch_size

            # Save a grid of the first 32 images
            if images_for_grid:
                grid = image_grid(images_for_grid, rows=4, cols=8, class_label=class_names[class_label])
                grid.save(os.path.join(class_dir, f"grid_{class_names[class_label]}.png"))
            print(f"Generated {num_samples_per_class} samples for class {class_label}")

    return all_metrics, good_images

# Generate 500 samples per class and compute metrics
base_dir = "generated_samples-arch-A5-500-enhanced"
good_images_dir = "good_images-arch-A5-enhanced"
try:
    all_metrics, good_images = generate_samples_labelwise(
        cvae, num_samples_per_class=500, base_dir=base_dir, 
        good_images_dir=good_images_dir, latent_dim=latent_dim, device=device
    )
    print("Finished generating 500 enhanced samples per class and computing metrics")
except Exception as e:
    print(f"Error during sample generation: {str(e)}")

# Step 3: Analyze Metrics and Adjust Thresholds

# Print metrics for each class
for class_label in range(num_classes):
    print(f"\nClass {class_label} Metrics:")
    for metric in all_metrics[class_label]:
        print(f"Image: {metric['image']}, PSNR: {metric['psnr']:.2f}, SSIM: {metric['ssim']:.4f}, JS: {metric['js']:.4f}")
    print(f"Initially selected good images for class {class_label}: {len(good_images[class_label])}")

# Collect all metrics to determine new thresholds
all_psnr = [metric['psnr'] for class_label in range(num_classes) for metric in all_metrics[class_label]]
all_ssim = [metric['ssim'] for class_label in range(num_classes) for metric in all_metrics[class_label]]
all_js = [metric['js'] for class_label in range(num_classes) for metric in all_metrics[class_label]]

# Compute percentiles for new thresholds (aim for top 20% for PSNR and SSIM, bottom 20% for JS)
psnr_threshold = np.percentile(all_psnr, 80)  # 80th percentile (top 20%)
ssim_threshold = np.percentile(all_ssim, 80)  # 80th percentile (top 20%)
js_threshold = np.percentile(all_js, 20)      # 20th percentile (bottom 20%)

print(f"\nNew Thresholds Based on Percentiles:")
print(f"PSNR > {psnr_threshold:.2f}")
print(f"SSIM > {ssim_threshold:.4f}")
print(f"JS < {js_threshold:.4f}")

# Step 4: Re-filter Good Images with New Thresholds
good_images_new = {class_label: [] for class_label in range(num_classes)}
for class_label in range(num_classes):
    class_dir = os.path.join(base_dir, str(class_label))
    good_class_dir = os.path.join(good_images_dir, str(class_label))
    os.makedirs(good_class_dir, exist_ok=True)

    for metric in all_metrics[class_label]:
        image_name = metric['image']
        avg_psnr = metric['psnr']
        avg_ssim = metric['ssim']
        avg_js = metric['js']

        # Check if the image meets the new criteria
        if avg_psnr > psnr_threshold and avg_ssim > ssim_threshold and avg_js < js_threshold:
            # Copy the image to the good images directory (if not already there)
            src_path = os.path.join(class_dir, image_name)
            dst_path = os.path.join(good_class_dir, image_name)
            if not os.path.exists(dst_path):
                pil_image = Image.open(src_path)
                pil_image.save(dst_path)
            good_images_new[class_label].append(image_name)

# Print the final number of good images
for class_label in range(num_classes):
    print(f"Final good images for class {class_label}: {len(good_images_new[class_label])} images saved to {os.path.join(good_images_dir, str(class_label))}")

# Clear GPU memory
if device.type == 'cuda':
    torch.cuda.empty_cache()

Using device: cuda:1
GPU Name: NVIDIA RTX A5000
GPU Memory Allocated: 0.00 MB


  cvae.load_state_dict(torch.load(model_path, map_location=device))


Loaded CVAE model from cvae_vehicle_final.pth
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']
Generating and processing samples for class 0...
Generated 500 samples for class 0
Generating and processing samples for class 1...
Generated 500 samples for class 1
Generating and processing samples for class 2...
Generated 500 samples for class 2
Generating and processing samples for class 3...
Generated 500 samples for class 3
Generating and processing samples for class 4...
Generated 500 samples for class 4
Finished generating 500 enhanced samples per class and computing metrics

Class 0 Metrics:
Image: sample_0.png, PSNR: 7.54, SSIM: 0.0163, JS: 0.3810
Image: sample_1.png, PSNR: 6.92, SSIM: 0.0207, JS: 0.3628
Image: sample_2.png, PSNR: 9.76, SSIM: 0.0919, JS: 0.3765
Image: sample_3.png, PSNR: 9.35, SSIM: 0.0765, JS: 0.3468
Image: sample_4.png, PSNR: 7.78, SSIM: 0.0290, JS: 0.4008
Image: sample_5.png, PSNR: 6.81, SSIM: 0.0345, JS: 0.3577
Image:

In [17]:
##SORTING GOOG IMAGES OF 500SAMPLES_SNEHA_CODE

In [18]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
# Correct the import path for jensenshannon
from scipy.spatial.distance import jensenshannon
import kagglehub

# Verify scipy version - keep this check for good practice, though importing from spatial.distance
# is more robust across versions where the function might be in both or just spatial.distance.
import scipy
print(f"Using scipy version: {scipy.__version__}")
# The check for > 1.7 is relevant if you strictly wanted it in scipy.stats.
# Since we import from spatial.distance, any version >= 1.6.0 should work.
# Let's adjust the check to reflect the actual requirement for scipy.spatial.distance.jensenshannon
from pkg_resources import parse_version
if parse_version(scipy.__version__) < parse_version("1.6.0"):
     raise ImportError(f"scipy version must be >= 1.6.0 to use jensenshannon from scipy.spatial.distance. Please upgrade scipy using 'pip install --upgrade scipy'. Current version: {scipy.__version__}")


# Define paths
generated_base_dir = r"C:\Users\hp\Desktop\fedavg\generated_samples-arch-A5-500"
good_images_dir = r"C:\Users\hp\Desktop\fedavg\good_images-arch-A5-500"
os.makedirs(good_images_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
num_classes = 5
original_dim = RESIZE * RESIZE * 3

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Load a subset of real images as references (5 per class)
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    # Convert tensors to numpy arrays and rescale to [0, 1]
    orig_np = orig.cpu().numpy() * 0.5 + 0.5  # Denormalize
    gen_np = gen.cpu().numpy() * 0.5 + 0.5    # Denormalize
    orig_np = np.transpose(orig_np, (1, 2, 0))  # (C, H, W) -> (H, W, C)
    gen_np = np.transpose(gen_np, (1, 2, 0))    # (C, H, W) -> (H, W, C)
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    # Compute PSNR
    psnr_score = psnr(orig_np, gen_np, data_range=1.0)

    # Compute SSIM (multichannel=True for RGB images)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    # Compute JS divergence per channel
    js_scores = []
    for c in range(3):  # For each RGB channel
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        # Normalize to probability distributions
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0  # Worst case if computation fails
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Process generated images and compute metrics
metrics_dict = {i: [] for i in range(num_classes)}  # Store metrics for each class
good_images = {i: [] for i in range(num_classes)}   # Store good images for each class

for class_label in range(num_classes):
    class_dir = os.path.join(generated_base_dir, str(class_label))
    good_class_dir = os.path.join(good_images_dir, str(class_label))
    os.makedirs(good_class_dir, exist_ok=True)

    if not os.path.exists(class_dir):
        print(f"Class directory {class_dir} not found, skipping class {class_label}.")
        continue

    # Reference images for this class
    ref_images = reference_images[class_label]  # List of 5 reference tensors

    # Process each generated image
    idx = 0
    while idx < 500:  # Attempt to process up to 100 samples per class
        img_path = os.path.join(class_dir, f"sample_{idx}.png")
        if not os.path.exists(img_path):
            print(f"Image {img_path} not found, stopping at index {idx} for class {class_label}.")
            break

        # Load and preprocess the generated image
        try:
            gen_image = Image.open(img_path).convert('RGB')
            gen_tensor = transform(gen_image)
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            idx += 1
            continue

        # Compute metrics against each reference image and average
        psnr_scores, ssim_scores, js_scores = [], [], []
        for ref_tensor in ref_images:
            try:
                psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
                psnr_scores.append(psnr_score)
                ssim_scores.append(ssim_score)
                js_scores.append(js_score)
            except Exception as e:
                print(f"Error computing metrics for image {img_path}: {str(e)}")
                psnr_scores.append(0.0)  # Worst case
                ssim_scores.append(0.0)  # Worst case
                js_scores.append(1.0)    # Worst case

        avg_psnr = np.mean(psnr_scores)
        avg_ssim = np.mean(ssim_scores)
        avg_js = np.mean(js_scores)

        metrics_dict[class_label].append({
            'image': f"sample_{idx}.png",
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'js': avg_js
        })

        # Check if the image meets the "good" criteria
        if avg_psnr > 11.8 and avg_ssim > 0.19 and avg_js < 0.24:
            # Save the good image to the new folder
            good_img_path = os.path.join(good_class_dir, f"sample_{idx}.png")
            gen_image.save(good_img_path)
            good_images[class_label].append(f"sample_{idx}.png")

        idx += 1

    print(f"\nClass {class_label} Metrics:")
    for metric in metrics_dict[class_label]:
        print(f"Image: {metric['image']}, PSNR: {metric['psnr']:.2f}, SSIM: {metric['ssim']:.4f}, JS: {metric['js']:.4f}")
    print(f"Good images for class {class_label}: {len(good_images[class_label])} images saved to {good_class_dir}")

# Summary of good images
total_good = sum(len(good_images[cls]) for cls in good_images)
print(f"\nTotal good images across all classes: {total_good}")

Using scipy version: 1.15.3
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

Class 0 Metrics:
Image: sample_0.png, PSNR: 11.75, SSIM: 0.1817, JS: 0.2421
Image: sample_1.png, PSNR: 11.37, SSIM: 0.1130, JS: 0.2580
Image: sample_2.png, PSNR: 12.03, SSIM: 0.2295, JS: 0.2540
Image: sample_3.png, PSNR: 11.88, SSIM: 0.1560, JS: 0.2490
Image: sample_4.png, PSNR: 11.25, SSIM: 0.1965, JS: 0.2783
Image: sample_5.png, PSNR: 10.42, SSIM: 0.1273, JS: 0.2967
Image: sample_6.png, PSNR: 10.91, SSIM: 0.1223, JS: 0.2909
Image: sample_7.png, PSNR: 11.46, SSIM: 0.1890, JS: 0.2681
Image: sample_8.png, PSNR: 12.32, SSIM: 0.1999, JS: 0.2613
Image: sample_9.png, PSNR: 9.44, SSIM: 0.1148, JS: 0.3252
Image: sample_10.png, PSNR: 10.61, SSIM: 0.1197, JS: 0.2362
Image: sample_11.png, PSNR: 12.53, SSIM: 0.1955, JS: 0.2415
Image: sample_12.png, PSNR: 10.79, SS

In [20]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageEnhance, ImageFilter
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.spatial.distance import jensenshannon
import kagglehub
import scipy
print(f"Using scipy version: {scipy.__version__}")
from pkg_resources import parse_version
if parse_version(scipy.__version__) < parse_version("1.6.0"):
    raise ImportError(f"scipy version must be >= 1.6.0 for jensenshannon. Current version: {scipy.__version__}")

# Define paths
generated_base_dir = r"C:\Users\hp\Desktop\fedavg\generated_samples-arch-A5-500"
good_images_dir = r"C:\Users\hp\Desktop\fedavg\good_images-arch-A5-500"
os.makedirs(good_images_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
num_classes = 5
original_dim = RESIZE * RESIZE * 3

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Load a subset of real images as references (5 per class)
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Enhanced image processing function with balanced adjustments
def enhance_image(pil_image):
    # Apply a light denoising filter to reduce noise (improves PSNR)
    pil_image = pil_image.filter(ImageFilter.MedianFilter(size=3))
    
    # Apply moderate Unsharp Masking for sharpness (reduced to avoid artifacts)
    pil_image = pil_image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=100, threshold=3))
    
    # Enhance sharpness moderately (reduced factor to avoid over-sharpening)
    pil_image = ImageEnhance.Sharpness(pil_image).enhance(1.5)  # Reduced from 2.0 to 1.5
    
    # Enhance contrast slightly (helps with SSIM)
    pil_image = ImageEnhance.Contrast(pil_image).enhance(1.2)  # Reduced from 1.3 to 1.2
    
    # Enhance brightness very slightly (avoid over-brightening)
    pil_image = ImageEnhance.Brightness(pil_image).enhance(1.05)  # Reduced from 1.1 to 1.05
    
    # Enhance color to improve fidelity (helps with SSIM and JS)
    pil_image = ImageEnhance.Color(pil_image).enhance(1.1)
    
    # Optional: Super-resolution-like effect with careful scaling
    pil_image = pil_image.resize((RESIZE * 2, RESIZE * 2), Image.Resampling.LANCZOS)
    pil_image = pil_image.resize((RESIZE, RESIZE), Image.Resampling.LANCZOS)
    
    return pil_image

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    orig_np = orig.cpu().numpy() * 0.5 + 0.5
    gen_np = gen.cpu().numpy() * 0.5 + 0.5
    orig_np = np.transpose(orig_np, (1, 2, 0))
    gen_np = np.transpose(gen_np, (1, 2, 0))
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    psnr_score = psnr(orig_np, gen_np, data_range=1.0)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    js_scores = []
    for c in range(3):
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Process generated images, enhance them, and compute metrics
metrics_dict = {i: [] for i in range(num_classes)}
all_images = {i: [] for i in range(num_classes)}  # Store all images with metrics
good_images = {i: [] for i in range(num_classes)}  # Store good images

for class_label in range(num_classes):
    class_dir = os.path.join(generated_base_dir, str(class_label))
    good_class_dir = os.path.join(good_images_dir, str(class_label))
    os.makedirs(good_class_dir, exist_ok=True)

    if not os.path.exists(class_dir):
        print(f"Class directory {class_dir} not found, skipping class {class_label}.")
        continue

    ref_images = reference_images[class_label]
    idx = 0
    while idx < 500:
        img_path = os.path.join(class_dir, f"sample_{idx}.png")
        if not os.path.exists(img_path):
            print(f"Image {img_path} not found, stopping at index {idx} for class {class_label}.")
            break

        try:
            gen_image = Image.open(img_path).convert('RGB')
            # Enhance the image
            gen_image = enhance_image(gen_image)
            gen_tensor = transform(gen_image)
        except Exception as e:
            print(f"Error processing image {img_path}: {str(e)}")
            idx += 1
            continue

        # Compute metrics
        psnr_scores, ssim_scores, js_scores = [], [], []
        for ref_tensor in ref_images:
            try:
                psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
                psnr_scores.append(psnr_score)
                ssim_scores.append(ssim_score)
                js_scores.append(js_score)
            except Exception as e:
                print(f"Error computing metrics for image {img_path}: {str(e)}")
                psnr_scores.append(0.0)
                ssim_scores.append(0.0)
                js_scores.append(1.0)

        avg_psnr = np.mean(psnr_scores)
        avg_ssim = np.mean(ssim_scores)
        avg_js = np.mean(js_scores)

        all_images[class_label].append({
            'image': f"sample_{idx}.png",
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'js': avg_js,
            'pil_image': gen_image  # Store the enhanced image for saving if needed
        })

        idx += 1

    # Sort images by a combined score to select the top 100
    # Combined score: Higher PSNR and SSIM, lower JS
    all_images[class_label].sort(key=lambda x: (x['psnr'] + x['ssim'] * 10 - x['js'] * 10), reverse=True)

    # Select top 100 images and save those that meet the criteria
    selected_images = all_images[class_label][:100]  # Take top 100
    for metric in selected_images:
        avg_psnr = metric['psnr']
        avg_ssim = metric['ssim']
        avg_js = metric['js']
        img_name = metric['image']
        gen_image = metric['pil_image']

        # Save metrics for display
        metrics_dict[class_label].append({
            'image': img_name,
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'js': avg_js
        })

        # Check if the image meets the "good" criteria
        if avg_psnr > 11.8 and avg_ssim > 0.19 and avg_js < 0.24:
            good_img_path = os.path.join(good_class_dir, img_name)
            gen_image.save(good_img_path)
            good_images[class_label].append(img_name)

    print(f"\nClass {class_label} Metrics (Top 100):")
    for metric in metrics_dict[class_label][:100]:  # Display top 100
        print(f"Image: {metric['image']}, PSNR: {metric['psnr']:.2f}, SSIM: {metric['ssim']:.4f}, JS: {metric['js']:.4f}")
    print(f"Good images for class {class_label}: {len(good_images[class_label])} images saved to {good_class_dir}")

# Summary of good images
total_good = sum(len(good_images[cls]) for cls in good_images)
print(f"\nTotal good images across all classes: {total_good}")

Using scipy version: 1.15.3
Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

Class 0 Metrics (Top 100):
Image: sample_420.png, PSNR: 11.33, SSIM: 0.1461, JS: 0.2814
Image: sample_180.png, PSNR: 11.39, SSIM: 0.1491, JS: 0.3037
Image: sample_13.png, PSNR: 11.26, SSIM: 0.1609, JS: 0.3033
Image: sample_42.png, PSNR: 10.73, SSIM: 0.1888, JS: 0.2793
Image: sample_24.png, PSNR: 10.75, SSIM: 0.1778, JS: 0.2847
Image: sample_236.png, PSNR: 11.29, SSIM: 0.1298, JS: 0.2926
Image: sample_11.png, PSNR: 11.27, SSIM: 0.1341, JS: 0.2980
Image: sample_179.png, PSNR: 10.79, SSIM: 0.1511, JS: 0.2698
Image: sample_100.png, PSNR: 10.77, SSIM: 0.1607, JS: 0.2778
Image: sample_229.png, PSNR: 11.16, SSIM: 0.1498, JS: 0.3077
Image: sample_417.png, PSNR: 10.98, SSIM: 0.1588, JS: 0.3004
Image: sample_499.png, PSNR: 11.15, SSIM: 0.1299, JS: 0.2899
Image: s

In [21]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageEnhance, ImageFilter
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.spatial.distance import jensenshannon
import kagglehub

# Define paths
generated_base_dir = r"C:\Users\hp\Desktop\fedavg\generated_samples-arch-A5-500"
good_images_dir = r"C:\Users\hp\Desktop\fedavg\good_images-arch-A5-500"
os.makedirs(good_images_dir, exist_ok=True)

# Hyperparameters
RESIZE = 128
num_classes = 5
original_dim = RESIZE * RESIZE * 3

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Load a subset of real images as references (5 per class)
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Collect reference images (5 per class)
reference_images = {i: [] for i in range(num_classes)}
for sample in dataset:
    label = sample['label']
    if len(reference_images[label]) < 5:
        reference_images[label].append(sample['image'])
    if all(len(refs) >= 5 for refs in reference_images.values()):
        break

# Enhanced image processing function with focus on SSIM
def enhance_image(pil_image):
    # Reduce denoising to preserve structural details (improves SSIM)
    pil_image = pil_image.filter(ImageFilter.MedianFilter(size=1))  # Reduced from size=3
    
    # Apply edge-preserving sharpening
    pil_image = pil_image.filter(ImageFilter.UnsharpMask(radius=1.5, percent=120, threshold=3))  # Increased percent to 120
    
    # Enhance sharpness moderately
    pil_image = ImageEnhance.Sharpness(pil_image).enhance(1.5)
    
    # Increase contrast to emphasize structure (improves SSIM)
    pil_image = ImageEnhance.Contrast(pil_image).enhance(1.3)  # Increased from 1.2 to 1.3
    
    # Keep brightness adjustment minimal
    pil_image = ImageEnhance.Brightness(pil_image).enhance(1.05)
    
    # Enhance color to improve fidelity
    pil_image = ImageEnhance.Color(pil_image).enhance(1.1)
    
    # Super-resolution-like effect
    pil_image = pil_image.resize((RESIZE * 2, RESIZE * 2), Image.Resampling.LANCZOS)
    pil_image = pil_image.resize((RESIZE, RESIZE), Image.Resampling.LANCZOS)
    
    return pil_image

# Function to compute PSNR, SSIM, and JS divergence
def compute_metrics_for_image(orig, gen):
    orig_np = orig.cpu().numpy() * 0.5 + 0.5
    gen_np = gen.cpu().numpy() * 0.5 + 0.5
    orig_np = np.transpose(orig_np, (1, 2, 0))
    gen_np = np.transpose(gen_np, (1, 2, 0))
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    psnr_score = psnr(orig_np, gen_np, data_range=1.0)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    js_scores = []
    for c in range(3):
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Process generated images, enhance them, and compute metrics
metrics_dict = {i: [] for i in range(num_classes)}
all_images = {i: [] for i in range(num_classes)}  # Store all images with metrics
good_images = {i: [] for i in range(num_classes)}  # Store good images

for class_label in range(num_classes):
    class_dir = os.path.join(generated_base_dir, str(class_label))
    good_class_dir = os.path.join(good_images_dir, str(class_label))
    os.makedirs(good_class_dir, exist_ok=True)

    if not os.path.exists(class_dir):
        print(f"Class directory {class_dir} not found, skipping class {class_label}.")
        continue

    ref_images = reference_images[class_label]
    idx = 0
    while idx < 500:
        img_path = os.path.join(class_dir, f"sample_{idx}.png")
        if not os.path.exists(img_path):
            print(f"Image {img_path} not found, stopping at index {idx} for class {class_label}.")
            break

        try:
            gen_image = Image.open(img_path).convert('RGB')
            # Enhance the image
            gen_image = enhance_image(gen_image)
            gen_tensor = transform(gen_image)
        except Exception as e:
            print(f"Error processing image {img_path}: {str(e)}")
            idx += 1
            continue

        # Compute metrics
        psnr_scores, ssim_scores, js_scores = [], [], []
        for ref_tensor in ref_images:
            try:
                psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_tensor, gen_tensor)
                psnr_scores.append(psnr_score)
                ssim_scores.append(ssim_score)
                js_scores.append(js_score)
            except Exception as e:
                print(f"Error computing metrics for image {img_path}: {str(e)}")
                psnr_scores.append(0.0)
                ssim_scores.append(0.0)
                js_scores.append(1.0)

        avg_psnr = np.mean(psnr_scores)
        avg_ssim = np.mean(ssim_scores)
        avg_js = np.mean(js_scores)

        all_images[class_label].append({
            'image': f"sample_{idx}.png",
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'js': avg_js,
            'pil_image': gen_image
        })

        idx += 1

    # Sort images by a combined score to select the top 100
    all_images[class_label].sort(key=lambda x: (x['psnr'] + x['ssim'] * 10 - x['js'] * 10), reverse=True)

    # Select top 100 images and save those that meet the criteria
    selected_images = all_images[class_label][:100]
    for metric in selected_images:
        avg_psnr = metric['psnr']
        avg_ssim = metric['ssim']
        avg_js = metric['js']
        img_name = metric['image']
        gen_image = metric['pil_image']

        metrics_dict[class_label].append({
            'image': img_name,
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'js': avg_js
        })

        # Check if the image meets the "good" criteria
        if avg_psnr > 11.8 and avg_ssim > 0.19 and avg_js < 0.24:
            good_img_path = os.path.join(good_class_dir, img_name)
            gen_image.save(good_img_path)
            good_images[class_label].append(img_name)

    print(f"\nClass {class_label} Metrics (Top 100):")
    for metric in metrics_dict[class_label][:100]:
        print(f"Image: {metric['image']}, PSNR: {metric['psnr']:.2f}, SSIM: {metric['ssim']:.4f}, JS: {metric['js']:.4f}")
    print(f"Good images for class {class_label}: {len(good_images[class_label])} images saved to {good_class_dir}")

# Summary of good images
total_good = sum(len(good_images[cls]) for cls in good_images)
print(f"\nTotal good images across all classes: {total_good}")

Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

Class 0 Metrics (Top 100):
Image: sample_42.png, PSNR: 10.13, SSIM: 0.1515, JS: 0.2996
Image: sample_180.png, PSNR: 10.67, SSIM: 0.1137, JS: 0.3332
Image: sample_420.png, PSNR: 10.53, SSIM: 0.1075, JS: 0.3144
Image: sample_24.png, PSNR: 10.12, SSIM: 0.1410, JS: 0.3074
Image: sample_13.png, PSNR: 10.58, SSIM: 0.1225, JS: 0.3377
Image: sample_100.png, PSNR: 10.17, SSIM: 0.1252, JS: 0.3011
Image: sample_417.png, PSNR: 10.42, SSIM: 0.1275, JS: 0.3320
Image: sample_179.png, PSNR: 10.16, SSIM: 0.1159, JS: 0.2962
Image: sample_264.png, PSNR: 10.60, SSIM: 0.1003, JS: 0.3270
Image: sample_32.png, PSNR: 10.47, SSIM: 0.1014, JS: 0.3181
Image: sample_11.png, PSNR: 10.55, SSIM: 0.1012, JS: 0.3275
Image: sample_499.png, PSNR: 10.47, SSIM: 0.0990, JS: 0.3202
Image: sample_229.png, PSNR: 10.47, S

In [23]:
##original_images_parameters

In [22]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.spatial.distance import jensenshannon
import kagglehub

# Define paths
path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")

# Hyperparameters
RESIZE = 128
num_classes = 5

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label, 'path': img_path}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Transform for resizing and normalizing images
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load dataset
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Organize images by class
images_by_class = {i: [] for i in range(num_classes)}
for idx in range(len(dataset)):
    sample = dataset[idx]
    label = sample['label']
    images_by_class[label].append((sample['image'], sample['path']))

# Function to compute PSNR, SSIM, and JS divergence between two images
def compute_metrics_for_image(orig, gen):
    orig_np = orig.cpu().numpy() * 0.5 + 0.5
    gen_np = gen.cpu().numpy() * 0.5 + 0.5
    orig_np = np.transpose(orig_np, (1, 2, 0))
    gen_np = np.transpose(gen_np, (1, 2, 0))
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    psnr_score = psnr(orig_np, gen_np, data_range=1.0)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    js_scores = []
    for c in range(3):
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Compute metrics for all images in each class
metrics_by_class = {i: [] for i in range(num_classes)}

for class_label in range(num_classes):
    class_images = images_by_class[class_label]
    if len(class_images) < 2:
        print(f"Class {class_label} has {len(class_images)} images, not enough for comparison. Skipping.")
        continue

    # Use the first image in the class as the reference
    ref_image, ref_path = class_images[0]
    print(f"\nClass {class_label} ({dataset.classes[class_label]}): Reference Image = {os.path.basename(ref_path)}, Total Images = {len(class_images)}")

    for img_tensor, img_path in class_images:
        if img_tensor is ref_image:  # Skip the reference image itself to avoid perfect scores
            continue
        try:
            psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_image, img_tensor)
            metrics_by_class[class_label].append({
                'image': os.path.basename(img_path),
                'psnr': psnr_score,
                'ssim': ssim_score,
                'js': js_score
            })
        except Exception as e:
            print(f"Error computing metrics for image {img_path}: {str(e)}")

    # Compute max, min, and mean
    if metrics_by_class[class_label]:
        psnr_values = [m['psnr'] for m in metrics_by_class[class_label]]
        ssim_values = [m['ssim'] for m in metrics_by_class[class_label]]
        js_values = [m['js'] for m in metrics_by_class[class_label]]

        print(f"Class {class_label} ({dataset.classes[class_label]}) Statistics:")
        print(f"PSNR - Max: {max(psnr_values):.2f}, Min: {min(psnr_values):.2f}, Mean: {np.mean(psnr_values):.2f}")
        print(f"SSIM - Max: {max(ssim_values):.4f}, Min: {min(ssim_values):.4f}, Mean: {np.mean(ssim_values):.4f}")
        print(f"JS   - Max: {max(js_values):.4f}, Min: {min(js_values):.4f}, Mean: {np.mean(js_values):.4f}")
    else:
        print(f"No metrics computed for class {class_label}.")

Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

Class 0 (Hatchback): Reference Image = PHOTO_0.jpg, Total Images = 602
Class 0 (Hatchback) Statistics:
PSNR - Max: 39.06, Min: 6.29, Mean: 10.91
SSIM - Max: 0.9899, Min: 0.0267, Mean: 0.1556
JS   - Max: 0.3581, Min: 0.0114, Mean: 0.2619

Class 1 (Other): Reference Image = PHOTO_0.jpg, Total Images = 600
Class 1 (Other) Statistics:
PSNR - Max: 37.46, Min: 7.67, Mean: 11.44
SSIM - Max: 0.9915, Min: 0.0513, Mean: 0.1294
JS   - Max: 0.3745, Min: 0.0151, Mean: 0.2822

Class 2 (Pickup): Reference Image = PHOTO_0.jpg, Total Images = 1689
Class 2 (Pickup) Statistics:
PSNR - Max: 49.59, Min: 6.77, Mean: 8.27
SSIM - Max: 0.9986, Min: 0.0069, Mean: 0.0669
JS   - Max: 0.3758, Min: 0.0029, Mean: 0.3134

Class 3 (Seden): Reference Image = PHOTO_0.jpg, Total Images = 1222
Class 3 (Seden) Statist

In [26]:
##adding other para and checking parameters of dataset

In [25]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.spatial.distance import jensenshannon
from scipy import stats
import kagglehub

# Define paths
path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")

# Hyperparameters
RESIZE = 128
num_classes = 5

# Custom Dataset for Kaggle Vehicle Type Image Dataset
class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = []
        self.class_to_idx = {}
        self.images = []
        self.labels = []

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.classes.append(class_name)
                    self.class_to_idx[class_name] = len(self.classes) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    try:
                        Image.open(img_path).verify()
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[class_name])
                    except:
                        print(f"Skipping corrupted image: {img_path}")

        if len(self.classes) != num_classes:
            raise ValueError(f"Expected {num_classes} classes, found {len(self.classes)}")
        print(f"Found {len(self.images)} images across {len(self.classes)} classes.")
        print(f"Classes: {self.classes}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        sample = {'image': image, 'label': label, 'path': img_path}
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        return sample

# Transform for resizing and normalizing images
transform = transforms.Compose([
    transforms.Resize((RESIZE, RESIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load dataset
dataset = VehicleTypeDataset(root_dir=path, transform=transform)

# Organize images by class
images_by_class = {i: [] for i in range(num_classes)}
for idx in range(len(dataset)):
    sample = dataset[idx]
    label = sample['label']
    images_by_class[label].append((sample['image'], sample['path']))

# Function to compute PSNR, SSIM, and JS divergence between two images
def compute_metrics_for_image(orig, gen):
    orig_np = orig.cpu().numpy() * 0.5 + 0.5
    gen_np = gen.cpu().numpy() * 0.5 + 0.5
    orig_np = np.transpose(orig_np, (1, 2, 0))
    gen_np = np.transpose(gen_np, (1, 2, 0))
    orig_np = np.clip(orig_np, 0, 1)
    gen_np = np.clip(gen_np, 0, 1)

    psnr_score = psnr(orig_np, gen_np, data_range=1.0)
    ssim_score = ssim(orig_np, gen_np, multichannel=True, data_range=1.0, channel_axis=2)

    js_scores = []
    for c in range(3):
        orig_channel = orig_np[:, :, c].flatten()
        gen_channel = gen_np[:, :, c].flatten()
        orig_channel = (orig_channel + 1e-10) / (np.sum(orig_channel) + 1e-10 * RESIZE * RESIZE)
        gen_channel = (gen_channel + 1e-10) / (np.sum(gen_channel) + 1e-10 * RESIZE * RESIZE)
        js_score = jensenshannon(orig_channel, gen_channel)
        if np.isnan(js_score):
            js_score = 1.0
        js_scores.append(js_score)
    js_score = np.mean(js_scores)

    return psnr_score, ssim_score, js_score

# Compute metrics for all images in each class
metrics_by_class = {i: [] for i in range(num_classes)}

for class_label in range(num_classes):
    class_images = images_by_class[class_label]
    if len(class_images) < 2:
        print(f"Class {class_label} has {len(class_images)} images, not enough for comparison. Skipping.")
        continue

    # Use the first image in the class as the reference
    ref_image, ref_path = class_images[0]
    print(f"\nClass {class_label} ({dataset.classes[class_label]}): Reference Image = {os.path.basename(ref_path)}, Total Images = {len(class_images)}")

    for img_tensor, img_path in class_images:
        if img_tensor is ref_image:  # Skip the reference image itself to avoid perfect scores
            continue
        try:
            psnr_score, ssim_score, js_score = compute_metrics_for_image(ref_image, img_tensor)
            metrics_by_class[class_label].append({
                'image': os.path.basename(img_path),
                'psnr': psnr_score,
                'ssim': ssim_score,
                'js': js_score
            })
        except Exception as e:
            print(f"Error computing metrics for image {img_path}: {str(e)}")

    # Compute detailed statistics
    if metrics_by_class[class_label]:
        psnr_values = np.array([m['psnr'] for m in metrics_by_class[class_label]])
        ssim_values = np.array([m['ssim'] for m in metrics_by_class[class_label]])
        js_values = np.array([m['js'] for m in metrics_by_class[class_label]])

        # Compute median
        psnr_median = np.median(psnr_values)
        ssim_median = np.median(ssim_values)
        js_median = np.median(js_values)

        # Compute mode (approximate by binning for continuous data)
        psnr_mode = stats.mode(np.round(psnr_values, 1), keepdims=False)[0]
        ssim_mode = stats.mode(np.round(ssim_values, 3), keepdims=False)[0]
        js_mode = stats.mode(np.round(js_values, 3), keepdims=False)[0]

        # Compute standard deviation
        psnr_std = np.std(psnr_values)
        ssim_std = np.std(ssim_values)
        js_std = np.std(js_values)

        # Compute range
        psnr_range = np.max(psnr_values) - np.min(psnr_values)
        ssim_range = np.max(ssim_values) - np.min(ssim_values)
        js_range = np.max(js_values) - np.min(js_values)

        # Compute IQR (Interquartile Range)
        psnr_q1, psnr_q3 = np.percentile(psnr_values, [25, 75])
        ssim_q1, ssim_q3 = np.percentile(ssim_values, [25, 75])
        js_q1, js_q3 = np.percentile(js_values, [25, 75])
        psnr_iqr = psnr_q3 - psnr_q1
        ssim_iqr = ssim_q3 - ssim_q1
        js_iqr = js_q3 - js_q1

        # Detect outliers using IQR method
        psnr_outliers = np.sum((psnr_values < (psnr_q1 - 1.5 * psnr_iqr)) | (psnr_values > (psnr_q3 + 1.5 * psnr_iqr)))
        ssim_outliers = np.sum((ssim_values < (ssim_q1 - 1.5 * ssim_iqr)) | (ssim_values > (ssim_q3 + 1.5 * ssim_iqr)))
        js_outliers = np.sum((js_values < (js_q1 - 1.5 * js_iqr)) | (js_values > (js_q3 + 1.5 * js_iqr)))

        print(f"Class {class_label} ({dataset.classes[class_label]}) Statistics:")
        print(f"PSNR - Max: {np.max(psnr_values):.2f}, Min: {np.min(psnr_values):.2f}, Mean: {np.mean(psnr_values):.2f}, Median: {psnr_median:.2f}, Mode: {psnr_mode:.2f}, Std: {psnr_std:.2f}, Range: {psnr_range:.2f}, IQR: {psnr_iqr:.2f}, Outliers: {psnr_outliers}")
        print(f"SSIM - Max: {np.max(ssim_values):.4f}, Min: {np.min(ssim_values):.4f}, Mean: {np.mean(ssim_values):.4f}, Median: {ssim_median:.4f}, Mode: {ssim_mode:.4f}, Std: {ssim_std:.4f}, Range: {ssim_range:.4f}, IQR: {ssim_iqr:.4f}, Outliers: {ssim_outliers}")
        print(f"JS   - Max: {np.max(js_values):.4f}, Min: {np.min(js_values):.4f}, Mean: {np.mean(js_values):.4f}, Median: {js_median:.4f}, Mode: {js_mode:.4f}, Std: {js_std:.4f}, Range: {js_range:.4f}, IQR: {js_iqr:.4f}, Outliers: {js_outliers}")
    else:
        print(f"No metrics computed for class {class_label}.")

Searching for images in C:\Users\hp\.cache\kagglehub\datasets\sujaykapadnis\vehicle-type-image-dataset\versions\1
Found 4793 images across 5 classes.
Classes: ['Hatchback', 'Other', 'Pickup', 'Seden', 'SUV']

Class 0 (Hatchback): Reference Image = PHOTO_0.jpg, Total Images = 602
Class 0 (Hatchback) Statistics:
PSNR - Max: 39.06, Min: 6.29, Mean: 10.91, Median: 10.70, Mode: 10.60, Std: 2.20, Range: 32.77, IQR: 1.91, Outliers: 9
SSIM - Max: 0.9899, Min: 0.0267, Mean: 0.1556, Median: 0.1430, Mode: 0.1230, Std: 0.0853, Range: 0.9632, IQR: 0.0713, Outliers: 29
JS   - Max: 0.3581, Min: 0.0114, Mean: 0.2619, Median: 0.2649, Mode: 0.2720, Std: 0.0357, Range: 0.3467, IQR: 0.0419, Outliers: 10

Class 1 (Other): Reference Image = PHOTO_0.jpg, Total Images = 600
Class 1 (Other) Statistics:
PSNR - Max: 37.46, Min: 7.67, Mean: 11.44, Median: 11.53, Mode: 11.60, Std: 1.36, Range: 29.79, IQR: 1.01, Outliers: 21
SSIM - Max: 0.9915, Min: 0.0513, Mean: 0.1294, Median: 0.1290, Mode: 0.1230, Std: 0.0433, R