# 简介
- [self2self](https://openaccess.thecvf.com/content_CVPR_2020/papers/Quan_Self2Self_With_Dropout_Learning_Self-Supervised_Denoising_From_Single_Image_CVPR_2020_paper.pdf)是利用单张图片就可以进行去噪
- 利用伯努利采样得到的图片作为去噪网络输入伯努利采样的补集作为预测目标
- 仅在伯努利采样补集上计算损失
- 在伯努利采样概率0.3下运行100000次效果比较好
- 论文中的网络结构使用了partial convolution和dropout，dropout比较关键
- 掩码必须每个channel都不同效果才最好
- 图片输入时调用toTensor即可，不要进行标准化
- 不要使用randomFlip，不要使用多张图片，否则会破坏分布
- 如何腐蚀noise图像是一个需要探讨的话题
- L1 loss没有L2 loss抗噪，很容易出现彩块
- 学习率不能过大，否则会很快学到噪声，并且adamw会梯度爆炸
- 从雪坑的质感看，网络学到的近似于溶解
- 加式残差学习在监督学习中效果很好，在无监督学习中起负面作用，直接加法的特点是高频特征学得很快，建议用拼接代替加法实现残差。
- 神经网络对噪声的阻抗来自于伯努利抽样、池化层、cat操作
- 采样概率过小或者过大都会加剧模糊
- 第一层和最后一层的通道数对结果影响非常大，建议输入通道数至少达到64
- 不用sigmoid函数图像会暗很多,清晰度也下降，但不管这样去噪后图片亮度都会偏暗
- 网络越深色偏越大

# 环境

In [1]:
import torch
import torchvision
import os
import gc
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm.notebook import tnrange

print(torch.__version__)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    gc.collect()
    torch.cuda.empty_cache()
print(os.getcwd())


1.11.0+cu113
cuda
/root/autodl-tmp/facial landmarks denoise


# 数据

In [2]:
import json
from torch import tensor
with open('data/4.txt','r') as fp:
    data=tensor(json.load(fp),device=device)
noisyxy=data[:,:,0:2].flatten(1,2).permute([1,0])
noisyz=data.select(2,2).permute([1,0])

# 模型

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision.utils import save_image


from PIL import Image
from tqdm.notebook import trange


class HourGlassCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,activation=nn.LeakyReLU(0.1)):
        super(HourGlassCNNBlock,self).__init__()
        layers=[]
        layers.append(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1,device='cuda'))
        if activation:layers.append(activation)
        self.main=nn.Sequential(*layers)
    def forward(self,x):
        return self.main(x)#channel不同不能进行残差学习

class XYDenoiser(nn.Module):
    def __init__(self,noisyxy):
        super(XYDenoiser,self).__init__()
        self.b1=HourGlassCNNBlock(478*2,2**10)
        self.b2=HourGlassCNNBlock(2**10,2**11)

        self.b3=HourGlassCNNBlock(2**11,2**10)
        self.b4=HourGlassCNNBlock(2**11,2**10)
        self.b5=HourGlassCNNBlock(2**10,478*2,None)
        self.b6=HourGlassCNNBlock(478*4,478*2,nn.Sigmoid())

        self.noisyxy=noisyxy
    def forward(self,x):
        b1=self.b1(x)
        b2=self.b2(b1)
        b3=self.b3(b2)
        b4=self.b4(torch.cat((b3,b1)))
        b5=self.b5(b4)
        b6=self.b6(torch.cat((b5,x)))

        return b6.squeeze(0)
class ZDenoiser(nn.Module):
    def __init__(self):
        super(ZDenoiser,self).__init__()
        self.b1=HourGlassCNNBlock(478,2**9)
        self.b2=HourGlassCNNBlock(2**9,2**10)

        self.b3=HourGlassCNNBlock(2**10,2**9)
        self.b4=HourGlassCNNBlock(2**10+478,2**9)
        self.b5=HourGlassCNNBlock(2**9,478,None)
        self.b6=HourGlassCNNBlock(478*2,478,None)

    def forward(self,x):
        b1=self.b1(x)
        b2=self.b2(b1)
        b3=self.b3(b2)
        b4=self.b4(torch.cat((b3,b1,x),dim=0))
        b5=self.b5(b4)
        b6=self.b6(torch.cat((b5,x),dim=0))

        return b6.squeeze(0)
