In [None]:
!pip install torchmetrics

In [None]:
import os
import zipfile
import random
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image

# Define paths
zip_file_path = '/content/drive/MyDrive/Denoising_Dataset_train_val.zip'
extract_path = '/content/Denoising_Dataset_train_val'

# Step 1: Extract the main zip file (if not already extracted)
if not os.path.exists(extract_path):
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)

# Step 2: Custom Dataset to handle nested folder structure and mask naming convention
class DenoisingDataset(Dataset):
    def __init__(self, root_dir, data_type='Train', transform=None):
        """
        root_dir : Directory where the dataset is located
        data_type : 'Train' or 'Val' to specify which set to load
        transform : Any transformations to be applied on the images
        """
        self.root_dir = root_dir
        self.data_type = data_type
        self.transform = transform

        # Prepare lists to hold the file paths of clean, degraded, and mask images
        self.clean_images = []
        self.degraded_images = []
        self.mask_images = []

        # Step 3: Traverse through the nested folder structure
        for category in os.listdir(root_dir):  # First level (e.g., 'bottle', 'cable')
            category_path = os.path.join(root_dir, category, self.data_type)  # Choose Train or Val

            gt_clean_image_path = os.path.join(category_path, 'GT_clean_image')
            degraded_image_path = os.path.join(category_path, 'Degraded_image')
            defect_mask_path = os.path.join(category_path, 'Defect_mask')

            # Subfolders within 'GT_clean_image', 'Degraded_image', 'Defect_mask' (e.g., 'broken_large', 'broken_small')
            for subfolder in os.listdir(gt_clean_image_path):
                clean_subfolder = os.path.join(gt_clean_image_path, subfolder)
                degraded_subfolder = os.path.join(degraded_image_path, subfolder)
                mask_subfolder = os.path.join(defect_mask_path, subfolder)

                # For each image in the subfolder, gather the corresponding file paths
                for img_name in os.listdir(clean_subfolder):
                    clean_img = os.path.join(clean_subfolder, img_name)
                    degraded_img = os.path.join(degraded_subfolder, img_name)

                    # Construct the corresponding mask file name (e.g., '000.png' -> '000_mask.png')
                    img_base_name = os.path.splitext(img_name)[0]  # Get base name without extension
                    mask_img = os.path.join(mask_subfolder, f'{img_base_name}_mask.png')

                    # Append to lists
                    self.clean_images.append(clean_img)
                    self.degraded_images.append(degraded_img)
                    self.mask_images.append(mask_img)

    def __len__(self):
        return len(self.clean_images)

    def __getitem__(self, idx):
        clean_img = Image.open(self.clean_images[idx]).convert('RGB')
        degraded_img = Image.open(self.degraded_images[idx]).convert('RGB')
        mask_img = Image.open(self.mask_images[idx]).convert('L')  # Mask as grayscale

        if self.transform:
            clean_img = self.transform(clean_img)
            degraded_img = self.transform(degraded_img)
            mask_img = self.transform(mask_img)

        return clean_img, degraded_img, mask_img

# Step 4: Define transformations (resizing to a uniform size)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize all images to the same size
    transforms.ToTensor()           # Convert images to PyTorch tensors
])

# Step 5: Load training and validation datasets
data_path = os.path.join(extract_path, 'Denoising_Dataset_train_val')
train_dataset = DenoisingDataset(root_dir=data_path, data_type='Train', transform=transform)
val_dataset = DenoisingDataset(root_dir=data_path, data_type='Val', transform=transform)

print(f"Number of images in the training set: {len(train_dataset)}")
print(f"Number of images in the validation set: {len(val_dataset)}")

# Step 6: Create DataLoader for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

