# 环境

In [12]:
import torch
import torchmetrics
import torchvision
import os
import gc
print(torch.__version__)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    gc.collect()
    torch.cuda.empty_cache()
print(os.getcwd())


1.11.0+cu113
cuda
/root/autodl-tmp/deep-learning/DNCNN


In [13]:
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.utils.tensorboard
from torch.utils.tensorboard.writer import SummaryWriter
from torchvision.utils import make_grid
from tqdm.notebook import tqdm
from torchmetrics import PeakSignalNoiseRatio
import torch.nn.functional
from torch.optim import Adam

# 分析

In [14]:
def convert_to_rgb255(image:torch.Tensor):
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    return image
def show_image(image:torch.Tensor):
    image=convert_to_rgb255(image)
    plt.imshow(transforms.ToPILImage()(image))
    plt.show()

writer=SummaryWriter()
psnr=PeakSignalNoiseRatio().to(device=device)

# 数据
标准差为30的高斯噪声180*180彩铅（左上角或者随机位置）

In [15]:
class BSDSPairsDataset(Dataset):

    def __init__(self, root_dir, mode='train', image_size=(128, 128), sigma=30):
        super(BSDSPairsDataset, self).__init__()
        self.mode = mode
        self.image_size = image_size
        self.sigma = sigma
        self.images_dir = os.path.join(root_dir, mode)
        self.files = os.listdir(self.images_dir)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.files[idx])
        clean = Image.open(img_path).convert('RGB')#比torchvision的好
        transform = transforms.Compose([
            transforms.CenterCrop(128),
            # hwc->chw 并 归一化到[0,1]
            transforms.ToTensor(),
            # [−1,1]
            transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
            ])
        clean = transform(clean)
        
        noisy = clean + self.sigma * torch.randn(clean.shape)/255
        return noisy, clean

In [16]:
train_set= BSDSPairsDataset('data/BSDS300/images/')
test_set = BSDSPairsDataset('data/BSDS300/128x128/',mode = 'test')
train_loader=DataLoader(train_set,batch_size=8,shuffle=True)
test_loader=DataLoader(test_set,batch_size=8,shuffle=True)

# 模型

In [17]:
class DnCNN(nn.Module):
    def __init__(self, deep, channel=64):
        super(DnCNN,self).__init__()
        self.deep=deep
        self.channel=channel

        self.layers = nn.ModuleList()
        self.layers.append(nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=channel, kernel_size=3, padding=1),
            nn.SiLU()
        ))
        for _ in range(deep-2):
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, padding=1),
                nn.BatchNorm2d(channel),
                nn.SiLU()
            ))
        self.layers.append(nn.Conv2d(in_channels=channel, out_channels=3, kernel_size=3, padding=1))
    def forward(self, x):
        y=x
        for layer in self.layers:y=layer(y)
        return y+x #残差学习，学习噪声比学习输出干净图像更容易
    def loss(self,input,target,alpha=0.95,beta=0.05)->torch.Tensor:
        percetual_loss=None #感知损失 相似度 感知损失只是修正量、正则项，而不是目标项，不同的输入每一层的输出都相同是不可能的，所以它的比例应该尽可能小
        x,y=input,target
        for layer in self.layers:
            x=layer(x)
            y=layer(y)
            if percetual_loss==None:percetual_loss=nn.functional.smooth_l1_loss(x,y,beta=0.1)
            else:percetual_loss+=nn.functional.smooth_l1_loss(x,y,beta=0.1)
        return alpha*nn.functional.smooth_l1_loss(x+input,target,beta=0.1)+beta*percetual_loss/self.deep

In [18]:
class StackDnCNN(nn.Module):
    def __init__(self,size,deep,channel=64):
        super(StackDnCNN,self).__init__()
        self.deep=deep
        self.channel=channel
        self.size=size
        self.dncnns=nn.ModuleList()
        for _ in range(size):self.dncnns.append(DnCNN(deep,channel))
        self.optimizers=[Adam(dncnn.parameters(),lr=1e-3) for dncnn in self.dncnns]
    def forward(self,x):
        for dncnn in self.dncnns:y=dncnn(x)
        return x

In [19]:
class HourGlassCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,size=(128,128),layernorm=False,batchnorm=False,activation=True):
        super(HourGlassCNNBlock,self).__init__()
        layers=[]
        layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1))
        if layernorm:layers.append(nn.LayerNorm([out_channels,size[0],size[1]]))
        if batchnorm:layers.append(nn.BatchNorm2d(out_channels))
        if activation:layers.append(nn.SiLU())
        self.main=nn.Sequential(*layers)
    def forward(self,x):
        return self.main(x)#channel不同不能进行残差学习

class HourGlassCNN(nn.Module):
    def __init__(self,size=(128,128),layernorm=False,batchnorm=False):
        super(HourGlassCNN,self).__init__()
        self.blocks=nn.ModuleList([
            HourGlassCNNBlock(3,64,size,layernorm=False,batchnorm=False),#0
            HourGlassCNNBlock(64,128,size,layernorm,batchnorm),#1
            HourGlassCNNBlock(128,256,size,layernorm,batchnorm),#2
            HourGlassCNNBlock(256,512,size,layernorm,batchnorm),#3
            HourGlassCNNBlock(512,1024,size,layernorm,batchnorm),#4

            HourGlassCNNBlock(1024,512,size,layernorm,batchnorm),#5
            HourGlassCNNBlock(512,256,size,layernorm,batchnorm),#6
            HourGlassCNNBlock(256,128,size,layernorm,batchnorm),#7
            HourGlassCNNBlock(128,64,size,layernorm,batchnorm),#8
            HourGlassCNNBlock(64,3,size,layernorm=False,batchnorm=False,activation=False),#9
        ])
    def forward(self,x):
        y0=self.blocks[0](x)
        y1=self.blocks[1](y0)
        y2=self.blocks[2](y1)
        y3=self.blocks[3](y2)

        y4=self.blocks[4](y3)

        y5=self.blocks[5](y4)+y3#残差
        y6=self.blocks[6](y5)+y2
        y7=self.blocks[7](y6)+y1
        y8=self.blocks[8](y7)+y0
    
        y9=self.blocks[9](y8)+x
        return y9
    def loss(self,input,target):
        return nn.functional.smooth_l1_loss(self(input),target,beta=0.1)

# 训练

## DnCNN

## StackDnCNN

## HourGlassCNN