In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

class UNet(nn.Module):
    def __init__(self, input_depth):
        super(UNet, self).__init__()
        # 编码器部分
        self.encoder1 = self.conv_block(input_depth, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = self.conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.bottleneck = self.conv_block(512, 1024)
        
        # 解码器部分
        self.up_conv1 = self.up_conv(1024, 512)
        self.decoder1 = self.conv_block(1024, 512)
        
        self.up_conv2 = self.up_conv(512, 256)
        self.decoder2 = self.conv_block(512, 256)
        
        self.up_conv3 = self.up_conv(256, 128)
        self.decoder3 = self.conv_block(256, 128)
        
        self.up_conv4 = self.up_conv(128, 64)
        self.decoder4 = self.conv_block(128, 64)
        
        self.final_conv = nn.Conv2d(64, 3, kernel_size=1)  # 输出通道数从 1 改为 3

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3,padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3,padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def up_conv(self, in_channels, out_channels):
        # 使用 ConvTranspose2d 进行上采样
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # 编码器
        x1 = self.encoder1(x)
        p1 = self.pool1(x1)

        x2 = self.encoder2(p1)
        p2 = self.pool2(x2)

        x3 = self.encoder3(p2)
        p3 = self.pool3(x3)

        x4 = self.encoder4(p3)
        p4 = self.pool4(x4)

        x5 = self.bottleneck(p4)

        # 解码器
        d1 = self.up_conv1(x5)
        d1 = torch.cat([d1, x4], dim=1)
        d1 = self.decoder1(d1)

        d2 = self.up_conv2(d1)
        d2 = torch.cat([d2, x3], dim=1)
        d2 = self.decoder2(d2)

        d3 = self.up_conv3(d2)
        d3 = torch.cat([d3, x2], dim=1)
        d3 = self.decoder3(d3)

        d4 = self.up_conv4(d3)
        d4 = torch.cat([d4, x1], dim=1)
        d4 = self.decoder4(d4)

        out = self.final_conv(d4)
        return out


# 加载图像并添加噪声
def load_image(image_path, size=(256, 256)):
    img = Image.open(image_path)  # 移除 .convert('L')
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor()
    ])
    img = transform(img).unsqueeze(0)  # 添加 batch 维度
    return img

def add_noise(img, noise_factor=0.1):
    noisy_img = img + noise_factor * torch.randn(*img.shape)
    return torch.clamp(noisy_img, 0., 1.)

# 展示图像
def show_images(original, noisy, denoised):
    # 将张量转换为 numpy 数组，并移除批次维度
    original = original.squeeze(0).permute(1, 2, 0).cpu().numpy()
    noisy = noisy.squeeze(0).permute(1, 2, 0).cpu().numpy()
    denoised = denoised.squeeze(0).permute(1, 2, 0).cpu().numpy()

    # 由于 Matplotlib 的 imshow 函数期望图像值在 [0,1] 或 [0,255]，需要确保图像数据在 [0,1] 范围内
    original = np.clip(original, 0, 1)
    noisy = np.clip(noisy, 0, 1)
    denoised = np.clip(denoised, 0, 1)

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(original)
    axs[0].set_title('Original Image')
    axs[0].axis('off')

    axs[1].imshow(noisy)
    axs[1].set_title('Noisy Image')
    axs[1].axis('off')

    axs[2].imshow(denoised)
    axs[2].set_title('Denoised Image')
    axs[2].axis('off')

    plt.show()

    plt.show()

In [None]:
device = 'cuda:0'

def generate_gaussian_noise(input_depth, spatial_size, var=1./10):
    """
    Generates a Gaussian noise tensor.
    
    Parameters:
    - input_depth: The number of channels in the tensor
    - spatial_size: The spatial dimensions of the tensor (height, width)
    - var: The variance factor of the noise
    
    Returns:
    - A PyTorch tensor filled with Gaussian noise
    """

    if isinstance(spatial_size, int):
        spatial_size = (spatial_size, spatial_size)
    
    shape = (1, input_depth, spatial_size[0], spatial_size[1])
    noise = torch.randn(shape) * var
    
    return noise

In [None]:
# Instantiate the model

reg_noise_std = 1./30. # set to 1./20. for sigma=50

input_depth = 32
from utils import count_parameters

net_input = generate_gaussian_noise(input_depth=input_depth, spatial_size=256)
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

net = UNet(input_depth=input_depth).to(device)

print(f"Model Parameters: {count_parameters(net)}")

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

# Load and prepare the image
# image_path = './data/denoising/F16_GT.png'
image_path = './data/denoising/snail.jpg'

original_img = load_image(image_path)  # Returns tensor of shape [1, 1, 256, 256]
noisy_img = add_noise(original_img).to(device)

# Training loop
num_steps = 1000
for step in tqdm(range(num_steps)):
    optimizer.zero_grad()

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    net_input = net_input.to(device)
    output = net(net_input)
    # print(output.shape)
    # print(noisy_img.shape)
    loss = criterion(output, noisy_img)
    loss.backward()
    optimizer.step()

    # psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0]) 
    # psrn_gt    = compare_psnr(img_np, out.detach().cpu().numpy()[0]) 
    # psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0]) 

    if (step + 1) % 100 == 0:
        print(f"Step [{step+1}/{num_steps}], Loss: {loss.item():.4f}")
        show_images(original_img, noisy_img, output.detach())

# Show results
denoised_img = output.detach()
show_images(original_img, noisy_img, denoised_img)


In [None]:
image_path = './F16_GT.png'  # 使用本地路径
original_img = load_image(image_path)
noisy_img = add_noise(original_img)

In [None]:
net = UNet()
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.MSELoss()


In [None]:
num_steps = 10000
for step in range(num_steps):
    optimizer.zero_grad()
    output = net(noisy_img)
    loss = criterion(output, original_img)
    loss.backward()
    optimizer.step()

    if step % 500 == 0:
        print(f'Step {step}, Loss: {loss.item()}')


In [None]:
denoised_img = net(noisy_img).detach()
show_images(original_img, noisy_img, denoised_img)