In [1]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset

import os
import time
import cv2
import random
import skimage
from skimage.util import random_noise
import numpy as np
from PIL import Image
from PIL import ImageFile


In [2]:
n_blocks = 5 
n_epochs = 100 
batch_size = 16
train_path = 'data/Train/images_png/' 
val_path = 'data/Train/val_images/' 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
ImageFile.LOAD_TRUNCATED_IMAGES = True 
randomcrop = transforms.RandomCrop(128) 

In [3]:
def addGaussNoise(data, sigma):
    sigma2 = sigma**2 / (255 ** 2)
    noise = random_noise(data, mode='gaussian', var=sigma2, clip=True)
    return noise

In [4]:
class MyDataset(Dataset):
    def __init__(self, path, transform, sigma=30, ex=1):
        self.transform = transform
        self.sigma = sigma

        for _, _, files in os.walk(path):
            self.imgs = [path + file for file in files if Image.open(path + file).size >= (96,96)] * ex
        np.random.shuffle(self.imgs)

    def __getitem__(self, index):
        tempImg = self.imgs[index]
        tempImg = Image.open(tempImg).convert('RGB') #数据集中有部分图片为灰度图，将所有图片转换为RGB格式
        Img = np.array(self.transform(tempImg))/255 #像素归一化至[0,1]
        nImg = addGaussNoise(Img, self.sigma) #添加高斯噪声
        Img = torch.tensor(Img.transpose(2,0,1)) #由于Image.open加载的图片是H*W*C的格式，因此转换成C*H*W的格式
        nImg = torch.tensor(nImg.transpose(2,0,1))
        return Img, nImg

    def __len__(self):
        return len(self.imgs)

In [5]:
def get_data(batch_size, train_path, val_path, transform, sigma, ex=1):
    train_dataset = MyDataset(train_path, transform, sigma, ex)
    val_dataset = MyDataset(val_path, transform, sigma, ex)
    train_iter = DataLoader(train_dataset, batch_size, drop_last=True, num_workers=6)
    val_iter = DataLoader(val_dataset, batch_size, drop_last=True, num_workers=6)
    return train_iter, val_iter

In [6]:
train_iter, val_iter = get_data(batch_size, train_path, val_path, randomcrop, 30, ex=1)

In [7]:
def calculate_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

class ResBlock(nn.Module):
    def __init__(self, inC, outC):
        super(ResBlock, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(inC, outC, kernel_size=3, stride=1, padding=1, bias=False), 
                                    nn.BatchNorm2d(outC), 
                                    nn.PReLU())

        self.layer2 = nn.Sequential(nn.Conv2d(outC, outC, kernel_size=3, stride=1, padding=1, bias=False), 
                                    nn.BatchNorm2d(outC))

    def forward(self, x):
        resudial = x

        out = self.layer1(x)
        out = self.layer2(out)
        out = out + resudial

        return out
    


class Generator(nn.Module):
    def __init__(self, n_blocks):
        super(Generator, self).__init__()
        self.convlayer1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4, bias=False),
                                        nn.PReLU())

        self.ResBlocks = nn.ModuleList([ResBlock(64, 64) for _ in range(n_blocks)]) #叠加n_blocks个残差块

        self.convlayer2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 
                                        nn.BatchNorm2d(64))

        self.convout = nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4, bias=False)

    def forward(self, x):
        out = self.convlayer1(x)
        residual = out

        for block in self.ResBlocks:
            out = block(out)

        out = self.convlayer2(out)
        out = out + residual

        out = self.convout(out)

        return out
    


class DownSample(nn.Module):
    def __init__(self, input_channel, output_channel,  stride, kernel_size=3, padding=1):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(nn.Conv2d(input_channel, output_channel, kernel_size, stride, padding),
                                   nn.BatchNorm2d(output_channel),
                                   nn.LeakyReLU(inplace=True))

    def forward(self, x):
        x = self.layer(x)
        return x

