In [1]:
import os
import random
import torch
import wandb
from torch import nn
from PIL import Image
from glob import glob
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import collections.abc
from IPython.display import clear_output            

collections.Iterable = collections.abc.Iterable

In [2]:
DATA_PATH = 'data'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 32
wandb.init(project='imej-denoiser', name='DAE')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mnovebritito[0m. Use [1m`wandb login --relogin`[0m to force relogin


# LOAD IMAGE

In [3]:
# No class for this dataset
class NonClassDataLoader(Dataset):
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.transforms = transforms
        self.image_paths = glob(os.path.join(root_dir, '*.jpg')) + \
                   glob(os.path.join(root_dir, '*.jpeg')) + \
                   glob(os.path.join(root_dir, '*.png'))
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transforms:
            image = self.transforms(image)
        return image

In [4]:
# Define the named functions
def add_gaussian_noise(x):
    return x + 0.08 * torch.randn_like(x)

def normalize_tensor(x):
    return torch.clamp(x, 0., 1.)

def resize_tensor(x):
    return F.interpolate(x.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)

# Update the transform with named functions
transform = transforms.Compose([
    transforms.Resize((256, 256)), # resize image
    transforms.ToTensor(), # convert image to tensor
    transforms.Lambda(add_gaussian_noise), # add gaussian noise
    transforms.Lambda(normalize_tensor), # normalize tensors to [0, 1]
])

# Load data from folder
noised_data = NonClassDataLoader(root_dir=DATA_PATH, transforms=transform)
clean_data = NonClassDataLoader(root_dir=DATA_PATH, transforms=transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]))

train_data = DataLoader(
    dataset=noised_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
)

clean_data = DataLoader(
    dataset=clean_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
)

# Modelling

## DAE (Denoising AutoEncoder)

In [5]:
class DenoisingAutoencoder(nn.Module):
    def __init__(self, input_channels=3):
        super(DenoisingAutoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1),  #  256x256 -> 128x128
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),              # 128x128 -> 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),             # 64x64 -> 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),            # 32x32 -> 16x16
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # 16x16 -> 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),   # 32x32 -> 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # 64x64 -> 128x128
            nn.Sigmoid(),  # Normalize output between 0 and 1
            nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1),  # 128x128 -> 256x256
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Train

In [6]:
def denormalize(tensor):
    """Function to denormalize tensor values back to [0, 1]"""
    return tensor.clamp(0., 1.)

def plot_images(noisy, denoised, epoch, loss):
    """Plot images for each epoch"""
    noisy = denormalize(noisy).cpu().numpy().transpose(0, 2, 3, 1)
    denoised = denormalize(denoised).cpu().numpy().transpose(0, 2, 3, 1)

    # Clear the figure to avoid overlapping
    plt.clf()
    
    # Create a plot to show noisy, denoised, and original images
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # title for the plot
    plt.suptitle(f"Epoch: {epoch},  loss: {loss.item():.4f}")
    
    # Noisy image   
    axes[0].imshow(noisy[0])  
    axes[0].set_title("Noisy Image")
    
    # Denoised image (model output)
    axes[1].imshow(denoised[0])  
    axes[1].set_title("Denoised Image")

    plt.tight_layout()
    plt.show()
    clear_output(wait=True)  

In [7]:
epochs = 300
dae_model = DenoisingAutoencoder(input_channels=3).to(DEVICE)
optimizer = torch.optim.Adam(dae_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
criterion = nn.MSELoss()
dae_model.train()

DenoisingAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padd

In [8]:
print('START TRAINING')
for epoch in range(epochs):
    for noisy, clean in zip(train_data, clean_data):
        noisy = noisy.to(DEVICE)
        clean = clean.to(DEVICE)
        
        denoised_img = dae_model(noisy)
        loss = criterion(denoised_img, clean)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    scheduler.step()
    
    with torch.no_grad():
        # randomly select a noisy image
        noisy_img = random.choice(noised_data)
        noisy_img = noisy_img.unsqueeze(0).to(DEVICE)
        
        # denoise the noisy image
        denoised_img = dae_model(noisy_img)
        
        # plot the noisy and denoised image
        plot_images(noisy_img, denoised_img, epoch, loss)        
        wandb.log({
            'loss': loss.item(),
            'epoch': epoch,
            'denoised_img': wandb.Image(denoised_img[0].cpu().numpy().transpose(1, 2, 0)),
        })
        
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇██
loss,▇█▇▇▆▅▄▄▄▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
loss,0.00199


In [10]:
# Save the model
torch.save(dae_model.state_dict(), 'vae_model.pth')