In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip "/content/drive/MyDrive/image_output.zip" -d "/content/datasets"

In [None]:
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image

class SuperResolutionDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, hr_transform=None, lr_transform=None):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.hr_transform = hr_transform
        self.lr_transform = lr_transform
        self.filenames = [f for f in os.listdir(lr_dir) if os.path.isfile(os.path.join(hr_dir, f))]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        hr_path = os.path.join(self.hr_dir, self.filenames[idx])
        lr_path = os.path.join(self.lr_dir, self.filenames[idx])

        hr_image = Image.open(hr_path).convert('RGB')
        lr_image = Image.open(lr_path).convert('RGB')

        if self.hr_transform:
            hr_image = self.hr_transform(hr_image)
        if self.lr_transform:
            lr_image = self.lr_transform(lr_image)

        return lr_image, hr_image

# Conversion of high-resolution and low-resolution images
hr_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

lr_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Creat Datasets
dataset = SuperResolutionDataset(
    hr_dir='/content/datasets/sharp_original',
    lr_dir='/content/datasets/deblured',
    hr_transform=hr_transform,
    lr_transform=lr_transform
)

# Split datasets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Definite dataloader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8)

In [None]:
import torch
import torch.nn as nn

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.relu(residual)
        residual = self.conv2(residual)
        return x + residual

# EDSR Model, use scale_factor to choose scale
class EDSR(nn.Module):
    def __init__(self, scale_factor=2, num_channels=3, num_residual_blocks=16):
        super(EDSR, self).__init__()
        self.num_channels = num_channels

        # First layer
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=4)

        # Residual blocks
        self.residual_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual_blocks)])

        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Upsampling layers
        self.upsampling = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor=scale_factor),
            nn.ReLU(inplace=True)
        )

        # Output layer
        self.conv3 = nn.Conv2d(64, num_channels, kernel_size=9, padding=4)

    def forward(self, x):
        out = self.conv1(x)
        residual = out
        out = self.residual_blocks(out)
        out = self.conv2(out)
        out = out + residual  # Element-wise sum
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

#Create EDSR model instance
model = EDSR(scale_factor=2, num_channels=3, num_residual_blocks=16)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
import torch.distributed as distance
from torch.cuda.amp import GradScaler, autocast

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#def Perceptual Loss by MobileNet
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.mobilenet = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT).features
        self.mobilenet.eval()
        for param in self.mobilenet.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        input_features = self.mobilenet(input)
        target_features = self.mobilenet(target)
        # Resize if necessary
        if input_features.shape[2:] != target_features.shape[2:]:
            input_features = F.interpolate(input_features, size=target_features.shape[2:], mode='bilinear', align_corners=False)

        loss = nn.functional.mse_loss(input_features, target_features)
        return loss

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Creat model and Adam optimizer
SuperResolution_model = EDSR(scale_factor=2, num_channels=3, num_residual_blocks=16).to(device)
mse_criterion = nn.MSELoss()
perceptual_criterion = PerceptualLoss().to(device)
optimizer = torch.optim.Adam(SuperResolution_model.parameters(), lr=0.0001)

# Initialize the ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)

In [None]:
#Defining Loss Function Weights
mse_weight = 1.0
perceptual_weight = 0.1