#判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),
                                   nn.LeakyReLU(inplace=True))

        self.down = nn.Sequential(DownSample(64, 64, stride=2, padding=1),
                                  DownSample(64, 128, stride=1, padding=1),
                                  DownSample(128, 128, stride=2, padding=1),
                                  DownSample(128, 256, stride=1, padding=1),
                                  DownSample(256, 256, stride=2, padding=1),
                                  DownSample(256, 512, stride=1, padding=1),
                                  DownSample(512, 512, stride=2, padding=1))

        self.dense = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                   nn.Conv2d(512, 1024, 1),
                                   nn.LeakyReLU(inplace=True),
                                   nn.Conv2d(1024, 1, 1),
                                   nn.Sigmoid()) #Loss为nn.BCELoss则加Sigmoid，若为nn.BCEWithLogitsLoss则不加，因为此Loss里包括了Sigmoid

    def forward(self, x):
        x = self.conv1(x)
        x = self.down(x)
        x = self.dense(x)
        return x


In [8]:
import torch
import torch.nn as nn
import torchvision.models as models

#SRGAN使用预训练好的VGG19，用生成器的结果以及原始图像通过VGG后分别得到的特征图计算MSE，具体解释推荐看SRGAN的相关资料
class VGG(nn.Module):
    def __init__(self, device):
        super(VGG, self).__init__()
        vgg = models.vgg19(True)
        for pa in vgg.parameters():
            pa.requires_grad = False
        self.vgg = vgg.features[:16]
        self.vgg = self.vgg.to(device)

    def forward(self, x):
        out = self.vgg(x)
        return out

#内容损失
class ContentLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.mse = nn.MSELoss()
        self.vgg19 = VGG(device)

    def forward(self, fake, real):
        feature_fake = self.vgg19(fake)
        feature_real = self.vgg19(real)
        loss = self.mse(feature_fake, feature_real)
        return loss

#对抗损失
class AdversarialLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        loss = torch.sum(-torch.log(x))
        return loss

#感知损失
class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.vgg_loss = ContentLoss(device)
        self.adversarial = AdversarialLoss()

    def forward(self, fake, real, x):
        vgg_loss = self.vgg_loss(fake, real)
        adversarial_loss = self.adversarial(x)
        return vgg_loss + 1e-3*adversarial_loss

#正则项，需要说明的是，在SRGAN的后续版本的论文中，这个正则项被删除了
class RegularizationLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        a = torch.square(
            x[:, :, :x.shape[2]-1, :x.shape[3]-1] - x[:, :, 1:x.shape[2], :x.shape[3]-1]
        )
        b = torch.square(
            x[:, :, :x.shape[2]-1, :x.shape[3]-1] - x[:, :, :x.shape[2]-1, 1:x.shape[3]]
        )
        loss = torch.sum(torch.pow(a+b, 1.25))
        return loss

In [9]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):

        self.patience = patience #等待多少个epoch之后停止
        self.verbose = verbose #是否显示日志
        self.counter = 0 #计步器
        self.best_score = None #记录最好性能
        self.early_stop = False #早停触发
        self.val_psnr_min = 0 #记录最小的验证PSNR
        self.delta = delta #可以给最好性能加上的小偏置
        self.checkpoint_perf = [] #记录检查点的性能

    def __call__(self, g, d, train_psnr, val_psnr):

        score = val_psnr
        self.early_stop = False

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(g, d, val_psnr)
        elif score < self.best_score + self.delta: #PSNR越大越好，因此这里是小于，若使用loss做指标，这里应改成大于
            self.counter += 1 #若当前性能不超过前一个epoch的性能则计步器+1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: #计步器累计到达极限，出发早停
                self.early_stop = True
                self.counter = 0
                self.best_score = None
                self.val_psnr_min = 0
        else:
            self.best_score = score
            self.save_checkpoint(g, d, val_psnr) #保存检查点
            self.counter = 0 #计步器重置
            self.checkpoint_perf = [train_psnr, val_psnr] #记录检查点性能数据
        return self.checkpoint_perf

    def save_checkpoint(self, g, d, val_psnr): #保存检查点
        self.val_psnr_min = val_psnr
        if self.verbose:
            print(f'Validation PSNR increased ({self.val_psnr_min:.6f} --> {val_psnr:.6f}).  Saving model ...')
            torch.save(g.state_dict(), 'Generator.pth')
            torch.save(d.state_dict(), 'Discriminator.pth')
        else:
            torch.save(g.state_dict(), 'Generator.pth')
            torch.save(d.state_dict(), 'Discriminator.pth')

