In [1]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from models import *
from dataset import prepare_data, Dataset
from utils import *
import torch
import numpy as np
from torchvision.transforms import transforms
import random
from skimage.util import random_noise

In [2]:
def add_noise(img_tensor):
    # Ensure img_tensor is float32 and on the correct device
    img_tensor = img_tensor.to(torch.float32)
    device = img_tensor.device

    noise_t = random.randint(0, 6)
    if noise_t == 0:
        # Gaussian noise
        gaussian_noise = (
            torch.tensor(
                random_noise(
                    img_tensor.cpu().numpy(),
                    mode="gaussian",
                    mean=0,
                    var=0.05,
                    clip=True,
                )
            ).to(device, dtype=torch.float32)
            - img_tensor
        )
        noise_tensor = gaussian_noise

    elif noise_t == 1:
        # Gaussian + Salt noise
        img_gaussian = torch.tensor(
            random_noise(
                img_tensor.cpu().numpy(), mode="gaussian", mean=0, var=0.05, clip=True
            )
        ).to(device, dtype=torch.float32)
        img_gaussian_salt = torch.tensor(
            random_noise(
                img_gaussian.cpu().numpy(), mode="salt", amount=0.05, clip=True
            )
        ).to(device, dtype=torch.float32)
        noise_tensor = img_gaussian_salt - img_tensor

    elif noise_t == 2:
        # Speckle noise
        speckle_noise = torch.normal(
            mean=0, std=0.05**0.5, size=img_tensor.size(), device=device
        )
        noise_tensor = img_tensor * speckle_noise

    elif noise_t == 3:
        # Poisson noise
        poisson_noise = (
            torch.tensor(
                random_noise(img_tensor.cpu().numpy(), mode="poisson", clip=True)
            ).to(device, dtype=torch.float32)
            - img_tensor
        )
        noise_tensor = poisson_noise

    elif noise_t == 4:
        # Salt noise
        salt_noise = (
            torch.tensor(
                random_noise(img_tensor.cpu().numpy(), mode="salt", amount=0.05)
            ).to(device, dtype=torch.float32)
            - img_tensor
        )
        noise_tensor = salt_noise

    elif noise_t == 5:
        # Speckle noise (Gaussian multiplicative)
        speckle_noise = (
            torch.tensor(
                random_noise(
                    img_tensor.cpu().numpy(),
                    mode="speckle",
                    mean=0,
                    var=0.05,
                    clip=True,
                )
            ).to(device, dtype=torch.float32)
            - img_tensor
        )
        noise_tensor = speckle_noise

    elif noise_t == 6:
        # Speckle + Salt noise
        speckle_noise = torch.normal(
            mean=0, std=0.01**0.5, size=img_tensor.size(), device=device
        )
        speckle_component = img_tensor * speckle_noise
        img_with_speckle = img_tensor + speckle_component
        noisy_img_np = random_noise(
            img_with_speckle.cpu().numpy(), mode="salt", amount=0.05
        )
        noise_tensor = (
            torch.tensor(noisy_img_np).to(device, dtype=torch.float32) - img_tensor
        )

    return noise_tensor

In [3]:
# Define each parameter as a regular variable

preprocess = True  # or False, depending on your needs
batchSize = 224  # Batch size for training
num_of_layers = 18  # Number of layers in your model
epochs = 10  # Number of epochs for training
milestone = 2  # Epoch milestone for something specific, if needed
lr = 0.005  # Learning rate
outf = "output_folder/"  # Output folder path
mode = "train"  # Mode of operation, e.g., 'train' or 'test'

print(f"Preprocess: {preprocess}")
print(f"Batch Size: {batchSize}")
print(f"Number of Layers: {num_of_layers}")
print(f"Epochs: {epochs}")
print(f"Milestone: {milestone}")
print(f"Learning Rate: {lr}")
print(f"Output Folder: {outf}")
print(f"Mode: {mode}")

Preprocess: True
Batch Size: 224
Number of Layers: 18
Epochs: 10
Milestone: 2
Learning Rate: 0.005
Output Folder: output_folder/
Mode: train


In [4]:
# prepare_data(data_path="data", patch_size=50, stride=10, aug_times=0)

In [5]:
print("Loading dataset ...\n")
dataset_train = Dataset(train=True)
dataset_val = Dataset(train=False)
loader_train = DataLoader(
    dataset=dataset_train, num_workers=2, batch_size=batchSize, shuffle=True
)
print("# of training samples: %d\n" % int(len(dataset_train)))
print("# of validation samples: %d\n" % int(len(dataset_val)))

Loading dataset ...

# of training samples: 959660

# of validation samples: 21



