In [1]:
pip install PyWavelets pytorch_wavelets scikit-image opencv-python-headless==4.5.3.56

Note: you may need to restart the kernel to use updated packages.


In [2]:
from pytorch_wavelets import DWTForward, DWTInverse
import torch
from diffusion.diffusion import Diffusion
from diffusion.wavelet_diffusion import WaveDiffusion
from torch.utils.data import DataLoader
from BSD import BSDDataset
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_size = 256
epochs = 50
batch_size = 4
time_range = 1000
lr = 1e-4

base_dir=""

train_set = BSDDataset(base_dir=base_dir, split="train")
test_set = BSDDataset(base_dir=base_dir, split="test")

train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=4)


**Baseline**

In [7]:

model = Diffusion(image_size=image_size, image_channels=3, time_range=time_range, device=device).to(device)
model_name = "Diffuser Baseline"

optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [5]:
train_loss_all = []
test_loss_all = []

bar = tqdm(range(epochs))

output_dir = os.path.join(base_dir, "results_BSD", model_name)
model_dir = os.path.join(output_dir, "model.pt")
if os.path.exists(model_dir):
    model.load_state_dict(torch.load(model_dir).state_dict())
else:
    for epoch in bar:
        # training
        train_loss = 0
        model.train()
        for i, data in enumerate(train_loader):
            images, _ = data
            optimizer.zero_grad()
            loss = model.loss(images)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            bar.set_postfix({"Step": str(epoch * len(train_loader) + i + 1) + "/" + str(epochs * len(train_loader)), "training loss": format(train_loss, ".3f")})
        train_loss_all.append(train_loss/len(train_set))
        # testing
        test_loss = 0
        model.eval()
        for data in test_loader:
            images, _ = data
            with torch.no_grad():
                loss = model.loss(images)
            test_loss += loss.item()
        bar.set_postfix({"Epoch": epoch+1, "testing loss": format(test_loss, ".3f")})
        test_loss_all.append(test_loss/len(test_set))
    
    # save model and results
    os.makedirs(output_dir, exist_ok=True)
    torch.save(model, os.path.join(output_dir, "model.pt"))
    np.save(os.path.join(output_dir, "train_loss.npy"), np.array(train_loss_all))
    np.save(os.path.join(output_dir, "test_loss.npy"), np.array(test_loss_all))
    
    plt.plot(train_loss_all, label="training loss")
    plt.plot(test_loss_all, label="testing loss")
    plt.xlabel("Epoch")
    plt.legend()
    plt.savefig(os.path.join(output_dir, "loss.png"), format="png")

  0%|          | 0/50 [00:00<?, ?it/s]

In [6]:
import cv2
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

test_set = BSDDataset(base_dir=base_dir, split="test")
test_loader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=4)

