In [8]:
import torch
import math
from pytorch_wavelets import DWTForward, DWTInverse

class ImageAttentionBlock(torch.nn.Module):
    def __init__(self, hidden_dim, image_size, num_head=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.image_size = image_size
        self.layer_norm = torch.nn.LayerNorm([hidden_dim])
        self.att = torch.nn.MultiheadAttention(hidden_dim, num_head, batch_first=True)
        self.feed_forward = torch.nn.Sequential(
            torch.nn.LayerNorm([hidden_dim]),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x):  # x: [B, C, H, W]
        x = x.reshape(-1, self.hidden_dim, self.image_size * self.image_size).transpose(1, 2)  # [B, H*W, C]
        x_norm = self.layer_norm(x)
        attention_value, _ = self.att(x_norm, x_norm, x_norm)
        x = x + attention_value
        x = x + self.feed_forward(x)
        x = x.transpose(1, 2).reshape(-1, self.hidden_dim, self.image_size, self.image_size)
        return x

class BottleNeckBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.conv(x)
        return x

class BaseConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, groups=1):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, groups=groups)
        self.norm = torch.nn.GroupNorm(groups, in_channels)
        self.act = torch.nn.GELU()
    
    def forward(self, x):
        x = self.norm(x)
        x = self.act(x)
        x = self.conv(x)
        return x

class ResidualBlock(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.l1 = BaseConvBlock(in_channels, in_channels)
        self.l2 = BaseConvBlock(in_channels, in_channels)

    def forward(self, x):
        res = x
        x = self.l1(x)
        x = self.l2(x)
        return (x + res) / math.sqrt(2)

class WaveShort(torch.nn.Module):
    def __init__(self, in_channels, out_channels, wave='haar'):
        super().__init__()
        self.dwt = DWTForward(J=1, wave=wave, mode='symmetric')
        self.conv = torch.nn.Conv2d(4 * in_channels, out_channels, kernel_size=3, padding=1, bias=False, groups=4)

    def forward(self, x):
        xl, xh = self.dwt(x)
        b, c, _, h, w = xh[0].shape
        xh = xh[0].reshape(b, 3 * c, h, w)
        x = torch.cat([xl, xh], dim=1)
        x = self.conv(x)
        return x

class WaveBottleNeckBlock(torch.nn.Module):
    def __init__(self, in_channels, wave='haar'):
        super().__init__()
        self.dwt = DWTForward(J=1, wave=wave, mode='symmetric')
        self.idwt = DWTInverse(wave=wave, mode='symmetric')
        self.res = ResidualBlock(in_channels)

    def forward(self, x):
        xl, xh = self.dwt(x)
        xl = self.res(xl / 2) * 2
        x = self.idwt((xl, xh))
        return x

class WaveDownSampleBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, wave='haar'):
        super().__init__()
        self.dwt = DWTForward(J=1, wave=wave, mode='symmetric')
        self.conv1 = BaseConvBlock(in_channels, out_channels)
        self.conv2 = BaseConvBlock(out_channels, out_channels)
        self.short = BottleNeckBlock(in_channels, out_channels)

    def forward(self, x, t_emb):
        hs = self.conv1(x)
        x = self.short(x)
        hl, hh = self.dwt(hs)
        b, c, _, w, h = hh[0].shape
        hh = hh[0].reshape(b, 3 * c, w, h)
        xl, xh = self.dwt(x)
        hl = hl / 2.
        x = xl / 2.
        hs = hl + t_emb
        hs = self.conv2(hl)
        return (x + hs) / math.sqrt(2), hh
    

class WaveUpSampleBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, wave='haar'):
        super().__init__()
        self.idwt = DWTInverse(wave=wave, mode='symmetric')
        self.conv1 = BaseConvBlock(in_channels, out_channels)
        self.conv2 = BaseConvBlock(out_channels, out_channels)
        self.short = BottleNeckBlock(in_channels, out_channels)
        self.skip_conv = BaseConvBlock(in_channels * 3, out_channels * 3, groups=3)

    def forward(self, x, skip, t_emb):  # Skip Connection
        hs = self.conv1(x)
        x = self.short(x)
        b, c, w, h = x.shape
        skip = self.skip_conv(skip / 2.) * 2.
        skip = skip.reshape(b, c, 3, w, h)
        hs = self.idwt((2. * hs, [skip]))
        x = self.idwt((2. * x, [skip]))
        hs = hs + t_emb
        hs = self.conv2(hs)
        return (x + h) / math.sqrt(2)
    

class FreqUnet(torch.nn.Module):
    def __init__(self, image_size, image_channels, wave='haar', device="cuda"):
        super().__init__()
        self.image_size = image_size
        self.device = device

        self.short1 = WaveShort(image_channels, 128, wave=wave)
        self.short2 = WaveShort(128, 256, wave=wave)
        
        self.in_conv = torch.nn.Conv2d(image_channels, 64, kernel_size=3, padding=1, bias=False)
        self.res1 = ResidualBlock(64)
        self.down1 = WaveDownSampleBlock(64, 128, wave=wave)
        self.res2 = ResidualBlock(128)
        self.down2 = WaveDownSampleBlock(128, 256, wave=wave)
        self.bottle1 = WaveBottleNeckBlock(256, wave=wave)
        self.att = ImageAttentionBlock(256, image_size // 4)
        self.bottle2 = WaveBottleNeckBlock(256, wave=wave)
        self.res3 = ResidualBlock(256)
        self.up1 = WaveUpSampleBlock(256, 128, wave=wave)
        self.res4 = ResidualBlock(128)
        self.up2 = WaveUpSampleBlock(128, 64, wave=wave)
        self.out_conv = BaseConvBlock(64, image_channels)

    
    def forward(self, x):
        s1 = self.short1(x)
        s2 = self.short2(s1)
        
        h = self.in_conv(x)
        h = self.res1(h)
        h, skip1 = self.down1(h, 0)
        h = (h + s1) / math.sqrt(2)
        h = self.res2(h)
        h, skip2 = self.down2(h, 0)
        h = (h + s2) / math.sqrt(2)
        h = self.bottle1(h)
        h = self.att(h)
        h = self.bottle2(h)
        h = self.res3(h)
        h = self.up1(h, skip2, 0)
        h = self.res4(h)
        h = self.up2(h, skip1, 0)
        x = self.out_conv(h)
        return x

In [9]:
from utils import train, test
from BSD import BSDDataset
import numpy as np
torch.manual_seed(4623)
torch.cuda.manual_seed(4623)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_size = 256
epochs = 50
batch_size = 4
lr = 1e-4
noise_level = 10
criterion = torch.nn.MSELoss()

base_dir=""

train_set = BSDDataset(base_dir=base_dir, split="train")
test_set = BSDDataset(base_dir=base_dir, split="test")


def compute_loss(model, images, noise_level):
    noisy_images = images + (noise_level/255)*torch.randn(*images.shape)
    images = images.to(device) # move to GPU
    noisy_images = np.clip(noisy_images, 0, 1)
    noisy_images = noisy_images.to(device)
    outputs = model(noisy_images) # forward
    outputs = outputs.to(device)
    loss = criterion(outputs, images)
    return loss

def denoise(model, noisy_img):
    outputs = model(noisy_img) # forward
    return outputs

In [None]:
model = FreqUnet(256, 3, wave="Haar").to(device)
noise_level = 10
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "WaveUnet"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)

  0%|          | 0/50 [00:11<?, ?it/s, Step=1/5000, training loss=0.343]

In [4]:
model = FreqUnet(256, 3, wave="Haar").to(device)
noise_level = 25
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "WaveUnet"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)

TypeError: FreqUnet.__init__() missing 1 required positional argument: 'time_range'

In [None]:
model = FreqUnet(256, 3, wave="Haar").to(device)
noise_level = 50
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "WaveUnet"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)