In [2]:
import torch
import math
from pytorch_wavelets import DWTForward, DWTInverse

class ImageAttentionBlock(torch.nn.Module):
    def __init__(self, hidden_dim, image_size, num_head=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.image_size = image_size
        self.layer_norm = torch.nn.LayerNorm([hidden_dim])
        self.att = torch.nn.MultiheadAttention(hidden_dim, num_head, batch_first=True)
        self.feed_forward = torch.nn.Sequential(
            torch.nn.LayerNorm([hidden_dim]),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x):  # x: [B, C, H, W]
        x = x.reshape(-1, self.hidden_dim, self.image_size * self.image_size).transpose(1, 2)  # [B, H*W, C]
        x_norm = self.layer_norm(x)
        attention_value, _ = self.att(x_norm, x_norm, x_norm)
        x = x + attention_value
        x = x + self.feed_forward(x)
        x = x.transpose(1, 2).reshape(-1, self.hidden_dim, self.image_size, self.image_size)
        return x

class BottleNeckBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, groups=1):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups)
        self.norm = torch.nn.GroupNorm(groups, out_channels)
        self.act = torch.nn.GELU()

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x

class BaseConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, groups=1):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.residual = residual
        self.conv1 = torch.nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False, groups=groups)
        self.norm1 = torch.nn.GroupNorm(groups, mid_channels)
        self.act1 = torch.nn.GELU()
        self.conv2 = torch.nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False, groups=groups)
        self.norm2 = torch.nn.GroupNorm(groups, out_channels)
    
    def forward(self, x):
        if self.residual:
            residual = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        if self.residual:
            x = torch.nn.functional.gelu(x + residual)
        return x


class WaveDownSampleBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, wave='haar', groups=1):
        super().__init__()
        self.dwt = DWTForward(J=1, wave=wave, mode='symmetric')
        self.conv = torch.nn.Sequential(
            BottleNeckBlock(in_channels * 4, in_channels, groups),  # Adjust Channels
            BaseConvBlock(in_channels, in_channels, residual=True, groups=groups),
            BaseConvBlock(in_channels, out_channels, groups=groups),
        )

    def forward(self, x):
        b, c, h, w = x.shape
        xl, xh = self.dwt(x)
        h = h//2
        w = w//2
        xh = xh[0][:, :, :, :h, :w]
        xl = xl[:, :, :h, :w]
        b, c, _, h, w = xh.shape
        xh = xh.reshape(b, 3 * c, h, w)
        x = torch.cat([xl, xh], dim=1)
        x = self.conv(x)
        return x
    

class WaveUpSampleBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, wave='haar', groups=1):
        super().__init__()
        self.bottle = BottleNeckBlock(in_channels // 2, 2 * in_channels, groups)
        self.idwt = DWTInverse(wave=wave, mode='symmetric')
        self.conv = torch.nn.Sequential(
            BaseConvBlock(in_channels, in_channels, residual=True), 
            BaseConvBlock(in_channels, out_channels, in_channels // 2)
        )

    def forward(self, x, skip):  # Skip Connection
        b, c, h, w = x.shape
        x = self.bottle(x)
        xl = x[:, :c]
        xh = x[:, c:].reshape(b, c, 3, h, w)
        h *= 2
        w *= 2
        x = self.idwt((xl, [xh]))
        x = torch.nn.functional.pad(x, (0, h - x.shape[2], 0, w - x.shape[3]), 'reflect')
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x
    

class WaveUnet(torch.nn.Module):
    def __init__(self, image_size, image_channels, wave='haar', beta_range=(1e-4, 0.02), groups=1, device="cuda"):
        super().__init__()
        self.beta_range = beta_range
        self.image_size = image_size
        self.device = device
        self.in_conv = BaseConvBlock(image_channels, 64)
        self.down1 = WaveDownSampleBlock(64, 128, wave=wave, groups=groups)
        self.down2 = WaveDownSampleBlock(128, 256, wave=wave, groups=groups)
        self.att1 = ImageAttentionBlock(256, image_size // 4)
        self.down3 = WaveDownSampleBlock(256, 256, wave=wave, groups=groups)
        self.att2 = ImageAttentionBlock(256, image_size // 8)
        self.up1 = WaveUpSampleBlock(512, 128, wave=wave, groups=groups)
        self.att3 = ImageAttentionBlock(128, image_size // 4)
        self.up2 = WaveUpSampleBlock(256, 64, wave=wave, groups=groups)
        self.up3 = WaveUpSampleBlock(128, 64, wave=wave, groups=groups)
        self.out_conv = torch.nn.Conv2d(64, image_channels, kernel_size=1)
    
    def forward(self, x):
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x3 = self.att1(x3)
        x4 = self.down3(x3)
        x4 = self.att2(x4)
        x = self.up1(x4, x3)
        x = self.att3(x)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.out_conv(x)
        return x

In [2]:
from utils import train, test
from BSD import BSDDataset
import numpy as np

torch.manual_seed(4623)
torch.cuda.manual_seed(4623)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_size = 256
epochs = 100
batch_size = 4
lr = 1e-4
criterion = torch.nn.MSELoss()

base_dir = ""

train_set = BSDDataset(base_dir=base_dir, split="train")
test_set = BSDDataset(base_dir=base_dir, split="test")


def rgb_to_ycbcr(img):
    M = torch.tensor([[ 0.2990,     0.5870,     0.1140    ],
                      [-0.168736, -0.331264,    0.5       ],
                      [ 0.5,      -0.418688,   -0.081312  ]], dtype=img.dtype, device=img.device)
    bias = torch.tensor([0.0, 0.5, 0.5], dtype=img.dtype, device=img.device).view(3, 1, 1)
    ycbcr = torch.einsum('bchw,mc->bmhw', img, M) + bias
    return ycbcr

def ycbcr_to_rgb(ycbcr):
    M_inv = torch.tensor([[1.0,  0.0,       1.402],
                          [1.0, -0.344136, -0.714136],
                          [1.0,  1.772,    0.0     ]], dtype=ycbcr.dtype, device=ycbcr.device)
    bias = torch.tensor([0.0, -0.5, -0.5], dtype=ycbcr.dtype, device=ycbcr.device).view(3, 1, 1)
    rgb = torch.einsum('bchw,mc->bmhw', ycbcr + bias, M_inv)
    return rgb

def compute_loss(model, images, noise_range=(1, 60), residual=False, luminance=False):
    noise_level = torch.randint(noise_range[0], noise_range[1], (1,)).item()
    noisy_images = images + (noise_level / 255.0) * torch.randn_like(images)
    images = images.to(device)  # move to GPU
    noisy_images = torch.clamp(noisy_images, 0, 1)
    noisy_images = noisy_images.to(device)
    if luminance:
        images = rgb_to_ycbcr(images)[:, :1]
        noisy_images = rgb_to_ycbcr(noisy_images)[:, :1]
    outputs = model(noisy_images)  # forward
    outputs = outputs
    if residual:
        loss = criterion(outputs, noisy_images - images)
        return loss
    loss = criterion(outputs, images)
    return loss


def denoise(model, noisy_img, residual=False, luminance=False):
    if luminance:
        target_img = rgb_to_ycbcr(noisy_img)
        noisy_img = target_img[:, :1]
    if residual:
        outputs = noisy_img - model(noisy_img)
    else:
        outputs = model(noisy_img)  # forward
    if luminance:
        outputs = torch.cat([outputs, target_img[:, 1:]], dim=1)
        outputs = ycbcr_to_rgb(outputs)
    outputs = torch.clamp(outputs, 0, 1)
    return outputs

In [3]:
def experiment(model_name, residual=False, luminance=False, groups=False):
    if luminance:
        model = WaveUnet(256, 1, groups=4 if groups else 1).to(device)
    else:
        model = WaveUnet(256, 3, groups=4 if groups else 1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, residual=residual, luminance=luminance)
    for noise_level in [10, 25, 50]:
        test(model, test_set, batch_size, model_name, noise_level, denoise=denoise, residual=residual, luminance=luminance)

In [None]:
pip install PyWavelets pytorch_wavelets scikit-image opencv-python-headless==4.5.3.56

Note: you may need to restart the kernel to use updated packages.


All Combined

In [None]:
model_name = "WaveUnet"
residual = True
luminance = True
groups=True

experiment(model_name, residual=residual, luminance=luminance, groups=groups)

WaveUnet


  0%|          | 0/100 [00:53<?, ?it/s, Step=5/10000, training loss=0.289]

Ablation

In [None]:
model_name = "WaveUnet-no-residual"
residual = False
luminance = True
groups=True

experiment(model_name, residual=residual, luminance=luminance, groups=groups)

 80%|████████  | 40/50 [1:12:31<17:46, 106.66s/it, Step=4085/5000, training loss=0.002]

In [None]:
model_name = "WaveUnet-no-luminance"
residual = True
luminance = False
groups=True

experiment(model_name, residual=residual, luminance=luminance, groups=groups)

In [None]:
model_name = "WaveUnet-no-groups"
residual = True
luminance = True
groups=False

experiment(model_name, residual=residual, luminance=luminance, groups=groups)