# 环境

In [1]:
import torch
import torchmetrics
import torchvision
import os
import gc
import torch.nn as nn
from torchvision import transforms
print(torch.__version__)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    gc.collect()
    torch.cuda.empty_cache()
print(os.getcwd())


1.11.0+cu113
cuda
/root/autodl-tmp/deep-learning/unsupervised denoise


# 模型

In [2]:
class HourGlassCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,size=(128,128),layernorm=False,batchnorm=False,activation=True):
        super(HourGlassCNNBlock,self).__init__()
        layers=[]
        layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1))
        if layernorm:layers.append(nn.LayerNorm([out_channels,size[0],size[1]]))
        if batchnorm:layers.append(nn.BatchNorm2d(out_channels))
        if activation:layers.append(nn.SiLU())
        self.main=nn.Sequential(*layers)
    def forward(self,x):
        return self.main(x)#channel不同不能进行残差学习

class HourGlassCNN(nn.Module):
    def __init__(self,size=(128,128),layernorm=False,batchnorm=False):
        super(HourGlassCNN,self).__init__()
        self.blocks=nn.ModuleList([
            HourGlassCNNBlock(3,64,size,layernorm=False,batchnorm=False),#0
            HourGlassCNNBlock(64,128,size,layernorm,batchnorm),#1
            HourGlassCNNBlock(128,256,size,layernorm,batchnorm),#2
            HourGlassCNNBlock(256,512,size,layernorm,batchnorm),#3
            HourGlassCNNBlock(512,1024,size,layernorm,batchnorm),#4

            HourGlassCNNBlock(1024,512,size,layernorm,batchnorm),#5
            HourGlassCNNBlock(512,256,size,layernorm,batchnorm),#6
            HourGlassCNNBlock(256,128,size,layernorm,batchnorm),#7
            HourGlassCNNBlock(128,64,size,layernorm,batchnorm),#8
            HourGlassCNNBlock(64,3,size,layernorm=False,batchnorm=False,activation=False),#9
        ])
    def forward(self,x):
        y0=self.blocks[0](x)
        y1=self.blocks[1](y0)
        y2=self.blocks[2](y1)
        y3=self.blocks[3](y2)

        y4=self.blocks[4](y3)

        y5=self.blocks[5](y4)+y3#残差
        y6=self.blocks[6](y5)+y2
        y7=self.blocks[7](y6)+y1
        y8=self.blocks[8](y7)+y0
    
        y9=self.blocks[9](y8)+x
        return y9
    def loss(self,input,target):
        return nn.functional.smooth_l1_loss(input,target,beta=0.1)

# 分析

In [3]:
import torch.utils.tensorboard
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.notebook import tqdm
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from torchmetrics import PeakSignalNoiseRatio

In [4]:
def convert_to_rgb255(image:torch.Tensor):
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    return image
def show_image(image:torch.Tensor):
    image=convert_to_rgb255(image)
    plt.imshow(transforms.ToPILImage()(image))
    plt.show()

writer=SummaryWriter()
psnr=PeakSignalNoiseRatio().to(device=device)

# 数据

In [5]:
img_size=128
batch_size=16

In [6]:

from torch.utils.data import Dataset 
from torch.utils.data import DataLoader

from PIL import Image
# https://discuss.pytorch.org/t/torchvision-transfors-how-to-perform-identical-transform-on-both-image-and-target/10606/7

class BSDSPairsDataSet(Dataset):
    def __init__(self,imgs_dir=None,img_size=128,sigma=30) -> None:
        super(BSDSPairsDataSet,self).__init__()
        if not imgs_dir:imgs_dir=f'./data/BSDS300/{img_size}x{img_size}/base/train'
        self.img_size=img_size
        self.files=os.listdir(imgs_dir)
        self.sigma=sigma
        self.imgs_dir=imgs_dir
    def __getitem__(self,idx:int):
        img_path = os.path.join(self.imgs_dir, self.files[idx])
        clean = Image.open(img_path).convert('RGB')#比torchvision的好
        transform = transforms.Compose([
            # hwc->chw 并 归一化到[0,1]
            transforms.ToTensor(),
            # [−1,1]
            transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
            ])
        clean = transform(clean)
        noisy = clean + self.sigma * torch.randn(clean.shape)/255
        return noisy, clean
    def __len__(self):
        return len(self.files)