In [None]:
# Train model
def train_model(model, train_loader, val_loader, mse_criterion, perceptual_criterion, optimizer, num_epochs=100, early_stopping_tolerance=8):
    best_val_loss = float('inf')
    no_improvement_count = 0  # Early stopping counter

    scaler = GradScaler()
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for lr_image, hr_image in train_loader:
            lr_image = lr_image.to(device)
            hr_image = hr_image.to(device)

            optimizer.zero_grad()

           # Performing forward propagation using the autocast context
            with autocast():
                outputs = model(lr_image)
                mse_loss = mse_criterion(outputs, hr_image)
                perceptual_loss = perceptual_criterion(outputs, hr_image)
                total_loss = mse_weight * mse_loss + perceptual_weight * perceptual_loss

            # Performing backward propagation and optimization using GradScaler
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += total_loss.item() * lr_image.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        # Validation test
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for lr_image, hr_image in val_loader:
                lr_image = lr_image.to(device)
                hr_image = hr_image.to(device)
                outputs = model(lr_image)
                mse_loss = mse_criterion(outputs, hr_image)
                perceptual_loss = perceptual_criterion(outputs, hr_image)
                total_loss = mse_weight * mse_loss + perceptual_weight * perceptual_loss
                val_loss += total_loss.item() * lr_image.size(0)

        val_loss /= len(val_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        scheduler.step(val_loss)
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvement_count = 0
        else:
            no_improvement_count += 1
            if no_improvement_count >= early_stopping_tolerance:
                print("Stopping early due to no improvement in validation loss")
                break

    return model

train_model(SuperResolution_model, train_loader, val_loader, mse_criterion, perceptual_criterion, optimizer, num_epochs=100, early_stopping_tolerance=8)

Epoch 1/100, Training Loss: 0.2146, Validation Loss: 0.1678
Epoch 2/100, Training Loss: 0.1578, Validation Loss: 0.1567
Epoch 3/100, Training Loss: 0.1524, Validation Loss: 0.1532
Epoch 4/100, Training Loss: 0.1494, Validation Loss: 0.1516
Epoch 5/100, Training Loss: 0.1479, Validation Loss: 0.1523
Epoch 6/100, Training Loss: 0.1465, Validation Loss: 0.1489
Epoch 7/100, Training Loss: 0.1455, Validation Loss: 0.1466
Epoch 8/100, Training Loss: 0.1446, Validation Loss: 0.1470
Epoch 9/100, Training Loss: 0.1441, Validation Loss: 0.1459
Epoch 10/100, Training Loss: 0.1433, Validation Loss: 0.1458
Epoch 11/100, Training Loss: 0.1427, Validation Loss: 0.1448
Epoch 12/100, Training Loss: 0.1420, Validation Loss: 0.1444
Epoch 13/100, Training Loss: 0.1415, Validation Loss: 0.1439
Epoch 14/100, Training Loss: 0.1411, Validation Loss: 0.1435
Epoch 15/100, Training Loss: 0.1404, Validation Loss: 0.1435
Epoch 16/100, Training Loss: 0.1401, Validation Loss: 0.1427
Epoch 17/100, Training Loss: 0.13

EDSR(
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding

In [None]:
# Save model
if torch.cuda.is_available() and torch.cuda.current_device() == 0:
    model_path = '/content/drive/MyDrive/model/EDSR/SuperResolution_EDSR.pth'
    model_dir = os.path.expanduser(os.path.dirname(model_path))

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    torch.save(SuperResolution_model.state_dict(), os.path.expanduser(model_path))
    print(f"Model saved to {model_path}.")

Model saved to /content/drive/MyDrive/model/EDSR/SuperResolution_EDSR.pth.


In [None]:
# Inverse normalization transformation
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.225]
)

In [None]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage import img_as_float
import torch

# Initialize the sum of PSNR and SSIM
total_psnr = 0
total_ssim = 0
num_images = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = '/content/drive/MyDrive/model/EDSR/SuperResolution_EDSR.pth'

model = EDSR(scale_factor=2, num_channels=3, num_residual_blocks=16)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

with torch.no_grad():
    for lr_image, hr_image in val_loader:
        lr_image = lr_image.to(device)
        hr_image = hr_image.to(device)

        SuperResolution = model(lr_image)

        # Traverse the image to calculate PSNR and SSIM
        for i in range(lr_image.size(0)):
            SuperResolution_img = inv_normalize(SuperResolution[i]).clamp(0, 1)
            hr_img = inv_normalize(hr_image[i]).clamp(0, 1)

            SuperResolution_np = SuperResolution_img.cpu().numpy().transpose(1, 2, 0)
            hr_np = hr_img.cpu().numpy().transpose(1, 2, 0)

            # Calculate PSNR and SSIM
            psnr = compare_psnr(SuperResolution_np, hr_np)
            ssim = compare_ssim(SuperResolution_np, hr_np, multichannel=True)

            total_psnr += psnr
            total_ssim += ssim
            num_images += 1

# Calculate average PSNR and SSIM
avg_psnr = total_psnr / num_images
avg_ssim = total_ssim / num_images

print(f'Average PSNR: {avg_psnr}')
print(f'Average SSIM: {avg_ssim}')

  ssim = compare_ssim(SuperResolution_np, hr_np, multichannel=True)


Average PSNR: 25.400591350396446
Average SSIM: 0.8580290584552883
