# 简介
- [self2self](https://openaccess.thecvf.com/content_CVPR_2020/papers/Quan_Self2Self_With_Dropout_Learning_Self-Supervised_Denoising_From_Single_Image_CVPR_2020_paper.pdf)是利用单张图片就可以进行去噪
- 利用伯努利采样得到的图片作为去噪网络输入伯努利采样的补集作为预测目标
- 仅在伯努利采样补集上计算损失
- 在伯努利采样概率0.3下运行100000次效果比较好
- 论文中的网络结构使用了partial convolution和dropout，dropout比较关键

# 环境

In [1]:
import torch
import torchmetrics
import torchvision
import os
import gc
import torch.nn as nn
from torchvision import transforms

print(torch.__version__)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    gc.collect()
    torch.cuda.empty_cache()
print(os.getcwd())


1.11.0+cu113
cuda
/root/autodl-tmp/deep-learning/unsupervised denoise


# 模型

In [2]:
class HourGlassCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,size=128,layernorm=False,batchnorm=False,dropout=False,activation=True,dilation=1,padding=1,p=0.3):
        super(HourGlassCNNBlock,self).__init__()
        layers=[]
        layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=padding,dilation=dilation))
        if layernorm:layers.append(nn.LayerNorm([out_channels,size,size]))
        if batchnorm:layers.append(nn.BatchNorm2d(out_channels))
        if activation:layers.append(nn.SiLU())
        if dropout:layers.append(nn.Dropout2d(0.3))
        self.main=nn.Sequential(*layers)
    def forward(self,x):
        return self.main(x)#channel不同不能进行残差学习

