# SRGAN

## 预处理阶段

In [1]:
import argparse
import os
import numpy as np
import math
import itertools
import sys
import time
import glob
from PIL import Image

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torchvision import datasets
from torchvision.models import vgg19
from torchvision.models import VGG19_Weights

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
# 计划使用GPU运行
# 检查是否有可用的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
# 超参数设置
epoch = 0           # 起始轮次
n_epochs = 200      # 训练集总训练轮次
batch_size = 4
lr = 0.0002           # 学习率
b1 = 0.5 
b2 = 0.999
decay_epoch = 100    # 从第几轮开始，lr开始衰减
hr_height = 256      # 高分辨图像高
hr_width = 256       # 高分辨图像宽
channels = 3        # 图片通道数
sample_interval = 100    # 生成图片间隔
checkpoint_interval = -1  # 模型检查点之间的间隔
sample_dir = 'demo1'    # 生成图片保存路径

hr_shape = (hr_height, hr_width)

In [4]:
# 建立文件夹用于存放训练过程的图像，如果文件夹不存在，就创建一个
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

## 数据集

In [5]:
# 数据集预训练得到的平均值和标准差
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):
        hr_height, hr_width = hr_shape
        # 低分辨图像和高分辨图像的transforms
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_width // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_width), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        
        self.files = sorted(glob.glob(root + "/*.*"))
        
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)
        return {"lr": img_lr, "hr": img_hr}
    
    def __len__(self):
        return len(self.files)

In [6]:
# 数据集加载器
data_loader = DataLoader(
    # ImageDataset(root='/home/cxmd/文档/data_for_AI_train/img_align_celeba', hr_shape=hr_shape),
    ImageDataset(root='/home/cxmd/文档/data_for_AI_train/Particles/hr', hr_shape=hr_shape),
    batch_size=batch_size,
    shuffle=True,
)

## 网络构建

![](https://raw.githubusercontent.com/XM-Chen/figuremap/main/my_notes/202311200921526.png)

In [7]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        # 创建一个预训练的VGG-19模型
        vgg19_model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1)
        # 使用VGG-19模型的前18层作为特征提取器
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        # 提取图像的特征并返回
        return self.feature_extractor(img)


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        # 定义一个卷积块，包含两个卷积层，每个卷积层后面都跟着一个批量归一化层和一个PReLU激活函数
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
        )

    def forward(self, x):
        # 将输入通过卷积块处理得到的结果与原始输入相加，实现残差连接
        return x + self.conv_block(x)

In [8]:
# 构建生成器
class GeneratorResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):   # 原文建议使用16个残差块
        super(GeneratorResNet, self).__init__()
        
        # 第一层
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=9, stride=1, padding=4),
            nn.PReLU(),
        )
        
        # 残差块
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
        self.res_block = nn.Sequential(
            *res_blocks
        )
        
        # 残差块后的第二个卷积层
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
        )
        
        # 上采样层
        upsampling = []
        for out_features in range(2):
            upsampling += [
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU()
            ]
        self.upsampling = nn.Sequential(
            *upsampling
        )
        
        # 最后输出层
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, out_channels=out_channels, kernel_size=9, stride=1, padding=4),
            nn.Tanh(),
        )
        
    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_block(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

In [9]:
# 构建判别器
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        
        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)
        
        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters
        
        # layers.append(nn.AdaptiveAvgPool2d(1))
        # layers.append(nn.Flatten())
        # layers.append(nn.Linear(512, 1024))
        # layers.append(nn.LeakyReLU(0.2))
        # layers.append(nn.Linear(1024, 1))
        # layers.append(nn.Sigmoid())
        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, img):
        return self.model(img)

