In [1]:
# builtin 
import glob
import random
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")

# all imports
import torch 
import numpy as np 
import torch.nn as nn
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
from sklearn.model_selection import train_test_split



# our modules
from src.config import cfg, root_path
from src.utils import MeanSTDFinder
from src.data_loaders import SuperResolutionDataLoader
from src.models.srgan import Generator, Discriminator, VggFeatureExtractor


# create path for models checkpoint
Path(root_path).joinpath("saved_models/srgan").mkdir(exist_ok=True, parents=True)
Path(root_path).joinpath("saved_models/srgan/images").mkdir(exist_ok=True, parents=True)



In [2]:
import torch
import multiprocessing
from pathlib import Path
from easydict import EasyDict as edict

c = edict()
root_pth = "/home/rjn/Documents/GitHub/Satellite-Super-Resolution"
c.dataset = edict()
c.dataset.images_dir = str(Path(root_pth).joinpath("data/Train/images_png"))


# config for dataset scaling
c.images = edict()
c.images.channels = 3
c.images.scale_factor = 4
c.images.high_resolution_height = 512
c.images.high_resolution_width = 512


# dataloder
c.dataloader = edict()
c.dataloader.batch_size = 32
c.dataloader.num_workers = 16


c.device = edict()
c.device.device = "cuda" if torch.cuda.is_available() else "cpu"

# clip
c.train = edict()
c.train.n_epochs = 100
c.train.batch_size = 4
c.train.learning_rate = 0.00008
c.train.n_cpu = multiprocessing.cpu_count() // 2
c.train.b1 = 0.5  # adam: decay of first order momentum of gradient
c.train.b2 = 0.999  # adam: decay of second order momentum of gradient
c.train.decay_epoch = 100  # epoch from which to start lr decay

cfg = c

In [3]:
# get the images dataset path 
images_pth = cfg.dataset.images_dir

train_paths, test_paths = train_test_split(
    sorted(glob.glob(images_pth + "/*.*"))[:500],
    test_size=0.2,
    random_state=42,
)

# get the mean and std of the dataset 
# mean_std = MeanSTDFinder(images_dir=images_pth)()
mean_std = {'mean': [0.2903465 , 0.31224626, 0.29810828],
 'std': [0.1457739 , 0.13011318, 0.12317199]}

In [4]:
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset


class SuperResolutionDataLoader(Dataset):

    def __init__(self, paths, mean, std) -> None:
        super().__init__()

        self.items = paths

        # transforms for low resolution
        self.low_res_transforms = transforms.Compose(
            [
                transforms.Resize(
                    (
                        cfg.images.high_resolution_height // 4,
                        cfg.images.high_resolution_width // 4,
                    ),
                    Image.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        )

        # transforms for high resolution
        self.high_res_transforms = transforms.Compose(
            [
                transforms.Resize(
                    (
                        cfg.images.high_resolution_height,
                        cfg.images.high_resolution_width,
                    ),
                    Image.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        )

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):

        img = Image.open(self.items[index % len(self.items)]).convert("RGB")

        img_lr = self.low_res_transforms(img)

        img_hr = self.high_res_transforms(img)

        return img_hr, img_lr


In [5]:
images_pth = cfg.dataset.images_dir

train_paths, test_paths = train_test_split(
    sorted(glob.glob(images_pth + "/*.*")),
    test_size=0.2,
    random_state=42,
)

# get the mean and std of the dataset 
# mean_std = MeanSTDFinder(images_dir=images_pth)()
mean_std = {'mean': [0.2903465 , 0.31224626, 0.29810828],
 'std': [0.1457739 , 0.13011318, 0.12317199]}


In [6]:
train_iter = DataLoader(
    SuperResolutionDataLoader(train_paths,**mean_std),
    batch_size=cfg.train.batch_size,
    shuffle=True,
    num_workers=cfg.train.n_cpu,
)
val_iter = DataLoader(
    SuperResolutionDataLoader(test_paths,**mean_std),
    batch_size=int(cfg.train.batch_size * 0.75),
    shuffle=True,
    num_workers=cfg.train.n_cpu,
)


In [7]:
def calculate_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [8]:
class ResBlock(nn.Module):
    def __init__(self, inC, outC):
        super(ResBlock, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(inC, outC, kernel_size=3, stride=1, padding=1, bias=False), 
                                    nn.BatchNorm2d(outC), 
                                    nn.PReLU())

        self.layer2 = nn.Sequential(nn.Conv2d(outC, outC, kernel_size=3, stride=1, padding=1, bias=False), 
                                    nn.BatchNorm2d(outC))

    def forward(self, x):
        resudial = x

        out = self.layer1(x)
        out = self.layer2(out)
        out = out + resudial

        return out
    

class Generator(nn.Module):
    def __init__(self, n_blocks):
        super(Generator, self).__init__()
        self.convlayer1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4, bias=False),
                                        nn.PReLU())

        self.ResBlocks = nn.ModuleList([ResBlock(64, 64) for _ in range(n_blocks)]) #叠加n_blocks个残差块

        self.convlayer2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 
                                        nn.BatchNorm2d(64))

        self.convout = nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4, bias=False)

    def forward(self, x):
        out = self.convlayer1(x)
        residual = out

        for block in self.ResBlocks:
            out = block(out)

        out = self.convlayer2(out)
        out = out + residual

        out = self.convout(out)

        return out
    

