In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image
import visdom

In [None]:
vis = visdom.Visdom(env='deepVRX')
vis.line([1.,2.],[1.,2.],win = 'pix_loss',name = 'pix_loss',opts = dict(title = 'pix_loss',legend = ['pix_loss']))

In [None]:
def showImageVis(img):
    # 将张量转换为NumPy数组并进行反归一化
    img = img.cpu().detach().numpy().transpose((1, 2, 0))  # 将通道维度放在最后
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    img = std * img + mean
    img = img.transpose((2, 0, 1))
    # 展示图像
    vis.image(img,win = 'img')

In [None]:
nf = 16
nh = 16
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,nf,kernel_size=8,stride=1,padding=0),
            nn.BatchNorm2d(nf),
            nn.ReLU()
            )
        self.conv2 = nn.Sequential(
            nn.Conv2d(nf,nf*2,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nf*2),
            nn.ReLU()
            )
        self.conv3 = nn.Sequential(
            nn.Conv2d(nf*2,nh,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nh),
            nn.ReLU()
            )
        self.convTrans1 = nn.Sequential(
            nn.ConvTranspose2d(nh,nf*2,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nf*2),
            nn.ReLU()
            )
        self.convTrans2 = nn.Sequential(
            nn.ConvTranspose2d(nf*4,nf,kernel_size=5,stride=2,padding=0),
            nn.BatchNorm2d(nf),
            nn.ReLU(),
            )
        self.convTrans3 = nn.Sequential(
            nn.ConvTranspose2d(nf*2,3,kernel_size=5,stride=1,padding=0),
            nn.BatchNorm2d(3),
            nn.Tanh()
            )
        self.convTrans4 = nn.Sequential(
            nn.ConvTranspose2d(3,3,kernel_size=4,stride=1,padding=0),
            nn.BatchNorm2d(3),
            nn.Tanh()
            )
    def forward(self,x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.convTrans1(x3)
        x5 = self.convTrans2(torch.cat((x2,x4),dim=1))
        x6 = self.convTrans3(torch.cat((x1,x5),dim=1))
        x7 = self.convTrans4(x6)
        return x7

In [None]:
class Discr(nn.Module):
    def __init__(self):
        super(Discr, self).__init__()

        self.main = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=0),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.MaxPool2d(kernel_size=3,stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.MaxPool2d(kernel_size=3,stride=3),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Flatten(1),
            nn.Linear(128*3*4,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        return x


In [None]:
G = Generator().cuda()
G.load_state_dict(torch.load('./models/G.pth'))

In [None]:
D = Discr().cuda()
D.load_state_dict(torch.load('./models/D.pth'))

In [None]:
class p2pDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.noise_folder = os.path.join(data_dir, './pics2/clear1/')
        self.clear_folder = os.path.join(data_dir, './pics2/noise1/')
        self.image_list = os.listdir(self.noise_folder)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]

        noise_path = os.path.join(self.noise_folder, img_name)
        clear_path = os.path.join(self.clear_folder, img_name)

        noise_img = Image.open(noise_path).convert('RGB')
        clear_img = Image.open(clear_path).convert('RGB')

        if self.transform:
            noise_img = self.transform(noise_img)
            clear_img = self.transform(clear_img)

        return noise_img, clear_img


In [None]:
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = p2pDataset('',transform=data_transform)

In [None]:
batch_size = 20

epochs = 500
printStep = 10
showStep = 100

optG = torch.optim.Adam(G.parameters(), lr=0.008)
optD = torch.optim.RMSprop(D.parameters(), lr=0.0005)

criterionG = nn.MSELoss()
criterionD = nn.MSELoss()

losslist = []

data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
real_labels = torch.ones(batch_size, 1).cuda()
fake_labels = torch.zeros(batch_size, 1).cuda()
losslist = []
for epoch in range(epochs):
    for step,(img,img2) in enumerate(data_loader):
        clear_img = img.cuda()
        noise_img = img2.cuda()
        
        #train D
        real_outputs = D(clear_img)
        real_loss = criterionD(real_outputs, real_labels)
        
        with torch.no_grad():
            fake_img = G(noise_img)
        fake_outputs = D(fake_img.detach())
        fake_loss = criterionD(fake_outputs, fake_labels)
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optD.step()
        optD.zero_grad()
        
        #train G
        fake_img = G(noise_img)
        fake_outputs = D(fake_img)
        
        g_gan_loss = criterionD(fake_outputs, real_labels*0.9)*0.05
        g_pix_loss = criterionG(fake_img,clear_img)*10
        losslist.append(g_pix_loss.item())
        g_loss = g_gan_loss+g_pix_loss
        g_loss.backward()
        optG.step()
        optG.zero_grad()
        if step % 5 == 0:
            vis.line(Y=losslist,name='pix_loss',win='pix_loss')
    if epoch % 1 == 0:
        print('Epoch: {}, Step: {}, D_loss: {:.5f}, G_loss: gan {:.5f} + pix {:.5f}'.format(epoch, step, d_loss.item(), g_gan_loss.item(), g_pix_loss.item()))
    if epoch % 1 == 0:
        showImageVis(torch.cat((noise_img[0],fake_img[0],clear_img[0]),dim=2))

In [None]:
vis.image(fake_img[1])

In [None]:
torch.save(G.parameters,'./models/G2.pth')
torch.save(D.parameters,'./models/D2.pth')