In [6]:
# Build model
# net = DnCNN(channels=1, num_of_layers=num_of_layers)
net = DenoiseCNN()
net.apply(weights_init_kaiming)

  nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')


DenoiseCNN(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
  )
  (decoder): Sequential(
    (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Sigmoid()
  )
)

In [7]:
criterion = nn.MSELoss(reduction="sum")
# Move to GPU
device_ids = list(
    range(torch.cuda.device_count())
)  # Automatically get all available GPUs
model = nn.DataParallel(net, device_ids=device_ids).cuda()  # Use all available GPUs
criterion.cuda()
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
# TensorBoard writer
writer = SummaryWriter(outf)
step = 0
for epoch in range(epochs):
    if epoch < milestone:
        current_lr = lr
    else:
        current_lr = lr / 10.0
    # Set learning rate
    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr
    print("learning rate %f" % current_lr)
    # Train
    for i, data in enumerate(loader_train, 0):
        # Training step
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        img_train = data
        noise = add_noise(img_train)
        imgn_train = img_train + noise
        img_train, imgn_train = img_train.cuda(), imgn_train.cuda()
        noise = noise.cuda()

        # Forward pass
        out_train = model(imgn_train)
        loss = criterion(out_train, noise) / (imgn_train.size()[0] * 2)
        loss.backward()
        optimizer.step()
        # Evaluation step with no gradient tracking
        with torch.no_grad():
            model.eval()
            out_train = torch.clamp(imgn_train - model(imgn_train), 0.0, 1.0)
            psnr_train = batch_PSNR(out_train, img_train, 1.0)
            print(
                "[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f"
                % (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train)
            )
        # Logging
        if step % 10 == 0:
            writer.add_scalar("loss", loss.item(), step)
            writer.add_scalar("PSNR on training data", psnr_train, step)
        step += 1
    # End of each epoch
    # Save model
    torch.save(model.state_dict(), os.path.join(outf, "baselinenet.pth"))
    model.eval()
    # Validate
    psnr_val = 0
    with torch.no_grad():  # Prevents gradient calculations
        for k in range(len(dataset_val)):
            img_val = torch.unsqueeze(dataset_val[k], 0)
            noise = add_noise(img_val)
            imgn_val = img_val + noise
            img_val, imgn_val = img_val.cuda(), imgn_val.cuda()
            out_val = torch.clamp(imgn_val - model(imgn_val), 0.0, 1.0)
            psnr_val += batch_PSNR(out_val, img_val, 1.0)
    psnr_val /= len(dataset_val)
    print("\n[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val))
    writer.add_scalar("PSNR on validation data", psnr_val, epoch)
    # Log the images
    out_train = torch.clamp(imgn_train - model(imgn_train), 0.0, 1.0)
    Img = utils.make_grid(img_train.data, nrow=8, normalize=True, scale_each=True)
    Imgn = utils.make_grid(imgn_train.data, nrow=8, normalize=True, scale_each=True)
    Irecon = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True)
    writer.add_image("clean image", Img, epoch)
    writer.add_image("noisy image", Imgn, epoch)
    writer.add_image("reconstructed image", Irecon, epoch)

learning rate 0.005000




[epoch 1][1/4285] loss: 379.6023 PSNR_train: 18.4019
[epoch 1][2/4285] loss: 27.2277 PSNR_train: 17.1485
[epoch 1][3/4285] loss: 62.9901 PSNR_train: 13.0183
[epoch 1][4/4285] loss: 63.7857 PSNR_train: 12.9595
[epoch 1][5/4285] loss: 25.7776 PSNR_train: 17.0870
[epoch 1][6/4285] loss: 62.4399 PSNR_train: 13.0573
[epoch 1][7/4285] loss: 16.7061 PSNR_train: 19.0590
[epoch 1][8/4285] loss: 16.6731 PSNR_train: 19.0018
[epoch 1][9/4285] loss: 62.0018 PSNR_train: 13.0883
[epoch 1][10/4285] loss: 42.3144 PSNR_train: 14.7141
[epoch 1][11/4285] loss: 16.8781 PSNR_train: 19.0030
[epoch 1][12/4285] loss: 2.0780 PSNR_train: 27.9475
[epoch 1][13/4285] loss: 23.5361 PSNR_train: 17.6590
[epoch 1][14/4285] loss: 2.0899 PSNR_train: 27.8573
[epoch 1][15/4285] loss: 24.7193 PSNR_train: 19.1686
[epoch 1][16/4285] loss: 42.3278 PSNR_train: 14.7131
[epoch 1][17/4285] loss: 25.3815 PSNR_train: 19.1422
[epoch 1][18/4285] loss: 16.6413 PSNR_train: 19.2268
[epoch 1][19/4285] loss: 61.9820 PSNR_train: 13.1021
[ep