class DownSample(nn.Module):
    def __init__(self, input_channel, output_channel,  stride, kernel_size=3, padding=1):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(nn.Conv2d(input_channel, output_channel, kernel_size, stride, padding),
                                   nn.BatchNorm2d(output_channel),
                                   nn.LeakyReLU(inplace=True))

    def forward(self, x):
        x = self.layer(x)
        return x

#判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),
                                   nn.LeakyReLU(inplace=True))

        self.down = nn.Sequential(DownSample(64, 64, stride=2, padding=1),
                                  DownSample(64, 128, stride=1, padding=1),
                                  DownSample(128, 128, stride=2, padding=1),
                                  DownSample(128, 256, stride=1, padding=1),
                                  DownSample(256, 256, stride=2, padding=1),
                                  DownSample(256, 512, stride=1, padding=1),
                                  DownSample(512, 512, stride=2, padding=1))

        self.dense = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                   nn.Conv2d(512, 1024, 1),
                                   nn.LeakyReLU(inplace=True),
                                   nn.Conv2d(1024, 1, 1),
                                   nn.Sigmoid()) #Loss为nn.BCELoss则加Sigmoid，若为nn.BCEWithLogitsLoss则不加，因为此Loss里包括了Sigmoid

    def forward(self, x):
        x = self.conv1(x)
        x = self.down(x)
        x = self.dense(x)
        return x

In [9]:
import torch
import torch.nn as nn
import torchvision.models as models

#SRGAN使用预训练好的VGG19，用生成器的结果以及原始图像通过VGG后分别得到的特征图计算MSE，具体解释推荐看SRGAN的相关资料
class VGG(nn.Module):
    def __init__(self, device):
        super(VGG, self).__init__()
        vgg = models.vgg19(True)
        for pa in vgg.parameters():
            pa.requires_grad = False
        self.vgg = vgg.features[:16]
        self.vgg = self.vgg.to(device)

    def forward(self, x):
        out = self.vgg(x)
        return out

#内容损失
class ContentLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.mse = nn.MSELoss()
        self.vgg19 = VGG(device)

    def forward(self, fake, real):
        feature_fake = self.vgg19(fake)
        feature_real = self.vgg19(real)
        loss = self.mse(feature_fake, feature_real)
        return loss

#对抗损失
class AdversarialLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        loss = torch.sum(-torch.log(x))
        return loss

#感知损失
class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.vgg_loss = ContentLoss(device)
        self.adversarial = AdversarialLoss()

    def forward(self, fake, real, x):
        vgg_loss = self.vgg_loss(fake, real)
        adversarial_loss = self.adversarial(x)
        return vgg_loss + 1e-3*adversarial_loss