# Step 7: Function to plot random images (clean, degraded, mask)
def plot_random_images(dataset, n_images=5):
    fig, axes = plt.subplots(n_images, 3, figsize=(15, 15))  # 3 columns for clean, degraded, mask
    for i in range(n_images):
        clean_img, degraded_img, mask_img = random.choice(dataset)
        clean_img_np = clean_img.permute(1, 2, 0).numpy()
        degraded_img_np = degraded_img.permute(1, 2, 0).numpy()
        mask_img_np = mask_img.squeeze(0).numpy()  # Grayscale, so remove channel dimension

        axes[i, 0].imshow(clean_img_np)
        axes[i, 0].set_title('Clean Image')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(degraded_img_np)
        axes[i, 1].set_title('Degraded Image')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(mask_img_np, cmap='gray')
        axes[i, 2].set_title('Mask Image')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# Plot random images from the train and validation datasets
print("Random images from the training set:")
plot_random_images(train_dataset)

print("Random images from the validation set:")
plot_random_images(val_dataset)


# Training and Validation Dataloader

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset

class DenoisingDataset(Dataset):
    def __init__(self, root_dir, data_type='Train', transform=None):
        """
        root_dir : Directory where the dataset is located
        data_type : 'Train' or 'Val' to specify which set to load
        transform : Any transformations to be applied on the images
        """
        self.root_dir = root_dir
        self.data_type = data_type
        self.transform = transform

        # Prepare lists to hold the file paths of clean and degraded images
        self.clean_images = []
        self.degraded_images = []

        # Step 3: Traverse through the nested folder structure
        for category in os.listdir(root_dir):  # First level (e.g., 'bottle', 'cable')
            category_path = os.path.join(root_dir, category, self.data_type)  # Choose Train or Val

            gt_clean_image_path = os.path.join(category_path, 'GT_clean_image')
            degraded_image_path = os.path.join(category_path, 'Degraded_image')

            # Subfolders within 'GT_clean_image' and 'Degraded_image' (e.g., 'broken_large', 'broken_small')
            for subfolder in os.listdir(gt_clean_image_path):
                clean_subfolder = os.path.join(gt_clean_image_path, subfolder)
                degraded_subfolder = os.path.join(degraded_image_path, subfolder)

                # For each image in the subfolder, gather the corresponding file paths
                for img_name in os.listdir(clean_subfolder):
                    clean_img = os.path.join(clean_subfolder, img_name)
                    degraded_img = os.path.join(degraded_subfolder, img_name)

                    # Append to lists
                    self.clean_images.append(clean_img)
                    self.degraded_images.append(degraded_img)

    def __len__(self):
        return len(self.clean_images)

    def __getitem__(self, idx):
        clean_img = Image.open(self.clean_images[idx]).convert('RGB')
        degraded_img = Image.open(self.degraded_images[idx]).convert('RGB')

        if self.transform:
            clean_img = self.transform(clean_img)
            degraded_img = self.transform(degraded_img)

        return degraded_img, clean_img # Return only clean and degraded images

# Step 4: Define transformations (resizing to a uniform size)
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize all images to the same size
    transforms.ToTensor()           # Convert images to PyTorch tensors
])          # Convert images to PyTorch tensors

# Step 5: Load training and validation datasets
data_path = os.path.join(extract_path, 'Denoising_Dataset_train_val')
train_dataset = DenoisingDataset(root_dir=data_path, data_type='Train', transform=transform)
val_dataset = DenoisingDataset(root_dir=data_path, data_type='Val', transform=transform)

print(f"Number of images in the training set: {len(train_dataset)}")
print(f"Number of images in the validation set: {len(val_dataset)}")

# Step 6: Create DataLoader for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)


# Model Architecture

In [None]:
import torch
import torch.nn as nn

#Proposed_U-net