for noise_level, time in [(10, 10), (25, 20), (50, 40)]:
    print("noise_level:", noise_level)
    output_dir = os.path.join(base_dir, "results_BSD", model_name + str(noise_level))
    original_path = os.path.join(output_dir, "original_images")
    noisy_path = os.path.join(output_dir, "noisy_images")
    denoised_path = os.path.join(output_dir, "denoised_images")
    os.makedirs(original_path, exist_ok=True)
    os.makedirs(noisy_path, exist_ok=True)
    os.makedirs(denoised_path, exist_ok=True)
    
    
    model.eval()
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device:", device)
    model.to(device)
    
    counter = 0
    with torch.no_grad():
        for j, data in tqdm(enumerate(test_loader, 0)):
            images, _ = data
            noisy_images = images + (noise_level/255)*torch.randn(*images.shape)
            noisy_images = np.clip(noisy_images, 0, 1)
            images = images.to(device) # move to GPU
            noisy_images = noisy_images.to(device)
            outputs = model.generational_denoise(noisy_images, time) # forward
            images = images.cpu().detach().numpy()
            noisy_images = noisy_images.cpu().detach().numpy()
            outputs = outputs.cpu().detach().numpy()
            for i in range(len(images)):
                image = 255 * np.transpose(images[i], (1,2,0))
                noisy_image = 255 * np.transpose(noisy_images[i], (1,2,0))
                output = 255 * np.transpose(outputs[i], (1,2,0))
                cv2.imwrite(os.path.join(original_path, str(counter)+".png"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(noisy_path, str(counter)+".png"), cv2.cvtColor(noisy_image, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(denoised_path, str(counter)+".png"), cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
                counter += 1
            if counter > 100:
                break
    
    PSNR_noisy = []
    SSIM_noisy = []
    PSNR_denoised = []
    SSIM_denoised = []
    for i in range(len(os.listdir(original_path))):
        image = cv2.imread(os.path.join(original_path, str(i)+".png"), 0)
        noisy_image = cv2.imread(os.path.join(noisy_path, str(i)+".png"), 0)
        denoised_image = cv2.imread(os.path.join(denoised_path, str(i)+".png"), 0)
        PSNR_noisy.append(peak_signal_noise_ratio(image, noisy_image))
        SSIM_noisy.append(structural_similarity(image, noisy_image))
        PSNR_denoised.append(peak_signal_noise_ratio(image, denoised_image))
        SSIM_denoised.append(structural_similarity(image, denoised_image))
    print("PSNR noisy:", format(np.mean(PSNR_noisy), ".2f"), "+-", format(np.std(PSNR_noisy), ".2f"))
    print("SSIM noisy:", format(np.mean(SSIM_noisy), ".3f"), "+-", format(np.std(SSIM_noisy), ".3f"))
    print("PSNR denoised:", format(np.mean(PSNR_denoised), ".2f"), "+-", format(np.std(PSNR_denoised), ".2f"))
    print("SSIM denoised:", format(np.mean(SSIM_denoised), ".3f"), "+-", format(np.std(SSIM_denoised), ".3f"))

noise_level: 10
device: cuda:0



0it [00:00, ?it/s][A
1it [00:06,  6.31s/it][A
2it [00:09,  4.34s/it][A
3it [00:12,  3.70s/it][A
4it [00:15,  3.41s/it][A
5it [00:18,  3.25s/it][A
6it [00:21,  3.14s/it][A
7it [00:24,  3.08s/it][A
8it [00:26,  3.04s/it][A
9it [00:29,  3.01s/it][A
10it [00:32,  3.00s/it][A
11it [00:35,  2.99s/it][A
12it [00:38,  2.98s/it][A
13it [00:41,  2.97s/it][A
14it [00:44,  2.97s/it][A
15it [00:47,  2.96s/it][A
16it [00:50,  2.97s/it][A
17it [00:53,  2.97s/it][A
18it [00:56,  2.96s/it][A
19it [00:59,  2.96s/it][A
20it [01:02,  2.96s/it][A
21it [01:05,  2.96s/it][A
22it [01:08,  2.96s/it][A
23it [01:11,  2.95s/it][A
24it [01:14,  2.96s/it][A
25it [01:17,  3.09s/it][A


PSNR noisy: 31.71 +- 0.16
SSIM noisy: 0.840 +- 0.066
PSNR denoised: 32.15 +- 1.15
SSIM denoised: 0.886 +- 0.026
noise_level: 25
device: cuda:0



0it [00:00, ?it/s][A
1it [00:08,  8.74s/it][A
2it [00:14,  6.99s/it][A
3it [00:20,  6.43s/it][A
4it [00:26,  6.16s/it][A
5it [00:31,  6.01s/it][A
6it [00:37,  5.92s/it][A
7it [00:43,  5.87s/it][A
8it [00:49,  5.83s/it][A
9it [00:54,  5.80s/it][A
10it [01:00,  5.79s/it][A
11it [01:06,  5.83s/it][A
12it [01:12,  5.80s/it][A
13it [01:17,  5.78s/it][A
14it [01:23,  5.77s/it][A
15it [01:29,  5.76s/it][A
16it [01:35,  5.76s/it][A
17it [01:40,  5.76s/it][A
18it [01:46,  5.76s/it][A
19it [01:52,  5.76s/it][A
20it [01:58,  5.76s/it][A
21it [02:03,  5.75s/it][A
22it [02:09,  5.75s/it][A
23it [02:15,  5.83s/it][A
24it [02:21,  5.80s/it][A
25it [02:27,  5.89s/it][A


PSNR noisy: 23.91 +- 0.26
SSIM noisy: 0.557 +- 0.116
PSNR denoised: 28.06 +- 1.04
SSIM denoised: 0.776 +- 0.041
noise_level: 50
device: cuda:0



0it [00:00, ?it/s][A
1it [00:14, 14.51s/it][A
2it [00:25, 12.63s/it][A
3it [00:37, 12.03s/it][A
4it [00:48, 11.75s/it][A
5it [00:59, 11.64s/it][A
6it [01:11, 11.53s/it][A
7it [01:22, 11.46s/it][A
8it [01:33, 11.43s/it][A
9it [01:45, 11.40s/it][A
10it [01:56, 11.38s/it][A
11it [02:07, 11.36s/it][A
12it [02:19, 11.35s/it][A
13it [02:30, 11.35s/it][A
14it [02:41, 11.34s/it][A
15it [02:53, 11.34s/it][A
16it [03:04, 11.34s/it][A
17it [03:15, 11.33s/it][A
18it [03:27, 11.33s/it][A
19it [03:38, 11.33s/it][A
20it [03:49, 11.33s/it][A
21it [04:01, 11.34s/it][A
22it [04:12, 11.39s/it][A
23it [04:24, 11.37s/it][A
24it [04:35, 11.36s/it][A
25it [04:46, 11.47s/it][A


PSNR noisy: 18.31 +- 0.29
SSIM noisy: 0.323 +- 0.103
PSNR denoised: 24.40 +- 0.98
SSIM denoised: 0.631 +- 0.053


**WaveDiffusion**

In [4]:
model = WaveDiffusion(image_size=image_size, image_channels=3, time_range=time_range, device=device).to(device)
model_name = "WaveDiffusion"
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
train_loss_all = []
test_loss_all = []

bar = tqdm(range(epochs))

for epoch in bar:
    # training
    train_loss = 0
    model.train()
    for i, data in enumerate(train_loader):
        images, _ = data
        optimizer.zero_grad()
        loss = model.loss(images)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        bar.set_postfix({"Step": str(epoch * len(train_loader) + i + 1) + "/" + str(epochs * len(train_loader)), "training loss": format(train_loss, ".3f")})
    train_loss_all.append(train_loss/len(train_set))
    # testing
    test_loss = 0
    model.eval()
    for data in test_loader:
        images, _ = data
        with torch.no_grad():
            loss = model.loss(images)
        test_loss += loss.item()
    bar.set_postfix({"Epoch": epoch+1, "testing loss": format(test_loss, ".3f")})
    test_loss_all.append(test_loss/len(test_set))

# save model and results
output_dir = os.path.join(base_dir, "results_BSD", model_name)
os.makedirs(output_dir, exist_ok=True)
torch.save(model, os.path.join(output_dir, "model.pt"))
np.save(os.path.join(output_dir, "train_loss.npy"), np.array(train_loss_all))
np.save(os.path.join(output_dir, "test_loss.npy"), np.array(test_loss_all))

plt.plot(train_loss_all, label="training loss")
plt.plot(test_loss_all, label="testing loss")
plt.xlabel("Epoch")
plt.legend()
plt.savefig(os.path.join(output_dir, "loss.png"), format="png")

  0%|          | 0/50 [01:21<?, ?it/s, Step=84/5000, training loss=41.980]

In [None]:
import cv2
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

test_set = BSDDataset(base_dir=base_dir, split="test")
test_loader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=4)

for noise_level, time in [(10, 10), (25, 20), (50, 40)]:
    print("noise_level:", noise_level)
    output_dir = os.path.join(base_dir, "results_BSD", model_name + str(noise_level))
    original_path = os.path.join(output_dir, "original_images")
    noisy_path = os.path.join(output_dir, "noisy_images")
    denoised_path = os.path.join(output_dir, "denoised_images")
    os.makedirs(original_path, exist_ok=True)
    os.makedirs(noisy_path, exist_ok=True)
    os.makedirs(denoised_path, exist_ok=True)
    
    
    model.eval()
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device:", device)
    model.to(device)
    
    counter = 0
    with torch.no_grad():
        for j, data in tqdm(enumerate(test_loader, 0)):
            images, _ = data
            noisy_images = images + (noise_level/255)*torch.randn(*images.shape)
            noisy_images = np.clip(noisy_images, 0, 1)
            images = images.to(device) # move to GPU
            noisy_images = noisy_images.to(device)
            outputs = model.generational_denoise(noisy_images, time) # forward
            images = images.cpu().detach().numpy()
            noisy_images = noisy_images.cpu().detach().numpy()
            outputs = outputs.cpu().detach().numpy()
            for i in range(len(images)):
                image = 255 * np.transpose(images[i], (1,2,0))
                noisy_image = 255 * np.transpose(noisy_images[i], (1,2,0))
                output = 255 * np.transpose(outputs[i], (1,2,0))
                cv2.imwrite(os.path.join(original_path, str(counter)+".png"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(noisy_path, str(counter)+".png"), cv2.cvtColor(noisy_image, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(denoised_path, str(counter)+".png"), cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
                counter += 1
            if counter > 100:
                break
    
    PSNR_noisy = []
    SSIM_noisy = []
    PSNR_denoised = []
    SSIM_denoised = []
    for i in range(len(os.listdir(original_path))):
        image = cv2.imread(os.path.join(original_path, str(i)+".png"), 0)
        noisy_image = cv2.imread(os.path.join(noisy_path, str(i)+".png"), 0)
        denoised_image = cv2.imread(os.path.join(denoised_path, str(i)+".png"), 0)
        PSNR_noisy.append(peak_signal_noise_ratio(image, noisy_image))
        SSIM_noisy.append(structural_similarity(image, noisy_image))
        PSNR_denoised.append(peak_signal_noise_ratio(image, denoised_image))
        SSIM_denoised.append(structural_similarity(image, denoised_image))
    print("PSNR noisy:", format(np.mean(PSNR_noisy), ".2f"), "+-", format(np.std(PSNR_noisy), ".2f"))
    print("SSIM noisy:", format(np.mean(SSIM_noisy), ".3f"), "+-", format(np.std(SSIM_noisy), ".3f"))
    print("PSNR denoised:", format(np.mean(PSNR_denoised), ".2f"), "+-", format(np.std(PSNR_denoised), ".2f"))
    print("SSIM denoised:", format(np.mean(SSIM_denoised), ".3f"), "+-", format(np.std(SSIM_denoised), ".3f"))