#正则项，需要说明的是，在SRGAN的后续版本的论文中，这个正则项被删除了
class RegularizationLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        a = torch.square(
            x[:, :, :x.shape[2]-1, :x.shape[3]-1] - x[:, :, 1:x.shape[2], :x.shape[3]-1]
        )
        b = torch.square(
            x[:, :, :x.shape[2]-1, :x.shape[3]-1] - x[:, :, :x.shape[2]-1, 1:x.shape[3]]
        )
        loss = torch.sum(torch.pow(a+b, 1.25))
        return loss
    

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):

        self.patience = patience #等待多少个epoch之后停止
        self.verbose = verbose #是否显示日志
        self.counter = 0 #计步器
        self.best_score = None #记录最好性能
        self.early_stop = False #早停触发
        self.val_psnr_min = 0 #记录最小的验证PSNR
        self.delta = delta #可以给最好性能加上的小偏置
        self.checkpoint_perf = [] #记录检查点的性能

    def __call__(self, g, d, train_psnr, val_psnr):

        score = val_psnr
        self.early_stop = False

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(g, d, val_psnr)
        elif score < self.best_score + self.delta: #PSNR越大越好，因此这里是小于，若使用loss做指标，这里应改成大于
            self.counter += 1 #若当前性能不超过前一个epoch的性能则计步器+1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: #计步器累计到达极限，出发早停
                self.early_stop = True
                self.counter = 0
                self.best_score = None
                self.val_psnr_min = 0
        else: #当前性能优于或等于前一个epoch的性能，则更新最佳性能记录
            self.best_score = score
            self.save_checkpoint(g, d, val_psnr) #保存检查点
            self.counter = 0 #计步器重置
            self.checkpoint_perf = [train_psnr, val_psnr] #记录检查点性能数据
        return self.checkpoint_perf

    def save_checkpoint(self, g, d, val_psnr): #保存检查点
        self.val_psnr_min = val_psnr
        if self.verbose:
            print(f'Validation PSNR increased ({self.val_psnr_min:.6f} --> {val_psnr:.6f}).  Saving model ...')
            torch.save(g.state_dict(), 'Generator.pth')
            torch.save(d.state_dict(), 'Discriminator.pth')
        else:
            torch.save(g.state_dict(), 'Generator.pth')
            torch.save(d.state_dict(), 'Discriminator.pth')

In [10]:
lr = 0.001
n_blocks = 5
G = Generator(n_blocks)
D = Discriminator()

G_loss = PerceptualLoss(cfg.device.device) #自定义的loss函数
Regulaztion = RegularizationLoss().to(cfg.device.device) #自定义的loss函数
D_loss = nn.BCELoss().to(cfg.device.device)

optimizer_g = torch.optim.Adam(G.parameters(), lr=lr*0.1) #先训练判别器，后训练生成器，因此生成器的学习率比判别器小
optimizer_d = torch.optim.Adam(D.parameters(), lr=lr)

real_label = torch.ones([cfg.train.batch_size, 1, 1, 1]).to(cfg.device.device)
fake_label = torch.zeros([cfg.train.batch_size, 1, 1, 1]).to(cfg.device.device)

early_stopping = EarlyStopping(10, verbose=True)

#数据记录用
train_loss_g = []
train_loss_d = []
train_psnr = []
val_loss = []
val_psnr = []

cfg.train.n_epochs

100