class double_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(double_conv, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None

    def forward(self, x):
        out = self.conv_block(x)
        residual = self.residual_conv(x) if self.residual_conv else x
        return out + residual

class SCA(nn.Module):
    def __init__(self, in_channels):
        super(SCA, self).__init__()
        self.spatial_attention = nn.Conv2d(in_channels, in_channels, kernel_size=1)  # Pointwise Convolution for Channel Attention
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // 16, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 16, in_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Spatial Attention
        spatial_out = self.spatial_attention(x)
        spatial_out = torch.sigmoid(spatial_out)

        # Channel Attention
        channel_out = self.channel_attention(x)

        # Element-wise multiplication for attention fusion
        fused_out = x * spatial_out * channel_out

        return fused_out

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super(NonLocalBlock, self).__init__()
        self.theta = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.phi = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.g = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.W = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)

    def forward(self, x):
        batch_size, C, H, W = x.size()

        # Compute theta, phi, g
        theta_x = self.theta(x).view(batch_size, C // 2, -1)
        phi_x = self.phi(x).view(batch_size, C // 2, -1)
        g_x = self.g(x).view(batch_size, C // 2, -1)

        # Compute attention map
        f = torch.matmul(theta_x.permute(0, 2, 1), phi_x)  # (B, H*W, H*W)
        f_div_C = f / (H * W)  # Normalize
        y = torch.matmul(f_div_C, g_x.permute(0, 2, 1))  # (B, H*W, C//2)
        y = y.view(batch_size, C // 2, H, W)  # Reshape

        # Combine with original features
        y = self.W(y)
        return x + y  # Residual connection

class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.dconv_down1 = double_conv(3, 16)
        self.att1 = SCA(16)

        self.dconv_down2 = double_conv(16, 32)
        self.att2 = SCA(32)

        self.dconv_down3 = double_conv(32, 64)
        self.att3 = SCA(64)

        self.dconv_down4 = double_conv(64, 128)
        self.att4 = SCA(128)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_bottleneck = double_conv(128, 128)  # Bottleneck convolution
        self.nl_block = NonLocalBlock(128)  # Non-Local Block

        self.dconv_up3 = double_conv(64 + 128, 64)
        self.dconv_up2 = double_conv(32 + 64, 32)
        self.dconv_up1 = double_conv(16 + 32, 16)

        self.conv_last = nn.Conv2d(16, n_class, 1)

    def forward(self, x):
        x_input = x
        conv1 = self.dconv_down1(x)
        conv1s = self.att1(conv1)  # Apply attention to skip connection
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        conv2s = self.att2(conv2)  # Apply attention to skip connection
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        conv3s = self.att3(conv3)  # Apply attention to skip connection
        x = self.maxpool(conv3)

        x = self.dconv_down4(x)
        x = self.att4(x)  # Apply attention to the bottleneck features

        # Bottleneck with Non-Local Block
        x = self.dconv_bottleneck(x)
        x = self.nl_block(x)  # Apply Non-Local Block

        x = self.upsample(x)
        x = torch.cat([x, conv3s], dim=1)
        x = self.dconv_up3(x)

        x = self.upsample(x)
        x = torch.cat([x, conv2s], dim=1)
        x = self.dconv_up2(x)

        x = self.upsample(x)
        x = torch.cat([x, conv1s], dim=1)
        x = self.dconv_up1(x)

        out = self.conv_last(x)
        out = out + x_input

        return out

# Example usage
n_classes = 3
model = UNet(n_class=n_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # Move the model to GPU if available

# Example input (batch size of 1, 3 channels, 256x256 image)
x = torch.randn(1, 3, 256, 256).to(device)
output = model(x)
print(output.shape)  # Should output: torch.Size([1, 3, 256, 256])

# Calculate the number of parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

# Loss functions

In [None]:
import torch
import torch.nn as nn
from torchmetrics import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio

class PSNRLoss(nn.Module):
    """PSNR Loss based on Peak Signal-to-Noise Ratio."""

    def __init__(self, data_range=1.0):
        super(PSNRLoss, self).__init__()
        self.psnr = PeakSignalNoiseRatio(data_range=data_range)

    def forward(self, x, y):
        # Ensure inputs are on the same device
        x = torch.tensor(x, device=y.device) if isinstance(x, list) else x
        y = torch.tensor(y, device=x.device) if isinstance(y, list) else y

        # Compute PSNR
        psnr_value = self.psnr(x, y)

        # Convert PSNR metric to loss
        loss = -psnr_value  # Negate because higher PSNR is better
        return loss

class SSIMLoss(nn.Module):
    """SSIM Loss based on Structural Similarity Index Measure."""

    def __init__(self, data_range=1.0):
        super(SSIMLoss, self).__init__()
        self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range)

    def forward(self, x, y):
        """Calculate SSIM loss between two images."""
        # Ensure inputs are on the same device
        x = torch.tensor(x, device=y.device) if isinstance(x, list) else x
        y = torch.tensor(y, device=x.device) if isinstance(y, list) else y

        # Compute SSIM
        ssim_value = self.ssim(x, y)

        # Convert SSIM metric to loss
        loss = 1 - ssim_value  # Negate because higher SSIM is better
        return loss

# Example of usage:
# ssim_loss = SSIMLoss()
# loss_value = ssim_loss(image1, image2)



class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        '''if isinstance(x, list):
            x = torch.tensor(x, device=y.device)
        else:
            x = x.to(y.device)

        if isinstance(y, list):
            y = torch.tensor(y, device=x.device)
        else:
            y = y.to(x.device)'''

        x = torch.tensor(x, device=y.device) if isinstance(x, list) else x
        y = torch.tensor(y, device=x.device) if isinstance(y, list) else y
        diff = x - y
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
        return loss

class EdgeLoss(nn.Module):
    def __init__(self):
        super(EdgeLoss, self).__init__()
        k = torch.Tensor([[.05, .25, .4, .25, .05]])
        self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
        if torch.cuda.is_available():
            self.kernel = self.kernel.cuda()
        self.loss = CharbonnierLoss()

    def conv_gauss(self, img):
        n_channels, _, kw, kh = self.kernel.shape
        img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
        return F.conv2d(img, self.kernel, groups=n_channels)

    def laplacian_kernel(self, current):
        # Ensure current has four dimensions: [batch, channels, height, width]
        if current.dim() == 3:
            current = current.unsqueeze(0)

        filtered = self.conv_gauss(current)  # Apply Gaussian filter

        # Perform downsampling and upsampling with careful indexing
        down = filtered[:, :, ::2, ::2]  # Downsample
        new_filter = torch.zeros_like(filtered)
        new_filter[:, :, ::2, ::2] = down * 4  # Upsample with zero padding

        filtered = self.conv_gauss(new_filter)  # Apply Gaussian filter again
        diff = current - filtered
        return diff

    def forward(self, x, y):
        loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
        return loss


# Training

In [None]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm  # For progress bars
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure  # Import metrics

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize your model, loss function, optimizer, and metrics
model = UNet(n_class=n_classes).to(device)  # Ensure the model is on the same device
criterion_char = CharbonnierLoss().to(device)  # Ensure criterion is on the same device if necessary
criterion_ssim = SSIMLoss().to(device)
criterion_edge = EdgeLoss().to(device)  # Ensure criterion is on the same device if necessary
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize metrics for PSNR and SSIM
psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

# Define the number of epochs and paths for saving the model
num_epochs = 250
best_ssim = 0.0
best_model_path = 'best_model.pth'

# Lists to store the training and validation losses and metrics
train_losses = []
val_losses = []
train_psnr_scores = []
train_ssim_scores = []
val_psnr_scores = []
val_ssim_scores = []

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_train_loss = 0.0
    running_train_psnr = 0.0
    running_train_ssim = 0.0

    # Training phase
    for degraded_imgs, clean_imgs in tqdm(train_loader, desc=f'Training Epoch {epoch + 1}/{num_epochs}', leave=False):
        # Move images to the device
        degraded_imgs, clean_imgs = degraded_imgs.cuda(), clean_imgs.cuda()

        optimizer.zero_grad()  # Clear gradients
        outputs = model(degraded_imgs)  # Forward pass


        # Calculate losses
        ''''loss_char = torch.sum(torch.stack([criterion_char(outputs[j], clean_imgs) for j in range(len(outputs))]))
        loss_edge = torch.sum(torch.stack([criterion_edge(outputs[j], clean_imgs) for j in range(len(outputs))]))
        loss_ssim = torch.sum(torch.stack([criterion_ssim(outputs[j], clean_imgs) for j in range(len(outputs))]))  # Negate PSNR
        loss = loss_char + (0.1 * loss_edge) + (0.3*loss_ssim)'''
        loss_char = criterion_char(outputs, clean_imgs)  # Outputs shape: [batch_size, channels, height, width]
        loss_edge = criterion_edge(outputs, clean_imgs)
        loss_ssim = criterion_ssim(outputs, clean_imgs)
        loss = loss_char + (0.1 * loss_edge) + (0.3*loss_ssim)

        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        running_train_loss += loss.item()

        # Calculate metrics
        psnr = psnr_metric(outputs, clean_imgs)  # Use the final output for PSNR and SSIM
        ssim = ssim_metric(outputs, clean_imgs)
        running_train_psnr += psnr.item()
        running_train_ssim += ssim.item()

    avg_train_loss = running_train_loss / len(train_loader)
    avg_train_psnr = running_train_psnr / len(train_loader)
    avg_train_ssim = running_train_ssim / len(train_loader)
    train_losses.append(avg_train_loss)
    train_psnr_scores.append(avg_train_psnr)
    train_ssim_scores.append(avg_train_ssim)

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    running_val_loss = 0.0
    running_val_psnr = 0.0
    running_val_ssim = 0.0

    with torch.no_grad():  # No gradients required during validation
        for degraded_imgs, clean_imgs in tqdm(val_loader, desc=f'Validating Epoch {epoch + 1}/{num_epochs}', leave=False):
            degraded_imgs, clean_imgs = degraded_imgs.to(device), clean_imgs.to(device)
            outputs = model(degraded_imgs)

            # Calculate losses
            loss_char = criterion_char(outputs, clean_imgs)  # Outputs shape: [batch_size, channels, height, width]
            loss_edge = criterion_edge(outputs, clean_imgs)
            loss_ssim = criterion_ssim(outputs, clean_imgs)
            loss = loss_char + (0.1 * loss_edge) + (0.3*loss_ssim)
            #loss = loss_char + (0.05 * loss_edge)
            running_val_loss += loss.item()

            # Calculate metrics
            psnr = psnr_metric(outputs, clean_imgs)
            ssim = ssim_metric(outputs, clean_imgs)
            running_val_psnr += psnr.item()
            running_val_ssim += ssim.item()

    avg_val_loss = running_val_loss / len(val_loader)
    avg_val_psnr = running_val_psnr / len(val_loader)
    avg_val_ssim = running_val_ssim / len(val_loader)
    val_losses.append(avg_val_loss)
    val_psnr_scores.append(avg_val_psnr)
    val_ssim_scores.append(avg_val_ssim)

    # Print the average losses and metrics for this epoch
    print(f"Epoch [{epoch + 1}/{num_epochs}] - Train Loss: {avg_train_loss:.4f}, "
          f"Validation Loss: {avg_val_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, "
          f"Validation PSNR: {avg_val_psnr:.4f}, Train SSIM: {avg_train_ssim:.4f}, "
          f"Validation SSIM: {avg_val_ssim:.4f}")

    # Save the best model based on validation loss
    if avg_val_ssim > best_ssim:
        best_ssim = avg_val_ssim
        torch.save(model.state_dict(), best_model_path)  # Save the model

# Plotting the training and validation losses and metrics
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss', color='blue')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', color='orange')
plt.title('Training and Validation Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs + 1), train_psnr_scores, label='Training PSNR', color='blue')
plt.plot(range(1, num_epochs + 1), val_psnr_scores, label='Validation PSNR', color='orange')
plt.plot(range(1, num_epochs + 1), train_ssim_scores, label='Training SSIM', color='green')
plt.plot(range(1, num_epochs + 1), val_ssim_scores, label='Validation SSIM', color='red')
plt.title('Training and Validation PSNR/SSIM over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Score')
plt.legend()
plt.grid()
plt.tight_layout()
plt.savefig('training_validation_metrics_curve.png')  # Save the figure
plt.show()
