# 环境

In [1]:
import torch
import torchmetrics
import torchvision
import os
import gc
import torch.nn as nn
from torchvision import transforms
print(torch.__version__)
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    gc.collect()
    torch.cuda.empty_cache()
print(os.getcwd())


1.11.0+cu113
cuda
/root/autodl-tmp/deep-learning/unsupervised denoise


# 模型

In [2]:
class HourGlassCNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,size=128,layernorm=False,batchnorm=False,activation=True):
        super(HourGlassCNNBlock,self).__init__()
        layers=[]
        layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1))
        if layernorm:layers.append(nn.LayerNorm([out_channels,size,size]))
        if batchnorm:layers.append(nn.BatchNorm2d(out_channels))
        if activation:layers.append(nn.SiLU())
        self.main=nn.Sequential(*layers)
    def forward(self,x):
        return self.main(x)#channel不同不能进行残差学习

class HourGlassCNN(nn.Module):
    def __init__(self,size=128,layernorm=False,batchnorm=False,max_channels=512,residual=True):
        super(HourGlassCNN,self).__init__()
        self.blocks=nn.ModuleList()
        self.blocks.append( HourGlassCNNBlock(3,64,size,layernorm=False,batchnorm=False)),#0
        channel=128
        while channel<=max_channels:
            self.blocks.append(HourGlassCNNBlock(channel//2,channel,size,layernorm,batchnorm))
            channel*=2
        channel//=2
        while channel>64:
            self.blocks.append(HourGlassCNNBlock(channel,channel//2,size,layernorm,batchnorm))
            channel//=2
        self.blocks.append(HourGlassCNNBlock(64,3,size,layernorm=False,batchnorm=False,activation=False))
        self.residual=residual
    def forward(self,x):
        blocks_len=len(self.blocks)
        y=[]
        y.append(self.blocks[0](x))
        for i in range(1,blocks_len//2):y.append(self.blocks[i](y[i-1]))
        for i in range(blocks_len//2,blocks_len-1):y.append(self.blocks[i](y[i-1])+y[blocks_len-i-2]if self.residual else self.blocks[i](y[i-1]))
        return self.blocks[blocks_len-1](y[-1])

class NoiseExtractor(nn.Module):
    def __init__(self,size=128,layernorm=False,batchnorm=False):
        super(NoiseExtractor,self).__init__()
        """
        main->dependent noise
        main->independent noise
        """
        self.main=HourGlassCNN(size,layernorm,batchnorm)
        self.dependent=HourGlassCNN(size,layernorm,batchnorm,256)
        self.independent=HourGlassCNN(size,layernorm,batchnorm,256)
    def forward(self,x):
       main=self.main(x)+x
       return self.dependent(main)+main,self.independent(main)+main

class Denoiser(nn.Module):
    def __init__(self,size=128,layernorm=False,batchnorm=False):
        super(Denoiser,self).__init__()
        self.main=HourGlassCNN(size,layernorm,batchnorm)
    def forward(self,x):
        return self.main(x)+x

class CVF_SID(nn.Module):
    def __init__(self,size=128,layernorm=False,batchnorm=False) -> None:
        super(CVF_SID,self).__init__()
        self.denoiser=Denoiser(size,layernorm,batchnorm)
        self.noise_extrator=NoiseExtractor(size,layernorm,batchnorm)
    def forward(self,x):
        clean=self.denoiser(x)
        dependent,independent=self.noise_extrator(x-clean)
        return clean,dependent,independent

# 分析

In [3]:
import torch.utils.tensorboard
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.notebook import tqdm
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from torchmetrics import PeakSignalNoiseRatio

In [4]:
def convert_to_rgb255(image:torch.Tensor):
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    return image
def show_image(image:torch.Tensor):
    image=convert_to_rgb255(image)
    plt.imshow(transforms.ToPILImage()(image))
    plt.show()

writer=SummaryWriter()
psnr=PeakSignalNoiseRatio().to(device=device)

# 数据

In [5]:
img_size=128
batch_size=8

In [6]:

from torch.utils.data import Dataset 
from torch.utils.data import DataLoader

from PIL import Image
# https://discuss.pytorch.org/t/torchvision-transfors-how-to-perform-identical-transform-on-both-image-and-target/10606/7

class BSDSPairsDataSet(Dataset):
    def __init__(self,imgs_dir=None,img_size=128,sigma=60) -> None:
        super(BSDSPairsDataSet,self).__init__()
        if not imgs_dir:imgs_dir=f'./data/BSDS300/{img_size}x{img_size}/base/train'
        self.img_size=img_size
        self.files=os.listdir(imgs_dir)
        self.sigma=sigma
        self.imgs_dir=imgs_dir
    def __getitem__(self,idx:int):
        img_path = os.path.join(self.imgs_dir, self.files[idx])
        clean = Image.open(img_path).convert('RGB')#比torchvision的好
        transform = transforms.Compose([
            # hwc->chw 并 归一化到[0,1]
            transforms.ToTensor(),
            # [−1,1]
            transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
            ])
        clean = transform(clean)
        noisy = clean + self.sigma * torch.randn(clean.shape)/255
        return noisy, clean
    def __len__(self):
        return len(self.files)
class BSDSNoisyPairsDataSet(Dataset):
    def __init__(self,imgs_dir=None,img_size=128,sigma=60,sigma_plus=30) -> None:
        super(BSDSPairsDataSet,self).__init__()
        if not imgs_dir:imgs_dir=f'./data/BSDS300/{img_size}x{img_size}/base/train'
        self.img_size=img_size
        self.files=os.listdir(imgs_dir)
        self.sigma=sigma
        self.sigma_plus=sigma_plus
        self.imgs_dir=imgs_dir
    def __getitem__(self,idx:int):
        img_path = os.path.join(self.imgs_dir, self.files[idx])
        clean = Image.open(img_path).convert('RGB')#比torchvision的好
        transform = transforms.Compose([
            # hwc->chw 并 归一化到[0,1]
            transforms.ToTensor(),
            # [−1,1]
            transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
            ])
        noisy1 = transform(clean) + self.sigma * torch.randn(clean.shape)/255
        noisy2 = noisy2+self.sigma_plus* torch.randn(clean.shape)/255
        return noisy2,noisy1
    def __len__(self):
        return len(self.files)

pre_train_set=BSDSNoisyPairsDataSet()
pre_train_loader=DataLoader(pre_train_set,batch_size=batch_size,shuffle=True)

train_set=BSDSPairsDataSet()
train_loader=DataLoader(train_set,batch_size=batch_size,shuffle=True)


TypeError: super(type, obj): obj must be an instance or subtype of type

# 训练

In [None]:
import itertools
from torch.optim import Adam

model = CVF_SID().to(device)
optimizer = Adam(model.parameters() , lr=1e-3,weight_decay=0.1)

In [None]:
# noise as clean：noisy-P'->noisy' 如果clean-P->noisy中的P近似于P'那么(noisy',noisy)训练可以近似复原噪声图片
from torch import log, var, zeros_like,zeros,abs,std,mean,square
from torch.nn import SmoothL1Loss
from torch.nn.functional import relu,avg_pool2d
smooth_l1_loss=SmoothL1Loss(beta=0.1)

def train(noisy):
    optimizer.zero_grad()
    losses=[]
    clean,dependent,independent=model(noisy)
    #consistency
    losses.append(smooth_l1_loss(clean+clean*dependent+independent,noisy))
    #identity
    clean1,dependent1,independent1=model(clean)
    losses.append(smooth_l1_loss(clean1,clean))
    clean2,dependent2,independent2=model(clean+clean*dependent)
    losses.append(smooth_l1_loss(dependent2,dependent))
    clean3,dependent3,independent3=model(independent)
    losses.append(smooth_l1_loss(independent3,independent))
    losses.append(smooth_l1_loss(clean2,clean))
    #zeros
    losses.append(smooth_l1_loss(clean3,zeros_like(clean3)))
    losses.append(smooth_l1_loss(dependent1,zeros_like(dependent1)))
    losses.append(smooth_l1_loss(independent1,zeros_like(independent1)))
    losses.append(smooth_l1_loss(independent2,zeros_like(independent2)))

    loss=losses[0]
    for i in range(1,len(losses)):loss+=losses[i]
    optimizer.step()
    total_loss=0.0
    with torch.no_grad():
        total_loss=loss.sum()
    return total_loss
def draw(mean_loss_noisy,mean_loss_clean,epoch):
    noisys,cleans=next(iter(train_loader))
    noisys=noisys.to(device=device)
    cleans=cleans.to(device=device)
    outs=model(noisys)
    writer.add_scalar('noisy_loss',mean_loss_noisy,epoch)
    writer.add_scalar('clean_loss',mean_loss_clean,epoch)
    writer.add_scalar('psnr',psnr(outs[0],cleans),epoch)
    writer.add_image('noisy',make_grid(convert_to_rgb255(noisys)),epoch)
    writer.add_image('clean',make_grid(convert_to_rgb255(cleans)),epoch)
    writer.add_image('out',make_grid(convert_to_rgb255(outs[0])),epoch)
    writer.add_image('dependent',make_grid(convert_to_rgb255(outs[1])),epoch)
    writer.add_image('independent',make_grid(convert_to_rgb255(outs[2])),epoch)

def main(): 
    for epoch in tqdm(range(5000)):
        total_loss_noisy,total_loss_clean=0,0
        for noisys,cleans in train_loader:
            noisys=noisys.to(device=device)
            cleans=cleans.to(device=device)
            total_loss_noisy+=train(noisys)
            with torch.no_grad():
                total_loss_clean+=smooth_l1_loss(model(noisys)[0],cleans).sum()
        with torch.no_grad():
            draw(total_loss_noisy,total_loss_clean,epoch)
main()

  0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 