In [1]:
pip install PyWavelets pytorch_wavelets scikit-image opencv-python-headless==4.5.3.56

Note: you may need to restart the kernel to use updated packages.


In [2]:
from pytorch_wavelets import DWTForward, DWTInverse
import torch
from diffusion.diffusion import Diffusion
from diffusion.wavelet_diffusion import WaveDiffusion
from torch.utils.data import DataLoader
from BSD import BSDDataset
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import cv2
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from utils import train, test

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_size = 256
epochs = 100
batch_size = 4
time_range = 1000
lr = 1e-4

base_dir=""

train_set = BSDDataset(base_dir=base_dir, split="train")
test_set = BSDDataset(base_dir=base_dir, split="test")

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)


In [4]:
def compute_loss(model, batch):
    return model.loss(batch)

def denoise(model, noisy_img, t):
    return model.generational_denoise(noisy_img, t)

**Baseline**

In [12]:
model = Diffusion(image_size=image_size, image_channels=3, time_range=time_range, device=device).to(device)
model_name = "Diffuser Baseline-time 100"

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

In [13]:
model_dir = os.path.join(base_dir, "results_BSD", model_name, "model.pt")
if os.path.exists(model_dir):
    model.load_state_dict(torch.load(model_dir).state_dict())
else:
    train(model, optimizer, epochs, train_loader, test_loader, model_name, compute_loss=compute_loss)

In [17]:
for noise_level, time in [(10, 5), (25, 12), (50, 20)]:
    test(model, test_loader, model_name, noise_level, denoise=denoise, t=time)

noise_level: 10
device: cuda:0


25it [00:35,  1.40s/it]


PSNR noisy: 31.70 +- 0.16
SSIM noisy: 0.839 +- 0.066
PSNR denoised: 33.86 +- 0.95
SSIM denoised: 0.910 +- 0.024
noise_level: 25
device: cuda:0


25it [01:13,  2.94s/it]


PSNR noisy: 23.91 +- 0.26
SSIM noisy: 0.557 +- 0.116
PSNR denoised: 28.47 +- 0.87
SSIM denoised: 0.766 +- 0.050
noise_level: 50
device: cuda:0


25it [01:57,  4.70s/it]


PSNR noisy: 18.31 +- 0.29
SSIM noisy: 0.322 +- 0.102
PSNR denoised: 23.60 +- 0.56
SSIM denoised: 0.547 +- 0.084


**WaveDiffusion**

In [5]:
model = WaveDiffusion(image_size=image_size, image_channels=3, time_range=time_range, device=device).to(device)
model_name = "WaveDiffusion"
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

In [6]:
model_dir = os.path.join(base_dir, "results_BSD", model_name, "model.pt")
if os.path.exists(model_dir):
    model.load_state_dict(torch.load(model_dir).state_dict())
else:
    train(model, optimizer, epochs, train_loader, test_loader, model_name, compute_loss=compute_loss, base_dir=base_dir)

In [9]:
for noise_level, time in [(10, 2), (25, 10), (50, 10)]:
    test(model, test_loader, model_name, noise_level, denoise=denoise, base_dir=base_dir, t=time)

noise_level: 10


25it [00:33,  1.33s/it]


PSNR noisy: 31.71 +- 0.16
SSIM noisy: 0.839 +- 0.066
PSNR denoised: 33.61 +- 0.54
SSIM denoised: 0.904 +- 0.035
noise_level: 25


25it [01:05,  2.63s/it]


PSNR noisy: 23.92 +- 0.26
SSIM noisy: 0.557 +- 0.116
PSNR denoised: 27.08 +- 0.39
SSIM denoised: 0.705 +- 0.087
noise_level: 50


25it [02:13,  5.32s/it]


PSNR noisy: 18.31 +- 0.28
SSIM noisy: 0.322 +- 0.103
PSNR denoised: 22.25 +- 0.30
SSIM denoised: 0.489 +- 0.104


**WaveDiffusion on Wavelet Space**