In [1]:
pip install PyWavelets pytorch_wavelets



In [4]:
import math
import torch
from pytorch_wavelets import DWTForward ,DWTInverse
class Basicconv(torch.nn.Module):

  def __init__(self, inchannel, outchannel, groups=1):
    super().__init__()
    self.conv = torch.nn.Conv2d(inchannel, outchannel, kernel_size = 3, padding = 1, groups=groups)
    self.act = torch.nn.PReLU()
    self.nom = torch.nn.GroupNorm(groups, outchannel)

  def forward(self, x):
    x = self.conv(x)
    x = self.nom(x)
    x = self.act(x)
    return x

class Bottleneck(torch.nn.Module):
  def __init__(self, inchannel, outchannel):
    super().__init__()
    self.conv = torch.nn.Conv2d(inchannel, outchannel, kernel_size = 1)
  def forward(self, x):
    x = self.conv(x)
    return x

class DCRblock(torch.nn.Module):
  def __init__(self, inchannel):
    super().__init__()
    self.conv1 = Basicconv(inchannel, inchannel//2)
    self.conv2 = Basicconv(inchannel//2 + inchannel, inchannel//2)
    self.conv3 = Basicconv(2*inchannel, inchannel)
  def forward(self, x):
    x1 = self.conv1(x)
    x2 = torch.cat([x,x1], dim=1)
    x2 = self.conv2(x2)
    x2 = torch.cat([x1,x2,x], dim=1)
    x3 = self.conv3(x2)
    x = x3 + x
    return x

class finalneck(torch.nn.Module):
  def __init__(self, inchannel, outchannel):
    super().__init__()
    self.conv = torch.nn.Conv2d(inchannel, outchannel, kernel_size = 3, padding=1)
  def forward(self, x):
    x = self.conv(x)
    return x

class finalcombine(torch.nn.Module):
  def __init__(self, inchannel, wave="Haar", groups=False):
    super().__init__()
    self.groups = groups
    self.c00 = Basicconv(inchannel = 3,  outchannel = 160)
    self.DWT = DWTForward(1, wave, "symmetric")
    self.IDWT = DWTInverse(wave, "symmetric")
    self.c1 = Bottleneck(inchannel = 320, outchannel = 320)
    self.d1 = Bottleneck(inchannel = 320, outchannel = 320)
    self.c2 = DCRblock(inchannel = 320)
    self.d2 = DCRblock(inchannel = 320)
    self.c3 = DCRblock(inchannel = 320)
    self.d3 = DCRblock(inchannel = 320)
    self.c4 = finalneck(inchannel = 320, outchannel = 3)
    self.d4 = finalneck(inchannel = 320, outchannel = 3)
    self.c11 = Basicconv(inchannel = 3*4,  outchannel = 256, groups = 4 if groups else 1)
    self.c12 = Bottleneck(inchannel = 512, outchannel = 512)
    self.c13 = DCRblock(inchannel = 512)
    self.c14 = finalneck(inchannel = 512, outchannel = 640)
    self.c21 = Basicconv(inchannel = 3*4*4,  outchannel = 256, groups = 4 if groups else 1)
    self.c22 = Bottleneck(inchannel = 512, outchannel = 512)
    self.c23 = DCRblock(inchannel = 512)
    self.c24 = finalneck(inchannel = 512, outchannel = 1024)
    self.c31 = Basicconv(inchannel = 3*4*4*4,  outchannel = 256, groups = 4 if groups else 1)
    self.c32 = DCRblock(inchannel = 256)
    self.c33 = finalneck(inchannel = 256, outchannel = 1024)

  def forward(self, x):
    x00 = self.c00(x)
    b, c, w, h = x.shape
    w = w // 2
    h = h // 2
    xl, xh = self.DWT(x)
    xh = xh[0][:, :, :, :w, :h]
    xl = xl[:, :, :w, :h]
    b, c, _, w, h = xh.shape
    xh = xh.reshape(b, 3*c, w, h)
    x1 = torch.cat([xl, xh], dim = 1)

    x11 = self.c11(x1)
    b, c, w, h = x1.shape
    w = w // 2
    h = h // 2
    xl, xh = self.DWT(x1)
    xh = xh[0][:, :, :, :w, :h]
    xl = xl[:, :, :w, :h]
    b, c, _, w, h = xh.shape
    xh = xh.reshape(b, 3*c, w, h)
    x2 = torch.cat([xl, xh], dim = 1)

    x21 = self.c21(x2)
    b, c, w, h = x2.shape
    w = w // 2
    h = h // 2
    xl, xh = self.DWT(x2)
    xh = xh[0][:, :, :, :w, :h]
    xl = xl[:, :, :w, :h]
    b, c, _, w, h = xh.shape
    xh = xh.reshape(b, 3*c, w, h)
    x3 = torch.cat([xl,xh], dim = 1)

    x31 = self.c31(x3)
    x32 = self.c32(x31)
    x33 = self.c33(x32)
    c = x33.shape[1] // 4
    xl = x33[:, :c]
    xh = x33[:, c:]
    b, c, w, h = xl.shape
    xh = [xh.reshape(b, c, 3, w, h)]
    w = w * 2
    h = h * 2
    xi = self.IDWT((xl, xh))
    xi = torch.nn.functional.pad(xi, (0, h-xi.shape[2], 0, w-xi.shape[3]), mode="reflect")
    x21 = torch.cat([x21, xi], dim=1)
    x22 = self.c22(x21)
    x23 = self.c23(x22)
    x24 = self.c24(x23)
    c = x24.shape[1] // 4
    xl = x24[:, :c]
    xh = x24[:, c:]
    b, c, w, h = xl.shape
    xh = [xh.reshape(b, c, 3, w, h)]
    w *= 2
    h *= 2
    xi = self.IDWT((xl, xh))
    xi = torch.nn.functional.pad(xi, (0, h-xi.shape[2], 0, w-xi.shape[3]), mode="reflect")
    x11 = torch.cat([x11, xi], dim=1)
    x12 = self.c12(x11)
    x13 = self.c13(x12)
    x14 = self.c14(x13)
    c = x14.shape[1] // 4
    xl = x14[:, :c]
    xh = x14[:, c:]
    b, c, w, h = xl.shape
    xh = [xh.reshape(b, c, 3, w, h)]
    w *= 2
    h *= 2
    xi = self.IDWT((xl, xh))
    xi = torch.nn.functional.pad(xi, (0, w - xi.shape[2], 0, h - xi.shape[3]), mode="reflect")
    x01 = torch.cat([x00, xi], dim=1)
    xc1 = self.c1(x01)
    xc2 = self.c2(xc1) + xc1
    xc3 = self.c3(xc2) + xc2
    xc4 = self.c4(xc3)
    xc4 = torch.nn.functional.tanh(xc4)
    xd1 = self.d1(x01)
    xd2 = self.d2(xd1) + xd1
    xd3 = self.d3(xd2) + xd2
    xd4 = self.d4(xd3)
    xd4 = x - xd4
    output = (xd4 + xc4)/2
    return output



In [5]:
from utils import train, test
from BSD import BSDDataset
import numpy as np
torch.manual_seed(4623)
torch.cuda.manual_seed(4623)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_size = 256
epochs = 50
batch_size = 4
time_range = 1000
lr = 1e-4
noise_level = 10
criterion = torch.nn.MSELoss()

base_dir=""

train_set = BSDDataset(base_dir=base_dir, split="train")
test_set = BSDDataset(base_dir=base_dir, split="test")


def compute_loss(model, images, noise_level):
    noisy_images = images + (noise_level/255)*torch.randn(*images.shape)
    images = images.to(device) # move to GPU
    noisy_images = np.clip(noisy_images, 0, 1)
    noisy_images = noisy_images.to(device)
    outputs = model(noisy_images) # forward
    outputs = outputs.to(device)
    loss = criterion(outputs, images)
    return loss

def denoise(model, noisy_img):
    outputs = model(noisy_img) # forward
    return outputs

NoGroup

In [6]:
model = finalcombine(3, "Haar").to(device)
noise_level = 10
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "DSWN"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)

  0%|          | 0/50 [07:23<?, ?it/s, Step=200/10000, training loss=0.090]


AttributeError: 'finalcombine' object has no attribute 'loss'

In [None]:
model = finalcombine(3, "Haar").to(device)
noise_level = 25
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "DSWN"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)

In [None]:
model = finalcombine(3, "Haar").to(device)
noise_level = 50
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "DSWN"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)

Group

In [None]:
model = finalcombine(3, "Haar", groups=True).to(device)
noise_level = 50
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "DSWNGroup"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)

Wavelet Ablation

In [None]:
model = finalcombine(3, "db2", groups=True).to(device)
noise_level = 50
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "DSWNDB2"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)

In [None]:
model = finalcombine(3, "db3", groups=True).to(device)
noise_level = 50
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model_name = "DSWNDB3"+str(noise_level)

train(model, optimizer, epochs, train_set, test_set, batch_size, model_name, compute_loss=compute_loss, noise_level=noise_level)
test(model, test_set, batch_size, model_name, noise_level, denoise=denoise)