import math
from torch import tensor, cat, zeros, ones,  randperm, bernoulli, full

# percent:percent of zeros in mask
def create_mask(channels=478*2, length=30*30,  percent=0.2, probability=0.25, mode='bernoulli', device='cuda'):
    if mode == 'percent':
        num = length*channels
        num_zeros = math.floor(num*percent)
        num_ones = num-num_zeros
        x = cat((zeros(num_zeros, device=device), ones(num_ones, device=device)))
        x = x[randperm(num)]
        x = x.view((channels,length))
        return x
    elif mode == 'bernoulli':
        return bernoulli(full((channels,length), 1-probability, device=device))


# 分析

# 训练

In [4]:
from torch import optim,zeros
xymodel = XYDenoiser(noisyxy=noisyxy).to(device=device)
zmodel=ZDenoiser().to(device=device)
xyoptimizer = optim.RAdam(xymodel.parameters(),lr=5*1e-5)
zoptimizer=optim.RAdam(zmodel.parameters(),lr=2.5*1e-5)
mse = nn.MSELoss(reduction='sum')

In [5]:

for itr in tnrange(500000):
	zmodel.train()
	mask=create_mask(channels=478)
	mask_inv=1-mask
	out=zmodel(noisyz*mask)
	loss=mse( out*mask_inv, noisyz*mask_inv)/mask_inv.sum()
	zoptimizer.zero_grad()
	loss.backward()
	zoptimizer.step()

	xymodel.train()
	mask=create_mask()
	input=noisyxy*mask
	out=xymodel(input)		
	loss=mse( out, input)/mask.sum()
	xyoptimizer.zero_grad()
	loss.backward()
	xyoptimizer.step()

	#break
	if (itr+1)%1000 == 0:
			with torch.no_grad():
				xymodel.eval()
				zmodel.eval()
				outxy = xymodel(noisyxy).permute([1,0]).unflatten(1,(478,2))
				outz=zeros((30*30,478,1),device=device)
				for j in range(100):
					mask=create_mask(channels=478)
					outz+=zmodel(noisyz*mask).permute([1,0]).unflatten(1,(478,1))
				outz/=100
				out=torch.cat((outxy,outz),dim=2)
				print("iteration %d, loss = %.4f" % (itr+1, loss.item()*100000))
				with open(f"result/{str(itr+1)}.txt",'w') as fp:
					fp.writelines(json.dumps(out.tolist()))

  0%|          | 0/500000 [00:00<?, ?it/s]

iteration 1000, loss = 9193.6320
iteration 2000, loss = 7845.5381
iteration 3000, loss = 6558.2179
iteration 4000, loss = 5617.9442
iteration 5000, loss = 4874.5330
iteration 6000, loss = 4302.8593
iteration 7000, loss = 3802.5104
iteration 8000, loss = 3389.6826
iteration 9000, loss = 3036.1610
iteration 10000, loss = 2723.6650
iteration 11000, loss = 2436.3210
iteration 12000, loss = 2203.0767
iteration 13000, loss = 1992.2141
iteration 14000, loss = 1796.0476
iteration 15000, loss = 1610.9353
iteration 16000, loss = 1465.5154
iteration 17000, loss = 1336.7856
iteration 18000, loss = 1207.4728
iteration 19000, loss = 1123.1512
iteration 20000, loss = 1003.5913
iteration 21000, loss = 913.7438
iteration 22000, loss = 835.4478
iteration 23000, loss = 765.1262
iteration 24000, loss = 713.9108
iteration 25000, loss = 660.2539
iteration 26000, loss = 608.2462
iteration 27000, loss = 567.1561
iteration 28000, loss = 529.0935
iteration 29000, loss = 894.2861
iteration 30000, loss = 459.8160

KeyboardInterrupt: 