# 简介
- [self2self](https://openaccess.thecvf.com/content_CVPR_2020/papers/Quan_Self2Self_With_Dropout_Learning_Self-Supervised_Denoising_From_Single_Image_CVPR_2020_paper.pdf)是利用单张图片就可以进行去噪
- 利用伯努利采样得到的图片作为去噪网络输入伯努利采样的补集作为预测目标
- 仅在伯努利采样补集上计算损失
- 在伯努利采样概率0.3下运行100000次效果比较好
- 论文中的网络结构使用了partial convolution和dropout，dropout比较关键
- 掩码必须每个channel都不同效果才最好
- 图片输入时调用toTensor即可，不要进行标准化
- 不要使用randomFlip，不要使用多张图片，否则会破坏分布
- 如何腐蚀noise图像是一个需要探讨的话题
- L1 loss没有L2 loss抗噪，很容易出现彩块
- 学习率不能过大，否则会很快学到噪声，并且adamw会梯度爆炸
- 从雪坑的质感看，网络学到的近似于溶解
- 加式残差学习在监督学习中效果很好，在无监督学习中起负面作用，直接加法的特点是高频特征学得很快，建议用拼接代替加法实现残差。
- 神经网络对噪声的阻抗来自于伯努利抽样、池化层、cat操作
- 采样概率过小或者过大都会加剧模糊
- 第一层和最后一层的通道数对结果影响非常大，建议输入通道数至少达到64
- 不用sigmoid函数图像会暗很多,清晰度也下降，但不管这样去噪后图片亮度都会偏暗
- 网络越深色偏越大

# 环境

In [None]:
import torch
import torchmetrics
import torchvision
import os
import gc
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm.notebook import tnrange

print(torch.__version__)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    gc.collect()
    torch.cuda.empty_cache()
print(os.getcwd())


# 模型

In [None]:

class HourGlassCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,p=0.3,device='cuda'):
        super(HourGlassCNNBlock,self).__init__()
        self.main=nn.Sequential(
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1,device=device),
            nn.Dropout1d(p),
            nn.SiLU(),
        )
    def forward(self,x):
        return self.main(x)#channel不同不能进行残差学习

class HourGlassCNN(nn.Module):
    def __init__(self,p=0.3,device='cuda'):
        super(HourGlassCNN,self).__init__()
        self.b1=HourGlassCNNBlock(478*3,2**12,p=p,device=device)
        self.b2=HourGlassCNNBlock(2**12,2**12,p=p,device=device)
        self.b3=HourGlassCNNBlock(2**12,2**13,p=p,device=device)
        self.b4=HourGlassCNNBlock(2**13,2**14,p=p,device=device)
        self.b5=HourGlassCNNBlock(2**14,2**13,p=p,device=device)

        self.b6=HourGlassCNNBlock(2**14,2**13,p=p,device=device)
        self.b7=HourGlassCNNBlock(2**13,2**12,p=p,device=device)
        self.b8=HourGlassCNNBlock(2**13,2**12,p=p,device=device)
        self.b9=HourGlassCNNBlock(2**12,2**12,p=p,device=device)
        self.b10=nn.Sequential(
            nn.Conv1d(2**12,478*3,kernel_size=3,padding=1,device=device),
            nn.Dropout1d(p),
            nn.Sigmoid()
        )

    def forward(self,x):
        b1=self.b1(x)
        b2=self.b2(b1)
        b3=self.b3(b2)
        b4=self.b4(b3)
        b5=self.b5(b4)
        b6=self.b6(torch.cat((b5,b3),dim=1))
        b7=self.b7(b6)
        b8=self.b8(torch.cat((b7,b1),dim=1))
        b9=self.b9(b8)
        b10=self.b10(b9)
        return b10

class Denoiser(nn.Module):
    def __init__(self,p=0.3,device='cuda'):
        super(Denoiser,self).__init__()
        self.main=HourGlassCNN(p=p,device=device)
    def forward(self,x):
        return self.main(x)

import math
from torch import tensor, cat, zeros, ones,  randperm, bernoulli, full


# percent:percent of zeros in mask
def create_mask(channels=478*3, length=30*10, percent=0.25, probability=0.25, mode='bernoulli', device='cuda'):
    if mode == 'percent':
        num =channels*length
        num_zeros = math.floor(num*percent)
        num_ones = num-num_zeros
        x = cat((zeros(num_zeros, device=device), ones(num_ones, device=device)))
        x = x[randperm(num)]
        x = x.view((channels,length))
        return x
    elif mode == 'bernoulli':
        return bernoulli(full((channels,length), 1-probability, device=device))


# 分析

In [None]:
import torch.utils.tensorboard
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.notebook import tqdm
from torchvision.utils import make_grid

# 数据

In [None]:
import json
from torch import tensor
with open('data/2.txt','r') as fp:
    noisy=tensor(json.load(fp),device=device).flatten(1,2).permute([1,0])

# 训练

In [None]:
from torch import optim,zeros
model = Denoiser(device=device).to(device=device)
optimizer = optim.Adam(model.parameters(),lr=1e-4)
mse = nn.MSELoss(reduction='sum')
for itr in tnrange(500000):
	model.train()
	mask=create_mask()
	mask_inv=1-mask
			
	loss=mse( model(noisy*mask)*mask_inv, noisy*mask_inv)/mask_inv.sum()
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()
		
	#break
	if (itr+1)%1000 == 0:
		with torch.no_grad():
			model.eval()
			out = zeros((1,478*3,300),device=device)
			for j in range(100):
				mask=create_mask()
				out+=model(noisy*mask)
			out/=100
			print("iteration %d, loss = %.4f" % (itr+1, loss.item()*100))
			json.dumps(out.tolist(),f"result/{str(itr+1)}.txt")