class HourGlassCNN(nn.Module):
    def __init__(self,size=128,layernorm=False,batchnorm=False,max_channels=512,residual=True,dropout=True,p=0.3):
        super(HourGlassCNN,self).__init__()
        self.blocks=nn.ModuleList()
        self.blocks.append( HourGlassCNNBlock(3,64,size,layernorm=False,batchnorm=False,p=p)),#0
        channel=128
        while channel<=max_channels:
            self.blocks.append(HourGlassCNNBlock(channel//2,channel,size,layernorm,batchnorm,dropout=dropout,p=p))
            channel*=2
        channel//=2
        while channel>64:
            self.blocks.append(HourGlassCNNBlock(channel,channel//2,size,layernorm,batchnorm,dropout=dropout,p=p))
            channel//=2
        self.blocks.append(HourGlassCNNBlock(64,3,size,layernorm=False,batchnorm=False,activation=False,dropout=dropout,p=p))
        self.residual=residual
    def forward(self,x):
        blocks_len=len(self.blocks)
        y=[]#blocks[i]的输出是y[i]
        y.append(self.blocks[0](x))#(3,64) 64,128 128,256 256,128 128,64 64,3
        for i in range(1,blocks_len//2):y.append(self.blocks[i](y[i-1]))
        for i in range(blocks_len//2,blocks_len-1):y.append(self.blocks[i](y[i-1])+y[blocks_len-i-2]if self.residual else self.blocks[i](y[i-1]))
        return self.blocks[-1](y[-1])


class Denoiser(nn.Module):
    def __init__(self,size=128,layernorm=False,batchnorm=False,max_channels=512,p=0.3):
        super(Denoiser,self).__init__()
        self.main=HourGlassCNN(size,layernorm,batchnorm,max_channels,p=p)
    def forward(self,x):
        return self.main(x)+x


# 分析

In [3]:
import torch.utils.tensorboard
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.notebook import tqdm
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from torchmetrics import PeakSignalNoiseRatio
from torchmetrics import StructuralSimilarityIndexMeasure

In [4]:
def convert_to_rgb255(image:torch.Tensor):
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    return image
def show_image(image:torch.Tensor):
    image=convert_to_rgb255(image)
    plt.imshow(transforms.ToPILImage()(image))
    plt.show()

#writer=SummaryWriter()
psnr=PeakSignalNoiseRatio().to(device=device)
ssim=StructuralSimilarityIndexMeasure().to(device=device)#负数对应颜色反转


# 数据

In [5]:

from torch.utils.data import Dataset 
from torch.utils.data import DataLoader

from PIL import Image
def load_image():
    img = Image.open('data/self2self_pytorch-main/examples/noisy.png').convert('RGB')#比torchvision的好
    transform = transforms.Compose([
            # hwc->chw 并 归一化到[0,1]
            transforms.ToTensor(),
            # [−1,1]
            transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
    ])
    img = transform(img)
    return img
noisy=load_image().to(device=device)


# 训练

In [6]:
from torch.optim import Adam
from torch.nn import SmoothL1Loss

model = Denoiser().to(device)
optimizer = Adam(model.parameters() , lr=1e-3)
p=0.3
smooth_l1_loss=SmoothL1Loss(beta=0.01,reduction='sum')

In [7]:
import utils
def create_mask():
    return utils.create_mask(width=512,height=512,probability=p,mode='bernoulli',device=device)

In [8]:
# noise as clean：noisy-P'->noisy' 如果clean-P->noisy中的P近似于P'那么(noisy',noisy)训练可以近似复原噪声图片
from torch import log, var, zeros_like,zeros,abs,std,mean,square,div

from torch.nn.functional import relu,avg_pool2d,conv2d
from torchvision.utils import save_image

def train():
    optimizer.zero_grad()
    mask=create_mask()
    mask_inv=1-mask
    #只计算mask区域的loss
    loss=smooth_l1_loss(model(noisy*mask)*mask_inv,noisy*mask_inv)/mask_inv.sum()
    #clean=model(noisy)  
    #梯度有多种使用方法：1.使noise梯度和noisy同再去除低频分量 2.使clean和noise梯度尽可能大
    #loss2=smooth_l1_loss(image_gradient(noisy,device=device),image_gradient(noisy-clean,device=device))
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        return loss.sum()

def main(): 
    for epoch in tqdm(range(200000)):
        print(train())
        if (epoch+1)%1000==0:
            result=model(noisy*create_mask())
            for i in range(99):
                result+=create_mask()
            result/=100
            save_image(tensor=result,fp=f'result/images/result-{(epoch+1)/1000}.png')#必须输入标准化了的图像
main()

  0%|          | 0/200000 [00:00<?, ?it/s]

tensor(0.5192, device='cuda:0')
tensor(0.5148, device='cuda:0')
tensor(0.4475, device='cuda:0')
tensor(1.0333, device='cuda:0')
tensor(0.4470, device='cuda:0')
tensor(0.4103, device='cuda:0')
tensor(0.4810, device='cuda:0')
tensor(0.3809, device='cuda:0')
tensor(0.4873, device='cuda:0')
tensor(0.4361, device='cuda:0')
tensor(0.4247, device='cuda:0')
tensor(2.3756, device='cuda:0')
tensor(0.4294, device='cuda:0')
tensor(0.6068, device='cuda:0')
tensor(0.7890, device='cuda:0')
tensor(1.4488, device='cuda:0')
tensor(0.7028, device='cuda:0')
tensor(0.4259, device='cuda:0')
tensor(0.4928, device='cuda:0')
tensor(1.1319, device='cuda:0')
tensor(3.1626, device='cuda:0')
tensor(0.7069, device='cuda:0')
tensor(1.0641, device='cuda:0')
tensor(1.2858, device='cuda:0')
tensor(0.5327, device='cuda:0')
tensor(0.5520, device='cuda:0')
tensor(1.1524, device='cuda:0')
tensor(2.2862, device='cuda:0')
tensor(0.4459, device='cuda:0')
tensor(2.0673, device='cuda:0')
tensor(0.3979, device='cuda:0')
tensor(0