In [11]:
import time
def train(generator, discriminator, train_iter, val_iter, n_epochs, optimizer_g, optimizer_d, loss_g, loss_d, Regulaztion, device):
    print('train on',device)
    generator.to(device)
    discriminator.to(device)
    cuda = next(generator.parameters()).device
    for epoch in range(n_epochs):
        train_epoch_loss_g = [] #数据记录用
        train_epoch_loss_d = []
        train_epoch_psnr = []
        val_epoch_loss = []
        val_epoch_psnr = []
        start = time.time() #开始时间
        generator.train() #设置为训练模式
        discriminator.train()
        for i, (img, nimg) in enumerate(train_iter):
            img, nimg = img.to(cuda).float(), nimg.to(cuda).float()
            fakeimg = generator(nimg) #生成器生成“假”图片，即降噪后的图片
            
            optimizer_d.zero_grad()
            realOut = discriminator(img) #判别器对“真”图片，即原始图片的判断，1为真，0为假
            fakeOut = discriminator(fakeimg.detach()) #判别器对“假”图片，即生成器生成的图片的判断，1为真，0为假
            loss_d = D_loss(realOut, real_label) + D_loss(fakeOut, fake_label) #判别器的损失
            loss_d.backward()
            optimizer_d.step()
            
            optimizer_g.zero_grad()
            loss_g = G_loss(fakeimg, img, D(fakeimg)) + 2e-8*Regulaztion(fakeimg) #生成器的损失，这里加了正则项
            loss_g.backward()
            optimizer_g.step()
            
            train_epoch_loss_d.append(loss_d.item()) #记录判别器损失
            train_epoch_loss_g.append(loss_g.item()) #记录生成器损失
            train_epoch_psnr.append(calculate_psnr(fakeimg, img).item()) #记录PSNR
        train_epoch_avg_loss_g = np.mean(train_epoch_loss_g) #计算一个epoch的平均损失
        train_epoch_avg_loss_d = np.mean(train_epoch_loss_d)
        train_epoch_avg_psnr = np.mean(train_epoch_psnr) #计算一个epoch的平均PSNR
        train_loss_g.append(train_epoch_avg_loss_g) #记录生成器的一个epoch的平均损失
        train_loss_d.append(train_epoch_avg_loss_d) #记录判别器的一个epoch的平均损失
        train_psnr.append(train_epoch_avg_psnr) #记录一个epoch的平均PSNR
        print(f'Epoch {epoch + 1}, Generator Train Loss: {train_epoch_avg_loss_g:.4f}, '
              f'Discriminator Train Loss: {train_epoch_avg_loss_d:.4f}, PSNR: {train_epoch_avg_psnr:.4f}') #打印epoch训练结果
        generator.eval() #设置为验证模式
        discriminator.eval()
        with torch.no_grad(): #不需要计算梯度
            for i, (img, nimg) in enumerate(val_iter): #验证就是简化版的训练，对照着看下，不赘述了
                img, nimg = img.to(cuda).float(), nimg.to(cuda).float()
                fakeimg = generator(nimg)
                loss_g = G_loss(fakeimg, img, D(fakeimg)) + 2e-8*Regulaztion(fakeimg)
                val_epoch_loss.append(loss_g.item())
                val_epoch_psnr.append(calculate_psnr(fakeimg, img).item())
            val_epoch_avg_loss = np.mean(val_epoch_loss)
            val_epoch_avg_psnr = np.mean(val_epoch_psnr)
            val_loss.append(val_epoch_avg_loss)
            val_psnr.append(val_epoch_avg_psnr)
            print(f'Generator Val Loss: {val_epoch_avg_loss:.4f}, PSNR: {val_epoch_avg_psnr:.4f}, Cost: {(time.time()-start):.4f}s')
            checkpoint_perf = early_stopping(generator, discriminator, train_epoch_avg_psnr, val_epoch_avg_psnr) #应用早停法，选出PSNR最高的那个
            if early_stopping.early_stop:
                print("Early stopping")
                print('Final model performance:')
                print(f'Train PSNR: {checkpoint_perf[0]}, Val PSNR: {checkpoint_perf[1]}')
                break
        torch.cuda.empty_cache() #清空显存缓存，可以不加这个

In [12]:
train(G, D, train_iter, val_iter, cfg.train.n_epochs, optimizer_g, optimizer_d, G_loss, D_loss, Regulaztion, cfg.device.device)

train on cuda


RuntimeError: The size of tensor a (32) must match the size of tensor b (128) at non-singleton dimension 3