<a href="https://colab.research.google.com/github/HANKSOONG/image-repair-and-restoration-/blob/main/U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip "/content/drive/MyDrive/image_output.zip" -d "/content/datasets"

In [None]:
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image

class DeblurDataset(Dataset):
    def __init__(self, denoised_dir, sharp_resized_dir, transform=None):
        self.denoised_dir = denoised_dir
        self.sharp_resized_dir = sharp_resized_dir
        self.transform = transform

        self.filenames = os.listdir(denoised_dir)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]

        denoised_path = os.path.join(self.denoised_dir, filename)
        sharp_resized_path = os.path.join(self.sharp_resized_dir, filename)

        denoised_img = Image.open(denoised_path).convert('RGB')
        sharp_resized_img = Image.open(sharp_resized_path).convert('RGB')

        if self.transform:
            denoised_img = self.transform(denoised_img)
            sharp_resized_img = self.transform(sharp_resized_img)

        return denoised_img, sharp_resized_img, filename

# Normalized
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Creat datasets
dataset = DeblurDataset(denoised_dir='/content/datasets/denoised',
              sharp_resized_dir='/content/datasets/sharp_resized',
              transform=transform)

# Split datasets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Definite dataloader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8)


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


In [None]:
# Definite U-Net
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

#Create U-Net model instance
model = UNet(n_channels=3, n_classes=3)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
import torch.distributed as distance
from torch.cuda.amp import GradScaler, autocast

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#def Perceptual Loss by MobileNet
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.mobilenet = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT).features
        self.mobilenet.eval()
        for param in self.mobilenet.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        input_features = self.mobilenet(input)
        target_features = self.mobilenet(target)
        # Resize if necessary
        if input_features.shape[2:] != target_features.shape[2:]:
            input_features = F.interpolate(input_features, size=target_features.shape[2:], mode='bilinear', align_corners=False)

        loss = nn.functional.mse_loss(input_features, target_features)
        return loss

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Creat model and Adam optimizer
debluring_model = UNet(n_channels=3, n_classes=3).to(device)
mse_criterion = nn.MSELoss()
perceptual_criterion = PerceptualLoss().to(device)
optimizer = torch.optim.Adam(debluring_model.parameters(), lr=0.0001)

# Initialize the ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)

In [None]:
#Defining Loss Function Weights
mse_weight = 1.0
perceptual_weight = 0.1

In [None]:
# Train model
def train_model(model, train_loader, val_loader, mse_criterion, perceptual_criterion, optimizer, num_epochs=100, early_stopping_tolerance=8):
    best_val_loss = float('inf')
    no_improvement_count = 0  # Early stopping counter

    scaler = GradScaler()
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for denoised_img, sharp_resized_img in train_loader:
            denoised_img = denoised_img.to(device)
            sharp_resized_img = sharp_resized_img.to(device)

            optimizer.zero_grad()

           # Performing forward propagation using the autocast context
            with autocast():
                outputs = model(denoised_img)
                mse_loss = mse_criterion(outputs, sharp_resized_img)
                perceptual_loss = perceptual_criterion(outputs, sharp_resized_img)
                total_loss = mse_weight * mse_loss + perceptual_weight * perceptual_loss

            # Performing backward propagation and optimization using GradScaler
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += total_loss.item() * denoised_img.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        # Validation test
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for denoised_img, sharp_resized_img in val_loader:
                denoised_img = denoised_img.to(device)
                sharp_resized_img = sharp_resized_img.to(device)
                outputs = model(denoised_img)
                mse_loss = mse_criterion(outputs, sharp_resized_img)
                perceptual_loss = perceptual_criterion(outputs, sharp_resized_img)
                total_loss = mse_weight * mse_loss + perceptual_weight * perceptual_loss
                val_loss += total_loss.item() * denoised_img.size(0)

        val_loss /= len(val_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        scheduler.step(val_loss)
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvement_count = 0
        else:
            no_improvement_count += 1
            if no_improvement_count >= early_stopping_tolerance:
                print("Stopping early due to no improvement in validation loss")
                break

    return model

train_model(debluring_model, train_loader, val_loader, mse_criterion, perceptual_criterion, optimizer, num_epochs=100, early_stopping_tolerance=8)

In [None]:
# Save model
if torch.cuda.is_available() and torch.cuda.current_device() == 0:
    model_path = '/content/drive/MyDrive/model/UNet/deblured_UNet.pth'
    model_dir = os.path.expanduser(os.path.dirname(model_path))

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    torch.save(debluring_model.state_dict(), os.path.expanduser(model_path))
    print(f"Model saved to {model_path}.")

In [None]:
# Inverse normalization transformation
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.225]
)