In [10]:
# 初始化生成器和判别器
generator = GeneratorResNet().to(device)
discriminator = Discriminator(input_shape=(channels, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

# 将feature_extractor设为评估模式，不需要训练
feature_extractor.eval()

FeatureExtractor(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3

In [11]:
# 设置损失函数和优化器
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [12]:
# 检测是否有预训练模型
# 如果开始训练的轮数不等于0，那么加载预训练的模型
if epoch != 0:
    # 从文件中加载生成器模型的参数，并将这些参数加载到生成器模型中
    generator.load_state_dict(torch.load(f"demo1/generator_{epoch}.pth"))
    # 从文件中加载判别器模型的参数，并将这些参数加载到判别器模型中
    discriminator.load_state_dict(torch.load(f"saved_models/discriminator_{epoch}.pth"))

## 训练

In [13]:
# 记录开始训练的时间
start_time = time.time()

for ind_epoch in range(epoch, n_epochs):
    for i, imgs in enumerate(data_loader):
        
        # 固定输入大小
        imgs_lr = imgs['lr'].to(device)
        imgs_hr = imgs['hr'].to(device)
        
        valid = torch.ones((imgs_lr.size(0), *discriminator.output_shape), device=device, requires_grad=False)
        fake = torch.zeros((imgs_lr.size(0), *discriminator.output_shape), device=device, requires_grad=False)
        
        # ----------------
        #  训练生成器
        # ----------------
        
        optimizer_G.zero_grad()
        
        # 由低分辨图像输入产生高分辨图像
        gen_hr = generator(imgs_lr)
        
        # 判别器损失
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
        
        # 内容损失
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())
        
        # 总损失
        loss_G = loss_content + 1e-3 * loss_GAN
        
        loss_G.backward()
        optimizer_G.step()
        
        # ----------------
        #  训练判别器
        # ----------------
        
        optimizer_D.zero_grad()
        
        # 真假图片损失
        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
        
        # 总损失
        loss_D = (loss_real + loss_fake) / 2
        
        loss_D.backward()
        optimizer_D.step()
        
        # ----------------
        #  输出
        # ----------------
        if (i + 1) % 100 == 0:      # 每一百批次输入一次训练程度
            # 计算过去的时间
            elapsed_time = time.time() - start_time
            print('Elapsed time: {:.4f} seconds, Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'.format(elapsed_time, ind_epoch + 1, n_epochs, i + 1, len(data_loader), loss_D.item(), loss_G.item()))

Elapsed time: 27.4935 seconds, Epoch [1/200], Step [100/853], d_loss: 0.0070, g_loss: 0.6543
Elapsed time: 53.7790 seconds, Epoch [1/200], Step [200/853], d_loss: 0.0058, g_loss: 0.5323
Elapsed time: 80.3820 seconds, Epoch [1/200], Step [300/853], d_loss: 0.0014, g_loss: 0.5476
Elapsed time: 106.7564 seconds, Epoch [1/200], Step [400/853], d_loss: 0.0225, g_loss: 0.5261
Elapsed time: 133.0567 seconds, Epoch [1/200], Step [500/853], d_loss: 0.0072, g_loss: 0.5852
Elapsed time: 159.6567 seconds, Epoch [1/200], Step [600/853], d_loss: 0.0101, g_loss: 0.4247
Elapsed time: 186.1487 seconds, Epoch [1/200], Step [700/853], d_loss: 0.0012, g_loss: 0.4128
Elapsed time: 212.4837 seconds, Epoch [1/200], Step [800/853], d_loss: 0.0016, g_loss: 0.8750
Elapsed time: 253.0195 seconds, Epoch [2/200], Step [100/853], d_loss: 0.0012, g_loss: 0.5471
Elapsed time: 279.3575 seconds, Epoch [2/200], Step [200/853], d_loss: 0.0010, g_loss: 0.5554
Elapsed time: 305.5912 seconds, Epoch [2/200], Step [300/853], 

## 保存模型

In [14]:
# 保存生成器模型
torch.save(generator.state_dict(), './demo1/generator.pth')

# 保存判别器模型
torch.save(discriminator.state_dict(), './demo1/discriminator.pth')

## 测试模型

In [31]:
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 加载低分辨率图像
image = Image.open('/home/cxmd/文档/data_for_AI_train/Particles/hr/L2_ffe7b1e4a0dbed455356a5a62d04894a.jpg')

# 将图像转换为Tensor，并添加一个批次维度
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
image = transform(image).unsqueeze(0).to(device)

# 使用生成器生成高分辨率图像
with torch.no_grad():
    output = generator(image)

# 将输出转换回图像
output = output.squeeze(0).cpu().detach()
output = output * std.view() + mean.view()  # 反归一化
output = output.clamp(0, 1)  # 将像素值限制在[0, 1]范围内
output = transforms.ToPILImage()(output)

# 显示图像
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# ax[0].imshow(image.squeeze(0).cpu().detach())
ax[1].imshow(output)
plt.show()

ValueError: operands could not be broadcast together with shapes (3,3072,4096) (3,) 

In [29]:
mean.view()

array([0.485, 0.456, 0.406])