# 简介
- [self2self](https://openaccess.thecvf.com/content_CVPR_2020/papers/Quan_Self2Self_With_Dropout_Learning_Self-Supervised_Denoising_From_Single_Image_CVPR_2020_paper.pdf)是利用单张图片就可以进行去噪
- 利用伯努利采样得到的图片作为去噪网络输入伯努利采样的补集作为预测目标
- 仅在伯努利采样补集上计算损失
- 在伯努利采样概率0.3下运行100000次效果比较好
- 论文中的网络结构使用了partial convolution和dropout，dropout比较关键
- 掩码必须每个channel都不同效果才最好
- 图片输入时调用toTensor即可，不要进行标准化
- 不要使用randomFlip，不要使用多张图片，否则会破坏分布
- 如何腐蚀noise图像是一个需要探讨的话题
- L1 loss没有L2 loss抗噪，很容易出现彩块
- 学习率不能过大，否则会很快学到噪声，并且adamw会梯度爆炸
- 从雪坑的质感看，网络学到的近似于溶解
- 加式残差学习在监督学习中效果很好，在无监督学习中起负面作用，直接加法的特点是高频特征学得很快，建议用拼接代替加法实现残差。
- 神经网络对噪声的阻抗来自于伯努利抽样、池化层、cat操作
- 采样概率过小或者过大都会加剧模糊
- 第一层和最后一层的通道数对结果影响非常大，建议输入通道数至少达到64
- 不用sigmoid函数图像会暗很多,清晰度也下降，但不管这样去噪后图片亮度都会偏暗
- 倒数第二层通常是输入和输出通道数相同或使用cat作为输入的缓冲层，不使用激活函数，避免破坏输出分布，最后一层通常用sigmoid或者不使用激活函数。

# 环境

In [1]:
import torch
import torchmetrics
import torchvision
import os
import gc
import torch.nn as nn
from torchvision import transforms
from torch import optim

print(torch.__version__)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    gc.collect()
    torch.cuda.empty_cache()
print(os.getcwd())


1.13.1
cuda
d:\project\deep-learning\unsupervised image denoise


# 模型

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision.utils import save_image

import torchvision.transforms as T

from PIL import Image
from tqdm.notebook import trange


class HourGlassCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,p=0.3,dropout=False,activation=nn.LeakyReLU(0.1)):
        super(HourGlassCNNBlock,self).__init__()
        layers=[]
        layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1,device='cuda'))
        if activation:layers.append(activation)
        if dropout:layers.append(nn.Dropout2d(p))
        self.main=nn.Sequential(*layers)
    def forward(self,x):
        return self.main(x)#channel不同不能进行残差学习

class HourGlassCNN(nn.Module):
    def __init__(self,p=0.3):
        super(HourGlassCNN,self).__init__()
        self.b1=HourGlassCNNBlock(3,32)
        self.b2=HourGlassCNNBlock(32,64)
        self.b3=HourGlassCNNBlock(64,128)
        
        self.b4=HourGlassCNNBlock(128,64,p)
        self.b5=HourGlassCNNBlock(128,64,p)
        self.b6=HourGlassCNNBlock(64,32,p)
        self.b7=HourGlassCNNBlock(64,32,p)
        self.b8=HourGlassCNNBlock(32,3,p,False,None)
        self.b9=HourGlassCNNBlock(6,3,p,False,nn.Sigmoid())

    def forward(self,x):
        b1=self.b1(x)
        b2=self.b2(b1)
        b3=self.b3(b2)
        b4=self.b4(b3)
        
        b5=self.b5(torch.cat((b4,b2),dim=1))
        b6=self.b6(b5)
        b7=self.b7(torch.cat((b6,b1),dim=1))
        b8=self.b8(b7)
        b9=self.b9(torch.cat((b8,x),dim=1))
        return b9

class Denoiser(nn.Module):
    def __init__(self,p=0.3):
        super(Denoiser,self).__init__()
        self.main=HourGlassCNN(p=p)
    def forward(self,x):
        return self.main(x)

import math
from torch import tensor, cat, zeros, ones,  randperm, bernoulli, full
from torch.nn.functional import conv2d


# percent:percent of zeros in mask
def create_mask(channels=3, height=512, width=512, percent=0.2, probability=0.25, mode='bernoulli', device='cuda'):
    if mode == 'percent':
        num = width*height*channels
        num_zeros = math.floor(num*percent)
        num_ones = num-num_zeros
        x = cat((zeros(num_zeros, device=device), ones(num_ones, device=device)))
        x = x[randperm(num)]
        x = x.view((channels,height, width))
        return x
    elif mode == 'bernoulli':
        return bernoulli(full((channels,height, width), 1-probability, device=device))


def image_gradient(img, mode='scharr', device='cuda'):
    if mode == 'weak':
        kernel_x = tensor(data=[[0., 0., 0.], [0., -1., 1.], [0., 0., 0.]],
                          device=device).repeat(3, 3, 1, 1)
        kernel_y = tensor(data=[[0., 0., 0.], [0., -1., 0.], [0., 1., 0.]],
                          device=device).repeat(3, 3, 1, 1)
    elif mode == 'sobel':
        kernel_x = tensor(data=[[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]],
                          device=device).repeat(3, 3, 1, 1)/8
        kernel_y = tensor(data=[[1., 2., 1.], [0., 0., 0.], [-1., -2., -1.]],
                          device=device).repeat(3, 3, 1, 1)/8
    elif mode == 'scharr':
        kernel_x = tensor(data=[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]],
                          device=device).repeat(3, 3, 1, 1)/16
        kernel_y = tensor(data=[[-3., -10., -3.], [0., 0., 0.], [3., 10., 3.]],
                          device=device).repeat(3, 3, 1, 1)/16
    dx = conv2d(img, weight=kernel_x, padding=1)
    dy = conv2d(img, weight=kernel_y, padding=1)
    return (dx**2+dy**2)**0.5


# 分析

# 数据

# 训练

In [3]:
from torchvision.transforms import ToTensor
from tqdm.notebook import trange
from torchvision.utils import save_image
from PIL import Image
def image_loader(image, device):
	"""load image, returns cuda tensor"""
	loader = ToTensor()
	image = loader(image).unsqueeze(0)
	return image.to(device)

if __name__ == "__main__":

	if torch.cuda.is_available():
		device = torch.device('cuda')
	else:
		device = torch.device('cpu')
	
	print('using device:', device) 
	
	model = Denoiser().to(device=device)
	img = Image.open("data/self2self_pytorch/examples/noisy.png")
	noisy=image_loader(img,device)

	optimizer = optim.Adam(model.parameters(),lr=1e-4)

	mse = nn.MSELoss(reduction='sum')
	for itr in trange(500000):
		model.train()
		mask=create_mask()
		mask_inv=1-mask
		out=model(noisy*mask)
	
		loss=mse(out*mask_inv,noisy*mask_inv)/mask_inv.sum()
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		
		#break
		if (itr+1)%1000 == 0:
			with torch.no_grad():
				model.eval()
				out = torch.zeros((1,3,512,512),device=device)
				for j in range(100):
					mask=create_mask()
					out+=model(noisy*mask)
				out/=100
				out=out.clip(0,1)
				print("iteration %d, loss = %.4f" % (itr+1, loss.item()*100))
				save_image(out,"images/self2self-"+str(itr+1)+".png")

using device: cuda


  0%|          | 0/500000 [00:00<?, ?it/s]