In [1]:
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import joblib
import pickle

from sklearn.model_selection import train_test_split

!pip install pytorch-msssim
from pytorch_msssim import ssim
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset, Dataset, random_split

import torchvision.transforms as transforms
import torchvision.datasets as datasets

Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-1.0.0


In [None]:
def salt_and_pepper(image_tensor, noise_level=0.05):
    """
    Apply salt and pepper noise to an image tensor.
    
    Parameters:
    - image_tensor (torch.Tensor): A torch tensor representing the image.
    - noise_level (float): The percentage of image pixels to be affected by noise.

    Returns:
    - torch.Tensor: Image tensor with salt and pepper noise added.
    """
    noisy_image = image_tensor.clone()
    num_noisy_pixels = int(noise_level * image_tensor.nelement())
    salt_indices = torch.randperm(image_tensor.nelement())[:num_noisy_pixels // 2]
    noisy_image.view(-1)[salt_indices] = 1.0
    pepper_indices = torch.randperm(image_tensor.nelement())[:num_noisy_pixels // 2]
    noisy_image.view(-1)[pepper_indices] = 0.0
    return noisy_image

def load_images(directory, target_size=(256, 256), t = ".png"):
    """
    Load images from dir
    
    Parameters:
    - directory (str): directory to get data
    - target_size (tuple): expected size of images without channels

    Returns:
    - data (torch.Tensor): A torch stack with the data
    """
    images = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(t):
                image_path = os.path.join(root, file)
                image = Image.open(image_path).convert('RGB')
                transform = transforms.Compose([
                    transforms.Resize(target_size),
                    transforms.ToTensor(), 
                ])
                image_tensor = transform(image)

                if image_tensor.max() > 1:
                    image_tensor = image_tensor / 255.0

                images.append(image_tensor)
    return torch.stack(images)

def tensor_to_image(tensor):
    """
    Apply to tensor to be print in plt
    
    Parameters:
    - image_tensor (torch.Tensor): A torch tensor representing the image.

    Returns:
    - numpy (np.array): A np array representing the image.
    """
    if tensor.max() > 1:
        tensor = tensor / 255.0

    return tensor.cpu().detach().numpy().transpose(1, 2, 0)

class PairedDataset(Dataset):
    def __init__(self, data_clean, data_noisy):
        self.data_clean = data_clean
        self.data_noisy = data_noisy

    def __len__(self):
        # Assuming both datasets have the same length
        return len(self.data_clean)

    def __getitem__(self, idx):
        clean_image = self.data_clean[idx]
        noisy_image = self.data_noisy[idx]

        return noisy_image, clean_image

class DenoisingAutoencoder(nn.Module):
    def __init__(self):
        super(DenoisingAutoencoder, self).__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.ReLU(True))
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.ReLU(True))
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(256),
            nn.ReLU(True))
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(512),
            nn.ReLU(True))
        self.enc5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(512),
            nn.ReLU(True))

        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Dropout(0.5))
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1), # the 1024 comes from concatenation
            nn.BatchNorm2d(256),
            nn.ReLU(True))
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.ReLU(True))
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.ReLU(True))
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1), 
            nn.Tanh())

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)

        # Decoder with skip connections
        d1 = self.dec1(e5)
        d1 = torch.cat((d1, e4), dim=1)  # skip connection
        d2 = self.dec2(d1)
        d2 = torch.cat((d2, e3), dim=1)  # skip connection
        d3 = self.dec3(d2)
        d3 = torch.cat((d3, e2), dim=1)  # skip connection
        d4 = self.dec4(d3)
        d4 = torch.cat((d4, e1), dim=1)  # skip connection
        d5 = self.dec5(d4)
        return d5

In [None]:
train_norm = load_images("/kaggle/input/traffic-signs/GTSRB/train/00001")

train_snp = []

noise_level = 0.5
for image in train_norm:
    i = salt_and_pepper(image, noise_level=noise_level)
    train_snp.append(i)

