In [1]:
from utils import *
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from datasets import SRDataset
import random
from torch.utils.data import Subset
import os
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

In [2]:
# Define the path to the test set folder and output folder
original_test_folder = "dota/test/images_large"
resized_test_folder = "dota/test/images"

# Create the output folder if it doesn't exist
os.makedirs(resized_test_folder, exist_ok=True)

# Define the transformation for resizing
resize_transform = transforms.Compose([
    transforms.Resize((1024, 1024))  # Resize to 1024x1024
])

# Loop through all images in the test set folder
print("Resizing test set images...")
for image_name in tqdm(os.listdir(original_test_folder)):
    # Full path to the original image
    original_image_path = os.path.join(original_test_folder, image_name)
    
    # Open the image using PIL
    with Image.open(original_image_path) as img:
        # Apply the resizing transformation
        resized_img = resize_transform(img)
        
        # Save the resized image to the new folder
        resized_image_path = os.path.join(resized_test_folder, image_name)
        resized_img.save(resized_image_path)

print(f"Resized images are saved to {resized_test_folder}")


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data
data_folder = "./"
test_data_names = ["dota/test/images"] #["Set5", "Set14", "BSDS100"]

# Model checkpoints
srgan_checkpoint = "./checkpoint_srgan.pth.tar"
srresnet_checkpoint = "./checkpoint_srresnet.pth.tar"

# Evaluate SRGAN on test set

In [3]:
# Load model, either the SRResNet or the SRGAN
srgan_generator = torch.load(srgan_checkpoint,weights_only=False)['generator'].to(device)
srgan_generator.eval()
model = srgan_generator

In [37]:
# Evaluate
def test_eval(model):
    for test_data_name in test_data_names:
        print("\nFor %s:\n" % test_data_name)
    
        # Custom dataloader
        test_dataset = SRDataset(data_folder,
                                 split='test',
                                 crop_size=0,
                                 scaling_factor=4,
                                 lr_img_type='imagenet-norm',
                                 hr_img_type='[-1, 1]',
                                 test_data_name=test_data_name)
    
        # Randomly sample a subset if needed
        num_samples = 469  # Or the size of the full test set if no limit is needed
        total_samples = len(test_dataset)
        selected_indices = random.sample(range(total_samples), num_samples)  # Use `range(total_samples)` for full dataset
        test_dataset_subset = Subset(test_dataset, selected_indices)
    
        # Create a DataLoader for the smaller dataset
        test_loader = torch.utils.data.DataLoader(test_dataset_subset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
        print(f"Test set size: {len(test_dataset_subset)}")
    
        # Lists to store PSNRs and SSIMs for individual images
        psnr_list = []
        ssim_list = []
    
        # Prohibit gradient computation explicitly
        with torch.no_grad():
            for i, (lr_imgs, hr_imgs) in enumerate(test_loader):
                # Move to default device
                lr_imgs = lr_imgs.to(device)  # (batch_size (1), 3, w / 4, h / 4), imagenet-normed
                hr_imgs = hr_imgs.to(device)  # (batch_size (1), 3, w, h), in [-1, 1]
    
                # Forward prop.
                sr_imgs = model(lr_imgs)  # (1, 3, w, h), in [-1, 1]
    
                # Calculate PSNR and SSIM
                sr_imgs_y = convert_image(sr_imgs, source='[-1, 1]', target='y-channel').squeeze(
                    0)  # (w, h), in y-channel
                hr_imgs_y = convert_image(hr_imgs, source='[-1, 1]', target='y-channel').squeeze(0)  # (w, h), in y-channel
    
                psnr = peak_signal_noise_ratio(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(), data_range=255.)
                ssim = structural_similarity(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(), data_range=255.)
    
                # Append PSNR and SSIM to the lists
                psnr_list.append(psnr)
                ssim_list.append(ssim)
    
                #print(f"Image {i + 1}: PSNR = {psnr:.3f}, SSIM = {ssim:.3f}")
    
        # Print average PSNR and SSIM
        avg_psnr = sum(psnr_list) / len(psnr_list)
        avg_ssim = sum(ssim_list) / len(ssim_list)
        print(f"\nAverage PSNR: {avg_psnr:.3f}")
        print(f"Average SSIM: {avg_ssim:.3f}")
        return avg_psnr, avg_ssim


In [38]:
avg_psnr_srgan, avg_ssim_srgan = test_eval(model) #SRGAN


For dota/test/images:

Test set size: 469

Average PSNR: 28.840
Average SSIM: 0.786


# Evaluate SRResNet on test set

In [39]:
srresnet = torch.load(srresnet_checkpoint, weights_only=False)['model'].to(device)
srresnet.eval()
print("Evaluating SRResNet")

Evaluating SRResNet


In [40]:
avg_psnr_srresnet, avg_ssim_srresnet = test_eval(srresnet) #SRResnet


For dota/test/images:

Test set size: 469

Average PSNR: 31.880
Average SSIM: 0.865


# Create results table

In [41]:
# Results from the evaluations
results = {
    "Model": ["SRGAN", "SRResNet"],
    "Average PSNR": [avg_psnr_srgan, avg_psnr_srresnet],
    "Average SSIM": [avg_ssim_srgan, avg_ssim_srresnet]
}

# Create a dataframe
df = pd.DataFrame(results)

In [42]:
# Define a function to style rows with colored backgrounds
def highlight_rows(row):
    color = "lightblue" if row["Model"] == "SRGAN" else "lightgreen"
    return [f"background-color: {color}; font-size: 18px; font-weight: bold;"] * len(row)

# Apply the styling function
styled_df = df.style.apply(highlight_rows, axis=1)

# Apply global styles for all text and borders around cells
styled_df = styled_df.set_table_styles([
    {"selector": "th", "props": [("font-size", "20px"), ("font-weight", "bold"), ("text-align", "center"), ("border", "2px solid black")]},
    {"selector": "td", "props": [("font-size", "18px"), ("font-weight", "bold"), ("text-align", "center"), ("border", "2px solid black")]},
    {"selector": "table", "props": [("border", "2px solid black"), ("border-collapse", "collapse")]}
])

# Display the styled dataframe (works in Jupyter Notebook or HTML-supported environments)
styled_df

Unnamed: 0,Model,Average PSNR,Average SSIM
0,SRGAN,28.840011,0.785837
1,SRResNet,31.879947,0.865299