In [None]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage import img_as_float
import torch

# Initialize the sum of PSNR and SSIM
total_psnr_sharp = 0
total_ssim_sharp = 0
total_psnr_denoised = 0
total_ssim_denoised = 0
num_images = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = '/content/drive/MyDrive/model/UNet/deblured_UNet4.pth'

model = UNet(n_channels=3, n_classes=3)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

with torch.no_grad():
    for denoised_img, sharp_resized_img in val_loader:
      denoised_img = denoised_img.to(device)
      sharp_resized_img = sharp_resized_img.to(device)

      deblured_img = model(denoised_img)

# Traverse the image to calculate PSNR and SSIM
      for i in range(denoised_img.size(0)):
        deblured = inv_normalize(deblured_img[i]).clamp(0, 1)
        sharp = inv_normalize(sharp_resized_img[i]).clamp(0, 1)
        denoised = inv_normalize(denoised_img[i]).clamp(0, 1)

        deblured_np = deblured.cpu().numpy().transpose(1, 2, 0)
        sharp_np = sharp.cpu().numpy().transpose(1, 2, 0)
        denoised_np = denoised.cpu().numpy().transpose(1, 2, 0)

    # Calculate PSNR and SSIM
        psnr_sharp = compare_psnr(deblured_np, sharp_np)
        ssim_sharp = compare_ssim(deblured_np, sharp_np, multichannel=True)
        psnr_denoised = compare_psnr(deblured_np, denoised_np)
        ssim_denoised = compare_ssim(deblured_np, denoised_np, multichannel=True)

        total_psnr_sharp += psnr_sharp
        total_ssim_sharp += ssim_sharp
        total_psnr_denoised += psnr_denoised
        total_ssim_denoised += ssim_denoised
        num_images += 1

# Calculate average PSNR and SSIM
avg_psnr_sharp = total_psnr_sharp / num_images
avg_ssim_sharp = total_ssim_sharp / num_images
avg_psnr_denoised = total_psnr_denoised / num_images
avg_ssim_denoised = total_ssim_denoised / num_images

print(f'Average PSNR (Deblured vs Sharp): {avg_psnr_sharp}')
print(f'Average SSIM (Deblured vs Sharp): {avg_ssim_sharp}')
print(f'Average PSNR (Deblured vs Denoised): {avg_psnr_denoised}')
print(f'Average SSIM (Deblured vs Denoised): {avg_ssim_denoised}')

In [None]:
import torch
import matplotlib.pyplot as plt

model_path = '/content/drive/MyDrive/model/UNet/deblured_UNet4.pth'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=3, n_classes=3)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# Definite dataloader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8)

output_folder = '/content/drive/MyDrive/image_output/deblured2'

def save_deblurred_images(data_loader, output_folder):
    with torch.no_grad():
        for denoised_img, sharp_resized_img, filename in data_loader:  # Updated to receive three items
            denoised_img = denoised_img.to(device)
            sharp_resized_img = sharp_resized_img.to(device)

            deblured_img = model(denoised_img)

            for i in range(denoised_img.size(0)):
                deblured = inv_normalize(deblured_img[i]).clamp(0, 1)
                deblured_np = deblured.cpu().numpy().transpose(1, 2, 0)

                # Save deblurred image
                plt.imsave(os.path.join(output_folder, filename[i]), deblured_np)


# Process and save images for training and validation datasets
save_deblurred_images(train_loader, output_folder)
save_deblurred_images(val_loader, output_folder)
