<a href="https://colab.research.google.com/github/Coffinbrain/lessons/blob/main/segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
from os import listdir
from os.path import join
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.data import DataLoader
from math import log10, exp
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from torch.optim.lr_scheduler import MultiStepLR
import math

plt.rcParams['font.sans-serif'] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False

In [None]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y


CROP_SIZE = 300


class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, zoom_factor):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x)
                                for x in listdir(image_dir) if is_image_file(x)]
        crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor)
        # 从图片中心裁剪成300*300
        self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size),
                                                   transforms.Resize(
                                                       crop_size // zoom_factor),
                                                   transforms.Resize(
                                                       crop_size, interpolation=Image.BICUBIC),
                                                   # BICUBIC 双三次插值
                                                   transforms.ToTensor()])
        self.target_transform = transforms.Compose(
            [transforms.CenterCrop(crop_size), transforms.ToTensor()])

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
        input = self.input_transform(input)
        target = self.target_transform(target)
        return input, target

    def __len__(self):
        return len(self.image_filenames)

In [None]:
def psnr(loss):
    return 10 * log10(1 / loss.item())

In [None]:
# 计算一维的高斯分布向量
def gaussian(window_size, sigma):
    gauss = torch.Tensor(
        [exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


# 创建高斯核，通过两个一维高斯分布向量进行矩阵乘法得到
# 可以设定channel参数拓展为3通道
def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(
        _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(
        channel, 1, window_size, window_size).contiguous()
    return window


# 计算SSIM
# 直接使用SSIM的公式，但是在计算均值时，不是直接求像素平均值，而是采用归一化的高斯核卷积来代替。
# 在计算方差和协方差时用到了公式Var(X)=E[X^2]-E[X]^2, cov(X,Y)=E[XY]-E[X]E[Y].
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd,
                         groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd,
                         groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd,
                       groups=channel) - mu1_mu2

    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(
                img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)

In [None]:
class SRCNN(nn.Module):
    def __init__(self, upscale_factor):
        super(SRCNN, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2,
                               kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.pixel_shuffle(x)
        return x

In [None]:
zoom_factor = 2
nb_epochs = 500
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
torch.cuda.manual_seed(0)
BATCH_SIZE = 4
NUM_WORKERS = 0
trainset = DatasetFromFolder(r"./data/images/train", zoom_factor)
valset = DatasetFromFolder(r"./data/images/train", zoom_factor)
trainloader = DataLoader(dataset=trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
valloader = DataLoader(dataset=valset, batch_size=BATCH_SIZE,
                       shuffle=False, num_workers=NUM_WORKERS)

In [None]:
model = SRCNN(1).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(
    [
        {"params": model.conv1.parameters(), "lr": 0.0001},
        {"params": model.conv2.parameters(), "lr": 0.0001},
        {"params": model.conv3.parameters(), "lr": 0.00001},
    ]
)


best_psnr = 0.0
for epoch in range(nb_epochs):
    # Train
    epoch_loss = 0
    for iteration, batch in enumerate(trainloader):
        input, target = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        out = model(input)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch}. Training loss: {epoch_loss / len(trainloader)}")
    # Val
    sum_psnr = 0.0
    sum_ssim = 0.0
    with torch.no_grad():
        for batch in valloader:
            input, target = batch[0].to(device), batch[1].to(device)
            out = model(input)
            loss = criterion(out, target)
            pr = psnr(loss)
            sm = ssim(input, out)
            sum_psnr += pr
            sum_ssim += sm
    print(f"Average PSNR: {sum_psnr / len(valloader)} dB.")
    print(f"Average SSIM: {sum_ssim / len(valloader)} ")
    avg_psnr = sum_psnr / len(valloader)
    if avg_psnr >= best_psnr:
        best_psnr = avg_psnr
        torch.save(model, r"best_model_SRCNN.pth")

In [None]:
BATCH_SIZE = 4
model_path = "best_model_SRCNN.pth"
testset = DatasetFromFolder(r"./data/images/test", zoom_factor)
testloader = DataLoader(dataset=testset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=NUM_WORKERS)
sum_psnr = 0.0
sum_ssim = 0.0
model = torch.load(model_path).to(device)
criterion = nn.MSELoss()
with torch.no_grad():
    for batch in testloader:
        input, target = batch[0].to(device), batch[1].to(device)
        out = model(input)
        loss = criterion(out, target)
        pr = psnr(loss)
        sm = ssim(input, out)
        sum_psnr += pr
        sum_ssim += sm
print(f"Test Average PSNR: {sum_psnr / len(testloader)} dB")
print(f"Test Average SSIM: {sum_ssim / len(testloader)} ")