train_set=BSDSPairsDataSet()
train_loader=DataLoader(train_set,batch_size=batch_size,shuffle=True)


# 训练

In [7]:
import itertools
from torch.optim import Adam
noise_extractor = HourGlassCNN(layernorm=True,batchnorm=True).to(device)
restomer = HourGlassCNN(layernorm=True,batchnorm=True).to(device)
noise_extractor_optimizer = Adam(noise_extractor.parameters() , lr=1e-3)
restomer_optimizer=Adam(restomer.parameters() , lr=1e-3)

In [8]:
# noise as clean：noisy-P'->noisy' 如果clean-P->noisy中的P近似于P'那么(noisy',noisy)训练可以近似复原噪声图片
from random import randint
from torch import zeros_like

def train(input):
    total_loss=0.0
    noisys1=input[0:batch_size//2]
    noisys2=input[batch_size//2:batch_size]

    
    #fake_noisys=restomer(noisys+noise)==noisys 只训练restomer
    noise_extractor_optimizer.zero_grad()
    restomer_optimizer.zero_grad()

    noise1=noise_extractor(noisys1)-noisys1
    noise2=noise_extractor(noisys2)-noisys2
    fake_noisys1=restomer(noisys1+noise2.detach())
    fake_noisys2=restomer(noisys2+noise1.detach())
    loss1=restomer.loss(fake_noisys2,noisys2)+restomer.loss(fake_noisys1,noisys1)
    loss1.backward()

    noise_extractor_optimizer.step()
    restomer_optimizer.step()#optimizer.step后才更新网络参数，loss只是计算梯度
    
    
    
    #restomer(cleans)==cleans
    noise_extractor_optimizer.zero_grad()
    restomer_optimizer.zero_grad()

    cleans1=noisys1-noise1
    cleans2=restomer(noisys2)
    loss2=restomer.loss(restomer(cleans1),cleans1)+restomer.loss(restomer(cleans2),cleans2)+noise_extractor.loss(noise_extractor(cleans1),zeros_like(cleans1))+noise_extractor.loss(noise_extractor(cleans2),zeros_like(cleans2))
    loss2.backward()
    noise_extractor_optimizer.step()
    restomer_optimizer.step()#optimizer.step后才更新网络参数，loss只是计算梯度

    #restomer(fake_noisys)=restomer(noisys)
    #loss3=restomer.loss(fake_noisys1,noisys1)+restomer.loss(restomer(fake_noisys2),cleans2)

    with torch.no_grad():
        total_loss+=loss1.sum()+loss2.sum()
    return total_loss
def draw(mean_loss_noisy,mean_loss_clean,noisys,cleans,epoch):
    outs=restomer(noisys)
    writer.add_scalar('noisy_loss',mean_loss_noisy,epoch)
    writer.add_scalar('clean_loss',mean_loss_clean,epoch)
    writer.add_scalar('psnr',psnr(outs,cleans),epoch)
    writer.add_image('noisy',make_grid(convert_to_rgb255(noisys)),epoch)
    writer.add_image('clean',make_grid(convert_to_rgb255(cleans)),epoch)
    writer.add_image('out',make_grid(convert_to_rgb255(outs)),epoch)
    writer.add_image('noise',make_grid(convert_to_rgb255(noise_extractor(noisys))),epoch)
def main(): 
    for epoch in tqdm(range(5000)):
        total_loss_noisy,total_loss_clean=0,0
        for noisys,cleans in train_loader:
            noisys=noisys.to(device=device)
            cleans=cleans.to(device=device)
            total_loss_noisy+=train(noisys)
            with torch.no_grad():
                total_loss_clean+=restomer.loss(restomer(noisys),cleans).sum()
        with torch.no_grad():
            draw(total_loss_noisy,total_loss_clean,noisys,cleans,epoch)
main()

  0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 