train_snp = torch.stack(train_snp)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.imshow(tensor_to_image(data[0]))
plt.title(f'Data {data[0].dtype}')
plt.subplot(1, 2, 2)
plt.imshow(tensor_to_image(data_snp[0]))
plt.title(f'SNP {data_snp[0].dtype}')
plt.tight_layout()
plt.show()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

paired_dataset = PairedDataset(train_norm, train_snp)
paired_loader = DataLoader(paired_dataset, batch_size=32, shuffle=True)

def ssim_loss(y_true, y_pred):
    return 1 - ssim(y_true, y_pred, data_range=1, size_average=True)

model = DenoisingAutoencoder().to(device)
criterion = ssim_loss
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for noisy_images, clean_images in paired_loader:
        noisy_images, clean_images = noisy_images.to(device), clean_images.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(noisy_images)
        loss = criterion(outputs, clean_images)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(paired_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}')
    
for param in model.parameters():
    if torch.isnan(param).any() or torch.isinf(param).any():
        print("Model contains NaN or inf values")
        
joblib.dump(model, f'/kaggle/celebA/model_{noise_level}.pkl')

In [None]:
with open('/kaggle/working/model_0.5.pkl', 'rb') as f:
    model = pickle.load(f)

In [None]:
val_norm = load_images("/kaggle/input/traffic-signs/GTSRB/test/00001")

val_snp = [] 

for image in val:
    i = salt_and_pepper(image, noise_level=1)
    val_snp.append(i)

val_snp = torch.stack(val_snp)

test_loader = DataLoader(val_snp, batch_size=32, shuffle=False)
model.eval()
model.to(device)
first_batch = True

with torch.no_grad():
    for inputs in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)

        if first_batch:
            inputs_cpu = inputs.cpu()
            outputs_cpu = outputs.cpu()
            fig, axs = plt.subplots(3, 3, figsize=(15, 6))
            for i in range(3):
                # Input images
                axs[0, i].imshow(val[i].permute(1, 2, 0).squeeze())
                axs[0, i].title.set_text('Input Image ' + str(i+1))
                axs[0, i].axis('off')
                
                # Input images
                axs[1, i].imshow(inputs_cpu[i].permute(1, 2, 0).squeeze())
                axs[1, i].title.set_text('Noisy Image ' + str(i+1))
                axs[1, i].axis('off')
                
                # Output images
                axs[2, i].imshow(outputs_cpu[i].permute(1, 2, 0).squeeze())
                axs[2, i].title.set_text('Output Image ' + str(i+1))
                axs[2, i].axis('off')
            
            plt.show()
            first_batch = False


In [None]:
val_norm = val_norm

val_snp = [] 

for image in val_norm:
    i = salt_and_pepper(image, noise_level=1)
    val_snp.append(i)

val_snp = torch.stack(val_snp)

test_loader = DataLoader(val_snp, batch_size=32, shuffle=False)
model.eval()
model.to(device)
first_batch = True

with torch.no_grad():
    for inputs in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)

        if first_batch:
            inputs_cpu = inputs.cpu()
            outputs_cpu = outputs.cpu()
            fig, axs = plt.subplots(3, 3, figsize=(15, 6))
            for i in range(3):
                # Input images
                axs[0, i].imshow(val[i].permute(1, 2, 0).squeeze())
                axs[0, i].title.set_text('Input Image ' + str(i+1))
                axs[0, i].axis('off')
                
                # Input images
                axs[1, i].imshow(inputs_cpu[i].permute(1, 2, 0).squeeze())
                axs[1, i].title.set_text('Noisy Image ' + str(i+1))
                axs[1, i].axis('off')
                
                # Output images
                axs[2, i].imshow(outputs_cpu[i].permute(1, 2, 0).squeeze())
                axs[2, i].title.set_text('Output Image ' + str(i+1))
                axs[2, i].axis('off')
            
            plt.show()
            first_batch = False