In [10]:
lr = 0.001

G = Generator(n_blocks)
D = Discriminator()

G_loss = PerceptualLoss(device) 
Regulaztion = RegularizationLoss().to(device)
D_loss = nn.BCELoss().to(device)

optimizer_g = torch.optim.Adam(G.parameters(), lr=lr*0.1) 
optimizer_d = torch.optim.Adam(D.parameters(), lr=lr)

real_label = torch.ones([batch_size, 1, 1, 1]).to(device)
fake_label = torch.zeros([batch_size, 1, 1, 1]).to(device)

early_stopping = EarlyStopping(10, verbose=True)

#数据记录用
train_loss_g = []
train_loss_d = []
train_psnr = []
val_loss = []
val_psnr = []



In [11]:
from torchvision.utils import save_image, make_grid
import random
from tqdm.auto import tqdm

def train(generator, discriminator, train_iter, val_iter, n_epochs, optimizer_g, optimizer_d, loss_g, loss_d, Regulaztion, device):
    print('train on', device)
    generator.to(device)
    discriminator.to(device)
    cuda = next(generator.parameters()).device
    for epoch in range(n_epochs):
        train_epoch_loss_g = []
        train_epoch_loss_d = []
        train_epoch_psnr = []
        val_epoch_loss = []
        val_epoch_psnr = []
        start = time.time()
        generator.train()
        discriminator.train()
        
        train_bar = tqdm(train_iter, desc=f"Training Epoch {epoch+1}/{n_epochs}", postfix={'Loss G': 0.0, 'Loss D': 0.0, 'PSNR': 0.0})
        for i, (img, nimg) in enumerate(train_bar):
            img, nimg = img.to(cuda).float(), nimg.to(cuda).float()
            fakeimg = generator(nimg)
            
            optimizer_d.zero_grad()
            realOut = discriminator(img)
            fakeOut = discriminator(fakeimg.detach())
            loss_d = D_loss(realOut, real_label) + D_loss(fakeOut, fake_label)
            loss_d.backward()
            optimizer_d.step()
            
            optimizer_g.zero_grad()
            loss_g = G_loss(fakeimg, img, D(fakeimg)) + 2e-8*Regulaztion(fakeimg)
            loss_g.backward()
            optimizer_g.step()
            
            train_epoch_loss_d.append(loss_d.item())
            train_epoch_loss_g.append(loss_g.item())
            train_epoch_psnr.append(calculate_psnr(fakeimg, img).item())
            
            train_bar.set_postfix({'Loss G': np.mean(train_epoch_loss_g), 'Loss D': np.mean(train_epoch_loss_d), 'PSNR': np.mean(train_epoch_psnr)})
        
        train_epoch_avg_loss_g = np.mean(train_epoch_loss_g)
        train_epoch_avg_loss_d = np.mean(train_epoch_loss_d)
        train_epoch_avg_psnr = np.mean(train_epoch_psnr)
        train_loss_g.append(train_epoch_avg_loss_g)
        train_loss_d.append(train_epoch_avg_loss_d)
        train_psnr.append(train_epoch_avg_psnr)
        print(f'Epoch {epoch + 1}, Generator Train Loss: {train_epoch_avg_loss_g:.4f}, '
              f'Discriminator Train Loss: {train_epoch_avg_loss_d:.4f}, PSNR: {train_epoch_avg_psnr:.4f}')
        
        generator.eval()
        discriminator.eval()
        
        with torch.no_grad():
            val_bar = tqdm(val_iter, desc=f"Validation Epoch {epoch+1}/{n_epochs}", postfix={'Val Loss': 0.0, 'PSNR': 0.0})
            for i, (img, nimg) in enumerate(val_bar):
                img, nimg = img.to(cuda).float(), nimg.to(cuda).float()
                fakeimg = generator(nimg)
                loss_g = G_loss(fakeimg, img, D(fakeimg)) + 2e-8*Regulaztion(fakeimg)
                val_epoch_loss.append(loss_g.item())
                val_epoch_psnr.append(calculate_psnr(fakeimg, img).item())

                if epoch % 10 == 0:
                    if i % 8 == 0:
                        imgs_hr = make_grid(img, nrow=1, normalize=True)
                        gen_hr = make_grid(fakeimg, nrow=1, normalize=True)
                        imgs_lr = make_grid(nimg, nrow=1, normalize=True)
                        img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
                        save_image(img_grid, f"saved_models/srgan_chinese/images/{i}.png", normalize=False)

                val_bar.set_postfix({'Val Loss': np.mean(val_epoch_loss), 'PSNR': np.mean(val_epoch_psnr)})

            val_epoch_avg_loss = np.mean(val_epoch_loss)
            val_epoch_avg_psnr = np.mean(val_epoch_psnr)
            val_loss.append(val_epoch_avg_loss)
            val_psnr.append(val_epoch_avg_psnr)
            print(f'Generator Val Loss: {val_epoch_avg_loss:.4f}, PSNR: {val_epoch_avg_psnr:.4f}, Cost: {(time.time() - start):.4f}s')
            
            checkpoint_perf = early_stopping(generator, discriminator, train_epoch_avg_psnr, val_epoch_avg_psnr)
            if early_stopping.early_stop:
                print("Early stopping")
                print('Final model performance:')
                print(f'Train PSNR: {checkpoint_perf[0]}, Val PSNR: {checkpoint_perf[1]}')
                break
        
        torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
