In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import glob
import numpy as np
from torchvision import models
from torch.cuda.amp import autocast, GradScaler
from torchvision.models import vgg16

# RRDBNet Components
class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_channels=32):
        super(DenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, growth_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(in_channels + growth_channels, growth_channels, 3, padding=1)
        self.conv3 = nn.Conv2d(in_channels + 2 * growth_channels, growth_channels, 3, padding=1)
        self.conv4 = nn.Conv2d(in_channels + 3 * growth_channels, growth_channels, 3, padding=1)
        self.conv5 = nn.Conv2d(in_channels + 4 * growth_channels, in_channels, 3, padding=1)
        self.lrelu = nn.LeakyReLU(0.2)

    def forward(self, x):
        out1 = self.lrelu(self.conv1(x))
        out2 = self.lrelu(self.conv2(torch.cat([x, out1], 1)))
        out3 = self.lrelu(self.conv3(torch.cat([x, out1, out2], 1)))
        out4 = self.lrelu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
        out5 = self.conv5(torch.cat([x, out1, out2, out3, out4], 1))
        return out5 * 0.2 + x  # Residual scaling

class RRDB(nn.Module):
    def __init__(self, in_channels, growth_channels=32):
        super(RRDB, self).__init__()
        self.dense1 = DenseBlock(in_channels, growth_channels)
        self.dense2 = DenseBlock(in_channels, growth_channels)
        self.dense3 = DenseBlock(in_channels, growth_channels)

    def forward(self, x):
        out = self.dense1(x)
        out = self.dense2(out)
        out = self.dense3(out)
        return out * 0.2 + x  # Residual scaling

class RRDBNet(nn.Module):
    def __init__(self, in_channels, out_channels, num_rrdb_blocks=10):
        super(RRDBNet, self).__init__()
        self.conv_first = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.rrdb_blocks = nn.Sequential(*[RRDB(64) for _ in range(num_rrdb_blocks)])
        self.conv_second = nn.Conv2d(64, 64, 3, padding=1)
        self.conv_upscale1 = nn.Conv2d(64, 256, 3, padding=1)  # 64 -> 256
        self.pixel_shuffle = nn.PixelShuffle(2)  # Reduces channels from 256 -> 64
        self.conv_upscale2 = nn.Conv2d(64, out_channels, 3, padding=1)  # Adjusted to 64 channels
        self.tanh = nn.Tanh()

    def forward(self, x):
        out = self.conv_first(x)
        out = self.rrdb_blocks(out)
        out = self.conv_second(out)
        out = self.pixel_shuffle(self.conv_upscale1(out))  # upscale and reduce channels
        out = self.conv_upscale2(out)  # Now expects 64 channels instead of 256
        return self.tanh(out)

# Two-Stage Generator Model
class TwoStageGenerator(nn.Module):
    def __init__(self):
        super(TwoStageGenerator, self).__init__()
        # First stage generator (e.g., generates an intermediate resolution)
        self.generator_stage_1 = RRDBNet(in_channels=3, out_channels=3, num_rrdb_blocks=10)
        # Second stage generator (refines the output of stage 1)
        self.generator_stage_2 = RRDBNet(in_channels=3, out_channels=3, num_rrdb_blocks=10)

    def forward(self, x):
        # Stage 1 output
        intermediate_output = self.generator_stage_1(x)
        # Stage 2 refinement
        final_output = self.generator_stage_2(intermediate_output)
        return intermediate_output, final_output

# Discriminator remains unchanged
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.lrelu1 = nn.LeakyReLU(0.2)

        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.lrelu2 = nn.LeakyReLU(0.2)

        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.lrelu3 = nn.LeakyReLU(0.2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dense1 = nn.Linear(256, 1024)
        self.lrelu4 = nn.LeakyReLU(0.2)
        self.dense2 = nn.Linear(1024, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.lrelu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.lrelu2(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.lrelu3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.dense1(out)
        out = self.lrelu4(out)
        out = self.dense2(out)
        return self.sigmoid(out)

# Custom Dataset class for image data
class ImageDataset(Dataset):
    def __init__(self, directory, target_size=(64, 64), scale_factor=4):
        self.image_paths = glob.glob(os.path.join(directory, "*.png")) + \
                          glob.glob(os.path.join(directory, "*.jpg"))
        self.target_size = target_size
        self.scale_factor = scale_factor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        img = Image.open(image_path).convert('RGB')
        high_res = img.resize(self.target_size, Image.LANCZOS)
        low_res = img.resize((self.target_size[0]//self.scale_factor, self.target_size[1]//self.scale_factor), Image.LANCZOS)
        low_res = low_res.resize(self.target_size, Image.LANCZOS)

        high_res = np.array(high_res).astype(np.float32) / 127.5 - 1
        low_res = np.array(low_res).astype(np.float32) / 127.5 - 1

        return torch.from_numpy(low_res).permute(2, 0, 1), torch.from_numpy(high_res).permute(2, 0, 1)

# VGG feature extractor for perceptual loss
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        model = vgg16(pretrained=True)
        self.features = nn.Sequential(*list(model.features)[:16])  # Extract features up to certain layer
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)

# Loss functions
def generator_loss(fake_output, fake_images_stage2, real_images, vgg):
    bce = nn.BCELoss()
    adversarial_loss = bce(fake_output, torch.ones_like(fake_output))

    # Resize images to 224x224 for VGG
    vgg_input_transform = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False)

    fake_images_stage2_resized = vgg_input_transform(fake_images_stage2)
    real_images_resized = vgg_input_transform(real_images)

    # Extract content features
    fake_features = vgg(fake_images_stage2_resized)
    real_features = vgg(real_images_resized)
    content_loss = nn.MSELoss()(fake_features, real_features)

    return adversarial_loss + 100 * content_loss

def discriminator_loss(real_output, fake_output):
    bce = nn.BCELoss()
    real_loss = bce(real_output, torch.ones_like(real_output))
    fake_loss = bce(fake_output, torch.zeros_like(fake_output))
    return real_loss + fake_loss

# Training loop
def train_model(train_dir, val_dir, output_dir, epochs=100, device='cpu'):
    os.makedirs(output_dir, exist_ok=True)

    generator = TwoStageGenerator().to(device)
    discriminator = Discriminator().to(device)
    vgg = VGGFeatureExtractor().to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

    train_dataset = ImageDataset(train_dir)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True)

    val_dataset = ImageDataset(val_dir)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, pin_memory=True)

    scaler = GradScaler()  # Gradient scaler for mixed precision

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        # Training
        generator.train()
        discriminator.train()
        train_g_loss = 0
        train_d_loss = 0
        for low_res, high_res in train_loader:
            low_res, high_res = low_res.to(device), high_res.to(device)

            # Train discriminator
            discriminator.zero_grad()
            real_output = discriminator(high_res)
            _, fake_images_stage2 = generator(low_res)
            fake_output = discriminator(fake_images_stage2.detach())

            with autocast():
                d_loss = discriminator_loss(real_output, fake_output)
            scaler.scale(d_loss).backward()
            scaler.step(d_optimizer)
            scaler.update()

            # Train generator
            generator.zero_grad()
            _, fake_images_stage2 = generator(low_res)
            fake_output = discriminator(fake_images_stage2)

            with autocast():
                g_loss = generator_loss(fake_output, fake_images_stage2, high_res, vgg)
            scaler.scale(g_loss).backward()
            scaler.step(g_optimizer)
            scaler.update()

            train_g_loss += g_loss.item()
            train_d_loss += d_loss.item()

        print(f"Generator Loss: {train_g_loss / len(train_loader):.4f}")
        print(f"Discriminator Loss: {train_d_loss / len(train_loader):.4f}")

        # Validation
        generator.eval()
        with torch.no_grad():
            for i, (low_res, high_res) in enumerate(val_loader):
                low_res, high_res = low_res.to(device), high_res.to(device)
                _, generated_images = generator(low_res)
                generated_images = ((generated_images + 1) * 127.5).clamp(0, 255).permute(0, 2, 3, 1).byte().cpu().numpy()
                for j, img in enumerate(generated_images):
                    img_path = os.path.join(output_dir, f'epoch_{epoch+1}_sample_{i}_{j}.png')
                    Image.fromarray(img).save(img_path)

    torch.save(generator.state_dict(), os.path.join(output_dir, 'generator_final.pth'))
    torch.save(discriminator.state_dict(), os.path.join(output_dir, 'discriminator_final.pth'))

# Main execution
if __name__ == "__main__":
    train_dir = "D:\\DIV2K_train_HR\\DIV2K_train_HR"
    val_dir = "D:\\DIV2K_valid_HR\\DIV2K_valid_HR"
    output_dir = "D:\\Images\\Code_output"

    train_model(train_dir, val_dir, output_dir, epochs=100, device='cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
import os
import os.path as osp
import glob
import cv2
import numpy as np
import torch
import sys
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

sys.path.append('D:\\Gan_code')
import RRDBNet_arch as arch

# Model and device configuration
model_path = 'C:\\Users\\jatin\\Downloads\\generator_2.pth'  # Path to your model weights
device = torch.device('cpu')  # Use 'cuda' if you want to run on GPU

# Test image folder and output directory
test_img_folder = 'D:\\afhq test'  # Path to input images
output_dir = 'D:\\Images\\results'       # Path to save results
os.makedirs(output_dir, exist_ok=True)

# Load the model
model = arch.RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
model.eval()
model = model.to(device)

print(f'Model loaded from {model_path}. \nStarting testing...')

# Initialize lists to store metrics
psnr_scores = []
ssim_scores = []

# Process each image in the folder
idx = 0
for path in glob.glob(osp.join(test_img_folder, '*.*')):  # Match all image files
    idx += 1
    base = osp.splitext(osp.basename(path))[0]
    print(f'Processing {idx}: {base}')

    # Read image
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    if img is None:
        print(f"Error reading {path}. Skipping...")
        continue

    # Store original dimensions
    original_height, original_width = img.shape[:2]
    
    # Store original image for metric calculations
    img_original = img.copy()

    img = img / 255.0  # Normalize to [0, 1]
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
    img_LR = img.unsqueeze(0).to(device)

    # Super-resolution inference
    with torch.no_grad():
        output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()

    # Convert back to image format
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # RGB
    output = (output * 255.0).round().astype(np.uint8)  # Scale to [0, 255]
    
    # Resize output to match original dimensions
    if output.shape[:2] != (original_height, original_width):
        output = cv2.resize(output, (original_width, original_height))

    # Save the output image
    save_path = osp.join(output_dir, f'{base}_rlt.png')
    cv2.imwrite(save_path, output)
    print(f'Saved result to {save_path}')

    # Convert images to grayscale for PSNR and SSIM
    img_original_gray = cv2.cvtColor(img_original, cv2.COLOR_BGR2GRAY)
    output_gray = cv2.cvtColor(output, cv2.COLOR_BGR2GRAY)

    # Calculate PSNR
    psnr_val = psnr(img_original_gray, output_gray, data_range=255)
    psnr_scores.append(psnr_val)

    # Calculate SSIM
    ssim_val = ssim(img_original_gray, output_gray, data_range=255)
    ssim_scores.append(ssim_val)

# Report average metrics
if psnr_scores:
    print(f"Average PSNR: {np.mean(psnr_scores):.2f} dB")

if ssim_scores:
    print(f"Average SSIM: {np.mean(ssim_scores):.4f}")

print("All images processed successfully!")

  model.load_state_dict(torch.load(model_path, map_location=device), strict=True)


Model loaded from C:\Users\jatin\Downloads\generator_2.pth. 
Starting testing...
Processing 1: flickr_cat_000008
Saved result to D:\Images\results\flickr_cat_000008_rlt.png
Processing 2: flickr_cat_000011
Saved result to D:\Images\results\flickr_cat_000011_rlt.png
Processing 3: flickr_cat_000016
Saved result to D:\Images\results\flickr_cat_000016_rlt.png
Processing 4: flickr_cat_000056
Saved result to D:\Images\results\flickr_cat_000056_rlt.png
Processing 5: flickr_cat_000076
Saved result to D:\Images\results\flickr_cat_000076_rlt.png
Processing 6: flickr_cat_000080
Saved result to D:\Images\results\flickr_cat_000080_rlt.png
Processing 7: flickr_cat_000096
Saved result to D:\Images\results\flickr_cat_000096_rlt.png
Processing 8: flickr_cat_000108
Saved result to D:\Images\results\flickr_cat_000108_rlt.png
Processing 9: flickr_cat_000123
Saved result to D:\Images\results\flickr_cat_000123_rlt.png
Processing 10: flickr_cat_000136
Saved result to D:\Images\results\flickr_cat_000136_rlt.pn