In [1]:
from pytorch_wavelets import DWTForward, DWTInverse
import torch
from diffusion.diffusion import Diffusion
from diffusion.wavelet_diffusion import WaveDiffusion
from torch.utils.data import DataLoader
from BSD import BSDDataset
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt

**Baseline**

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_size = 256
epochs = 50
batch_size = 4
time_range = 1000
lr = 1e-4

model = Diffusion(image_size=image_size, image_channels=3, time_range=time_range, device=device).to(device)
model_name = "Diffuser Baseline"

In [4]:
base_dir=""

train_set = BSDDataset(base_dir=base_dir, split="train")
test_set = BSDDataset(base_dir=base_dir, split="test")

train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=4)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [5]:
train_loss_all = []
test_loss_all = []

for epoch in tqdm(range(epochs)):
    # training
    train_loss = 0
    model.train()
    for data in train_loader:
        images, _ = data
        optimizer.zero_grad()
        loss = model.loss(images)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print("Epoch", epoch + 1, "training loss:", format(train_loss, ".3f"))
    train_loss_all.append(train_loss/len(train_set))
    # testing
    test_loss = 0
    model.eval()
    for data in test_loader:
        images, _ = data
        with torch.no_grad():
            loss = model.loss(images)
        test_loss += loss.item()
    print("Epoch", epoch+1, "testing loss:", format(test_loss, ".3f"))
    test_loss_all.append(test_loss/len(test_set))

# save model and results
output_dir = os.path.join(base_dir, "results_BSD", model_name)
os.makedirs(output_dir, exist_ok=True)
torch.save(model, os.path.join(output_dir, "model.pt"))
np.save(os.path.join(output_dir, "train_loss.npy"), np.array(train_loss_all))
np.save(os.path.join(output_dir, "test_loss.npy"), np.array(test_loss_all))

plt.plot(train_loss_all, label="training loss")
plt.plot(test_loss_all, label="testing loss")
plt.xlabel("Epoch")
plt.legend()
plt.savefig(os.path.join(output_dir, "loss.png"), format="png")

  0%|          | 0/50 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x10b77b240>
Traceback (most recent call last):
  File "/Users/jiseshen/Documents/Code/Wavelet_proj/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/Users/jiseshen/Documents/Code/Wavelet_proj/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1568, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/p

KeyboardInterrupt: 

In [8]:
import cv2
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

test_set = BSDDataset(base_dir=base_dir, split="test")
test_loader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=4)

noise_level = 10
time = 10

output_dir = os.path.join(base_dir, "results_BSD", model_name + str(noise_level))
original_path = os.path.join(output_dir, "original_images")
noisy_path = os.path.join(output_dir, "noisy_images")
denoised_path = os.path.join(output_dir, "denoised_images")
os.makedirs(original_path, exist_ok=True)
os.makedirs(noisy_path, exist_ok=True)
os.makedirs(denoised_path, exist_ok=True)


model.eval()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)
model.to(device)

counter = 0
with torch.no_grad():
    for j, data in enumerate(test_loader, 0):
        images, _ = data
        noisy_images = images + (noise_level/255)*torch.randn(*images.shape)
        noisy_images = np.clip(noisy_images, 0, 1)
        images = images.to(device) # move to GPU
        noisy_images = noisy_images.to(device)
        outputs = model.generational_denoise(noisy_images, time) # forward
        images = images.cpu().detach().numpy()
        noisy_images = noisy_images.cpu().detach().numpy()
        outputs = outputs.cpu().detach().numpy()
        for i in range(len(images)):
            image = 255 * np.transpose(images[i], (1,2,0))
            noisy_image = 255 * np.transpose(noisy_images[i], (1,2,0))
            output = 255 * np.transpose(outputs[i], (1,2,0))
            cv2.imwrite(os.path.join(original_path, str(counter)+".png"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
            cv2.imwrite(os.path.join(noisy_path, str(counter)+".png"), cv2.cvtColor(noisy_image, cv2.COLOR_RGB2BGR))
            cv2.imwrite(os.path.join(denoised_path, str(counter)+".png"), cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
            counter += 1
        if counter > 100:
            break

PSNR_noisy = []
SSIM_noisy = []
PSNR_denoised = []
SSIM_denoised = []
for i in range(len(os.listdir(original_path))):
    image = cv2.imread(os.path.join(original_path, str(i)+".png"), 0)
    noisy_image = cv2.imread(os.path.join(noisy_path, str(i)+".png"), 0)
    denoised_image = cv2.imread(os.path.join(denoised_path, str(i)+".png"), 0)
    PSNR_noisy.append(peak_signal_noise_ratio(image, noisy_image))
    SSIM_noisy.append(structural_similarity(image, noisy_image))
    PSNR_denoised.append(peak_signal_noise_ratio(image, denoised_image))
    SSIM_denoised.append(structural_similarity(image, denoised_image))
print("PSNR noisy:", format(np.mean(PSNR_noisy), ".2f"), "+-", format(np.std(PSNR_noisy), ".2f"))
print("SSIM noisy:", format(np.mean(SSIM_noisy), ".3f"), "+-", format(np.std(SSIM_noisy), ".3f"))
print("PSNR denoised:", format(np.mean(PSNR_denoised), ".2f"), "+-", format(np.std(PSNR_denoised), ".2f"))
print("SSIM denoised:", format(np.mean(SSIM_denoised), ".3f"), "+-", format(np.std(SSIM_denoised), ".3f"))

device: cpu


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x10b77b240>
Traceback (most recent call last):
  File "/Users/jiseshen/Documents/Code/Wavelet_proj/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/Users/jiseshen/Documents/Code/Wavelet_proj/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1568, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/connection.p

KeyboardInterrupt: 