train(G, D, train_iter, val_iter, n_epochs, optimizer_g, optimizer_d, G_loss, D_loss, Regulaztion, device)

train on cuda


Training Epoch 1/100: 100%|██████████| 103/103 [02:24<00:00,  1.40s/it, Loss G=2.17, Loss D=0.389, PSNR=10.9]


Epoch 1, Generator Train Loss: 2.1737, Discriminator Train Loss: 0.3890, PSNR: 10.9480


Validation Epoch 1/100: 100%|██████████| 54/54 [00:21<00:00,  2.54it/s, Val Loss=1.26, PSNR=13.2]


Generator Val Loss: 1.2627, PSNR: 13.1668, Cost: 165.3336s
Validation PSNR increased (13.166799 --> 13.166799).  Saving model ...


Training Epoch 2/100: 100%|██████████| 103/103 [02:24<00:00,  1.40s/it, Loss G=1.15, Loss D=0.679, PSNR=17]   


Epoch 2, Generator Train Loss: 1.1543, Discriminator Train Loss: 0.6786, PSNR: 16.9847


Validation Epoch 2/100: 100%|██████████| 54/54 [00:19<00:00,  2.77it/s, Val Loss=1.06, PSNR=21.6]


Generator Val Loss: 1.0630, PSNR: 21.5815, Cost: 163.8409s
Validation PSNR increased (21.581490 --> 21.581490).  Saving model ...


Training Epoch 3/100: 100%|██████████| 103/103 [02:23<00:00,  1.39s/it, Loss G=0.925, Loss D=1.12, PSNR=23.1]


Epoch 3, Generator Train Loss: 0.9251, Discriminator Train Loss: 1.1240, PSNR: 23.0920


Validation Epoch 3/100: 100%|██████████| 54/54 [00:19<00:00,  2.77it/s, Val Loss=1, PSNR=23.9]    


Generator Val Loss: 1.0000, PSNR: 23.9280, Cost: 163.1816s
Validation PSNR increased (23.927952 --> 23.927952).  Saving model ...


Training Epoch 4/100:  67%|██████▋   | 69/103 [01:38<00:48,  1.42s/it, Loss G=0.936, Loss D=0.622, PSNR=24.3]


KeyboardInterrupt: 