In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

In [2]:
from model import GeneratorUSRGAN  # Ensure this matches your model definition

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = GeneratorUSRGAN().to(device)

In [3]:
# Step 3: Use Fine-Tuned Model for Inference
from torchvision.utils import save_image

# Load fine-tuned model
model.load_state_dict(torch.load("srgan_finetuned_ssim.pth", map_location=device))
model.to(device)
model.eval()

# Define transformation (resize and normalize)
transform = transforms.Compose([
    #transforms.Resize((256, 256)),  # Match training resolution
    transforms.ToTensor()
])

# Load test images

test_images_path = "images"  # Update this if needed
test_images = [f for f in os.listdir(test_images_path) if f.endswith(('.png', '.jpg', '.jpeg'))]


In [4]:
output_dir = "enhanced_images_all"
os.makedirs(output_dir, exist_ok=True)  # Ensure output directory exists

# Process each image
for img_name in test_images:
    img_path = os.path.join(test_images_path, img_name)
    img = Image.open(img_path).convert("RGB")  # Convert to RGB

    # Apply transformation
    lr_image = transform(img).unsqueeze(0).to(device)  # Add batch dimension

    # Enhance image using the SRGAN model
    with torch.no_grad():
        sr_image = model(lr_image)

    # Convert output tensor to image
    sr_image = sr_image.squeeze(0).cpu().permute(1, 2, 0).numpy()  # Convert to HWC format
    sr_image = ((sr_image + 1) / 2 * 255).astype("uint8")  # Rescale from [-1,1] to [0,255]
    sr_image = Image.fromarray(sr_image)

    # Save enhanced image
    sr_image.save(os.path.join(output_dir, img_name))
    print(f"Saved enhanced image: {img_name}")


Saved enhanced image: FishDataset119_png.rf.5400284ea33a58a61de4f0ae9c380c13.jpg
Saved enhanced image: FishDataset119_png.rf.qwvSqHU30Lxz48hqM6yg.jpg
Saved enhanced image: FishDataset12_png.rf.12a2d25995702339cc4f2bf69d58cfaf.jpg
Saved enhanced image: FishDataset12_png.rf.AFP6H846vTfHF9mRdf6u.jpg
Saved enhanced image: FishDataset162_png.rf.e9a5507588f1af930a5e3d195026616e.jpg
Saved enhanced image: FishDataset162_png.rf.wX5frtfXCMiSyLNuhif5.jpg
Saved enhanced image: FishDataset18_png.rf.39fe8e48b9075f6edbeb115f654f883a.jpg
Saved enhanced image: FishDataset18_png.rf.CBw2vCaKeDUgCsuX8iTY.jpg
Saved enhanced image: FishDataset191_png.rf.1893f84914c84986fc8defa499245119.jpg
Saved enhanced image: FishDataset191_png.rf.gP6I7RhYtZXSPK7gNUS3.jpg
Saved enhanced image: FishDataset198_png.rf.7392fdef0ea1ac052bece31cea4ca72f.jpg
Saved enhanced image: FishDataset198_png.rf.nIf9Uqk0shTBrEcDxJzZ.jpg
Saved enhanced image: FishDataset210_png.rf.aefa3d22bf0bea68a67bf92be1bfe525.jpg
Saved enhanced image: F