In [7]:
# Part 2: Evaluation of ResNet18 and MobileNetV2 with Wavelet Compression
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import pywt
from torchvision.models import mobilenet_v2

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dir = r"C:\Users\HUAWEI\Downloads"
data_dir = r"C:\Users\HUAWEI\Downloads\mini-imagenet\test"  # Adjust to your dataset path

In [9]:
# Step 2: Load Test Dataset
# Transform for RGB images (3 channels, 84x84)
transform = transforms.Compose([
    transforms.Resize((84, 84)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Standard classification dataset for ResNet18 and MobileNetV2
class ClassificationDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.classes = sorted(os.listdir(root))
        self.data = []
        self.labels = []
        for cls in self.classes:
            cls_path = os.path.join(root, cls)
            if os.path.isdir(cls_path):
                images = [os.path.join(cls_path, img) for img in os.listdir(cls_path) if img.endswith(('.jpg', '.png'))]
                self.data.extend(images)
                self.labels.extend([self.classes.index(cls)] * len(images))
    
    def __getitem__(self, index):
        img_path = self.data[index]
        label = self.labels[index]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label
    
    def __len__(self):
        return len(self.data)

# Load dataset
classification_dataset = ClassificationDataset(root=data_dir, transform=transform)
classification_loader = DataLoader(classification_dataset, batch_size=32, shuffle=False)

In [11]:
# Step 3: Load Models
# ResNet18
resnet18 = models.resnet18(pretrained=False)
resnet18.fc = nn.Linear(resnet18.fc.in_features, 100)  # Adjust for 100 classes
try:
    resnet18.load_state_dict(torch.load(os.path.join(model_dir, "resnet18_finetuned.pth"), map_location=device, weights_only=True))
    print("ResNet18 loaded successfully.")
except Exception as e:
    print(f"Error loading ResNet18: {e}")
    raise
resnet18 = resnet18.to(device).eval()

# MobileNetV2
mobilenet_v2 = mobilenet_v2(pretrained=False)
mobilenet_v2.classifier[1] = nn.Linear(mobilenet_v2.classifier[1].in_features, 100)  # Adjust for 100 classes
try:
    mobilenet_v2.load_state_dict(torch.load(os.path.join(model_dir, "mobilenet_v2_finetuned.pth"), map_location=device, weights_only=True))
    print("MobileNetV2 loaded successfully.")
except Exception as e:
    print(f"Error loading MobileNetV2: {e}")
    raise
mobilenet_v2 = mobilenet_v2.to(device).eval()



ResNet18 loaded successfully.
MobileNetV2 loaded successfully.


In [13]:
# Step 4: Wavelet-Based Image Compression
import numpy as np
import pywt
import torch

def apply_wavelet_compression(images, retain):
    """
    Apply wavelet compression to a batch of images.
    
    Args:
        images (torch.Tensor): Input images [N, 3, 84, 84]
        retain (float): Percentage of wavelet coefficients to retain (e.g., 0.1 for 10%)
    
    Returns:
        torch.Tensor: Compressed images [N, 3, 84, 84]
        float: PSNR (dB)
        float: Compression ratio
    """
    images_np = images.cpu().numpy()  # [N, 3, 84, 84]
    compressed = []
    total_coeffs = 0
    retained_coeffs = 0
    
    for img in images_np:
        img_recon = np.zeros_like(img)  # [3, 84, 84]
        for c in range(3):  # Process R, G, B channels
            # 2-level wavelet decomposition
            coeffs = pywt.wavedec2(img[c], 'db1', level=2)
            coeff_arr, coeff_slices = pywt.coeffs_to_array(coffs)
            total_coeffs += coeff_arr.size
            
            # Threshold to retain top 'retain' percentage of coefficients
            thresh = np.percentile(np.abs(coeff_arr), 100 * (1 - retain))
            coeff_arr[np.abs(coeff_arr) < thresh] = 0
            retained_coeffs += np.sum(coeff_arr != 0)
            
            # Reconstruct image
            coeffs_recon = pywt.array_to_coeffs(coeff_arr, coeff_slices, output_format='wavedec2')
            img_recon[c] = pywt.waverec2(coeffs_recon, 'db1')
        
        compressed.append(img_recon)
    
    # Stack compressed images
    compressed = np.stack(compressed)  # [N, 3, 84, 84]
    
    # Compute compression ratio
    comp_ratio = total_coeffs / retained_coeffs if retained_coeffs > 0 else 1
    
    # Compute PSNR
    mse = np.mean((images_np - compressed) ** 2)
    psnr = 10 * np.log10(1 / mse) if mse > 0 else 100
    
    # Convert compressed images to tensor
    compressed = torch.from_numpy(compressed).to(images.device).float()
    
    return compressed, psnr, comp_ratio

In [15]:
# Step 5: Define Evaluation Function
def apply_wavelet_compression(images, retain):
    images_np = images.cpu().numpy()  # [N, 3, 84, 84]
    compressed = []
    total_coeffs = 0
    retained_coeffs = 0
    for img in images_np:
        img_recon = np.zeros_like(img)
        for c in range(3):
            coeffs = pywt.wavedec2(img[c], 'db1', level=2)
            coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs)
            total_coeffs += coeff_arr.size
            thresh = np.percentile(np.abs(coeff_arr), 100 * (1 - retain))
            coeff_arr[np.abs(coeff_arr) < thresh] = 0
            retained_coeffs += np.sum(coeff_arr != 0)
            coeffs_recon = pywt.array_to_coeffs(coeff_arr, coeff_slices, output_format='wavedec2')
            img_recon[c] = pywt.waverec2(coeffs_recon, 'db1')
        compressed.append(img_recon)
    compressed = np.stack(compressed)
    comp_ratio = total_coeffs / retained_coeffs if retained_coeffs > 0 else 1
    mse = np.mean((images_np - compressed) ** 2)
    psnr = 10 * np.log10(1 / mse) if mse > 0 else 100
    return torch.from_numpy(compressed).to(device).float(), psnr, comp_ratio

def evaluate_classification(model, dataloader, device, retain_percentage):
    model.eval()
    accuracies = []
    psnr_values = []
    compression_ratios = []
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)  # [batch_size, 3, 84, 84]
            labels = labels.to(device)
            
            # Compress images
            images_comp, psnr, comp_ratio = apply_wavelet_compression(images, retain_percentage)
            
            # Forward pass
            outputs = model(images_comp)  # [batch_size, 100]
            predictions = outputs.argmax(dim=1)  # [batch_size]
            
            # Compute accuracy
            accuracy = (predictions == labels).float().mean().item()
            accuracies.append(accuracy)
            psnr_values.append(psnr)
            compression_ratios.append(comp_ratio)
    
    acc_mean = np.mean(accuracies)
    acc_std = np.std(accuracies)
    psnr_val = np.mean(psnr_values)
    comp_ratio = np.mean(compression_ratios)
    return acc_mean, acc_std, psnr_val, comp_ratio

In [17]:
# Step 6: Run Evaluations
results = {
    'resnet18': {'classification': {}},
    'mobilenet_v2': {'classification': {}}
}
retain_percentages = [0.1, 0.25, 0.5]

# Evaluate ResNet18
for retain in retain_percentages:
    print(f"\nEvaluating ResNet18 with {retain*100}% coefficients retained:")
    acc_mean, acc_std, psnr_val, comp_ratio = evaluate_classification(resnet18, classification_loader, device, retain)
    results['resnet18']['classification'][retain] = {
        'accuracy_mean': acc_mean,
        'accuracy_std': acc_std,
        'psnr': psnr_val,
        'compression_ratio': comp_ratio
    }
    print(f"ResNet18 - Accuracy: {acc_mean:.4f} ± {acc_std:.4f}, PSNR: {psnr_val:.2f} dB, Compression Ratio: {comp_ratio:.2f}")

# Evaluate MobileNetV2
for retain in retain_percentages:
    print(f"\nEvaluating MobileNetV2 with {retain*100}% coefficients retained:")
    acc_mean, acc_std, psnr_val, comp_ratio = evaluate_classification(mobilenet_v2, classification_loader, device, retain)
    results['mobilenet_v2']['classification'][retain] = {
        'accuracy_mean': acc_mean,
        'accuracy_std': acc_std,
        'psnr': psnr_val,
        'compression_ratio': comp_ratio
    }
    print(f"MobileNetV2 - Accuracy: {acc_mean:.4f} ± {acc_std:.4f}, PSNR: {psnr_val:.2f} dB, Compression Ratio: {comp_ratio:.2f}")


Evaluating ResNet18 with 10.0% coefficients retained:
ResNet18 - Accuracy: 0.0123 ± 0.0527, PSNR: 19.93 dB, Compression Ratio: 9.98

Evaluating ResNet18 with 25.0% coefficients retained:
ResNet18 - Accuracy: 0.0149 ± 0.0621, PSNR: 26.93 dB, Compression Ratio: 3.99

Evaluating ResNet18 with 50.0% coefficients retained:
ResNet18 - Accuracy: 0.0144 ± 0.0588, PSNR: 36.15 dB, Compression Ratio: 2.00

Evaluating MobileNetV2 with 10.0% coefficients retained:
MobileNetV2 - Accuracy: 0.0223 ± 0.0921, PSNR: 19.93 dB, Compression Ratio: 9.98

Evaluating MobileNetV2 with 25.0% coefficients retained:
MobileNetV2 - Accuracy: 0.0163 ± 0.0632, PSNR: 26.93 dB, Compression Ratio: 3.99

Evaluating MobileNetV2 with 50.0% coefficients retained:
MobileNetV2 - Accuracy: 0.0139 ± 0.0543, PSNR: 36.15 dB, Compression Ratio: 2.00


In [19]:
# Step 7: Save and Display Results
base_dir = r"C:\Users\HUAWEI\Downloads"
with open(os.path.join(base_dir, "compression_results.md"), "w") as f:
    f.write("# Wavelet-based Image Compression Results\n\n")
    
    f.write("## Standard Classification (ResNet18)\n")
    for retain in retain_percentages:
        f.write(f"### Retaining {retain*100}% Coefficients\n")
        f.write("- Accuracy (Mean ± Std): {:.4f} ± {:.4f}\n".format(
            results['resnet18']['classification'][retain]['accuracy_mean'],
            results['resnet18']['classification'][retain]['accuracy_std']))
        f.write(f"- PSNR: {results['resnet18']['classification'][retain]['psnr']:.2f} dB\n")
        f.write(f"- Compression Ratio: {results['resnet18']['classification'][retain]['compression_ratio']:.2f}\n\n")
    
    f.write("## Standard Classification (MobileNetV2)\n")
    for retain in retain_percentages:
        f.write(f"### Retaining {retain*100}% Coefficients\n")
        f.write("- Accuracy (Mean ± Std): {:.4f} ± {:.4f}\n".format(
            results['mobilenet_v2']['classification'][retain]['accuracy_mean'],
            results['mobilenet_v2']['classification'][retain]['accuracy_std']))
        f.write(f"- PSNR: {results['mobilenet_v2']['classification'][retain]['psnr']:.2f} dB\n")
        f.write(f"- Compression Ratio: {results['mobilenet_v2']['classification'][retain]['compression_ratio']:.2f}\n\n")

print("\nResults saved to compression_results.md")
print("\nSummary:")
for retain in retain_percentages:
    print(f"\nRetaining {retain*100}% Coefficients:")
    print(f"ResNet18 - Accuracy: {results['resnet18']['classification'][retain]['accuracy_mean']:.4f} ± "
          f"{results['resnet18']['classification'][retain]['accuracy_std']:.4f}, "
          f"PSNR: {results['resnet18']['classification'][retain]['psnr']:.2f} dB, "
          f"Compression Ratio: {results['resnet18']['classification'][retain]['compression_ratio']:.2f}")
    print(f"MobileNetV2 - Accuracy: {results['mobilenet_v2']['classification'][retain]['accuracy_mean']:.4f} ± "
          f"{results['mobilenet_v2']['classification'][retain]['accuracy_std']:.4f}, "
          f"PSNR: {results['mobilenet_v2']['classification'][retain]['psnr']:.2f} dB, "
          f"Compression Ratio: {results['mobilenet_v2']['classification'][retain]['compression_ratio']:.2f}")


Results saved to compression_results.md

Summary:

Retaining 10.0% Coefficients:
ResNet18 - Accuracy: 0.0123 ± 0.0527, PSNR: 19.93 dB, Compression Ratio: 9.98
MobileNetV2 - Accuracy: 0.0223 ± 0.0921, PSNR: 19.93 dB, Compression Ratio: 9.98

Retaining 25.0% Coefficients:
ResNet18 - Accuracy: 0.0149 ± 0.0621, PSNR: 26.93 dB, Compression Ratio: 3.99
MobileNetV2 - Accuracy: 0.0163 ± 0.0632, PSNR: 26.93 dB, Compression Ratio: 3.99

Retaining 50.0% Coefficients:
ResNet18 - Accuracy: 0.0144 ± 0.0588, PSNR: 36.15 dB, Compression Ratio: 2.00
MobileNetV2 - Accuracy: 0.0139 ± 0.0543, PSNR: 36.15 dB, Compression Ratio: 2.00
