Code für Bild im .png Format

In [4]:
#imports

import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

import seaborn as sns
import lpips


In [None]:

# 1. Generator laden
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        #...
        pass

    def forward(self, x):
        #...
        pass

# Pfad zum Modell
generator = Generator()
generator.load_state_dict(torch.load("srgan_generator.pth", map_location="cpu"))
generator.eval()

# 2. LR-Bild vorbereiten

# Transform: skalieren & Tensor erzeugen (für LPIPS wichtig: normalize NICHT nötig)
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # lpips ist NN das dementsprechend immer gleiche Bildgröße braucht
    transforms.Grayscale(num_output_channels=3),  #3 statt 1
    transforms.ToTensor()
])

# Bilder laden
lr_img = Image.open("example_lr.png").convert("L")  # L = Greyscale
hr_img = Image.open("example_hr.png").convert("L")

# In 3-Kanal-Tensoren umwandeln
lr_tensor = transform(lr_img).unsqueeze(0)  # Shape: (1, 3, H, W)
hr_tensor = transform(hr_img).unsqueeze(0)

# 3. Inferenz; generator auf lr anwenden und super resolution erzeugen
with torch.no_grad():
    sr_tensor = generator(lr_tensor)

# 4. numpy arrays erstellen
sr_np = sr_tensor.squeeze().permute(1, 2, 0).numpy()
hr_np = hr_tensor.squeeze().permute(1, 2, 0).numpy()

# Konvertieren auf [0,1] für metriken
sr_np = np.clip(sr_np, 0, 1)
hr_np = np.clip(hr_np, 0, 1)

psnr_value = psnr(hr_np, sr_np, data_range=1.0)
ssim_value = ssim(hr_np, sr_np, multichannel=True, data_range=1.0)

print(f"PSNR: {psnr_value:.2f} dB")
print(f"SSIM: {ssim_value:.4f}")


plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.imshow(lr_img)
plt.title("Low Resolution")

plt.subplot(1,3,2)
plt.imshow(np.clip(sr_np, 0, 1))
plt.title("Super-Resolved")

plt.subplot(1,3,3)
plt.imshow(hr_img)
plt.title("High Resolution")

plt.suptitle(f"PSNR: {psnr_value:.2f} dB, SSIM: {ssim_value:.4f}")
plt.show()


FileNotFoundError: [Errno 2] No such file or directory: 'srgan_generator.pth'

In [6]:
#MAE + MSE 

mae_value = np.mean(np.abs(hr_np - sr_np))
mse_value = np.mean((hr_np - sr_np) ** 2)

print(f"MAE: {mae_value:.4f}")
print(f"MSE: {mse_value:.4f}")

NameError: name 'hr_np' is not defined

In [7]:
#heatmap bzw errormap (differenz der heatmaps)

# Nur Luminanz/Intensitätskanal, falls gewünscht
hr_gray = np.mean(hr_np, axis=2)
sr_gray = np.mean(sr_np, axis=2)
error_map = np.abs(hr_gray - sr_gray)

plt.figure(figsize=(6, 5))
sns.heatmap(error_map, cmap="viridis", cbar=True)
plt.title("Error Heatmap (|HR - SR|)")
plt.axis('off')
plt.show()


NameError: name 'hr_np' is not defined

In [8]:
#grey value histogram

plt.figure(figsize=(8, 5))
plt.hist(hr_gray.ravel(), bins=50, alpha=0.5, label="HR", color='red')
plt.hist(sr_gray.ravel(), bins=50, alpha=0.5, label="SR", color='blue')
plt.title("Histogramm der Grauwertverteilung")
plt.xlabel("Intensität")
plt.ylabel("Pixelanzahl")
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'hr_gray' is not defined

<Figure size 800x500 with 0 Axes>

In [10]:
#LPIPS: Bewertet Unterschiede im feature Raum

# Modell laden (z. B. VGG basierend)
lpips_fn = lpips.LPIPS(net='vgg')  # Optionen: 'alex', 'vgg', 'squeeze'

# Eingabebilder müssen (1, 3, H, W) und normalisiert [-1, 1] sein
def to_lpips_tensor(img_np):
    img_tensor = torch.tensor(img_np).permute(2, 0, 1).unsqueeze(0)
    img_tensor = img_tensor * 2 - 1  # [0,1] → [-1,1]
    return img_tensor.float()

sr_lpips = to_lpips_tensor(sr_np)
hr_lpips = to_lpips_tensor(hr_np)

# LPIPS-Distanz berechnen
lpips_distance = lpips_fn(sr_lpips, hr_lpips).item()
print(f"LPIPS Distance: {lpips_distance:.4f}")


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\LaraR/.cache\torch\hub\checkpoints\vgg16-397923af.pth
3.4%


KeyboardInterrupt: 