In [1]:
from torchvision.datasets import ImageFolder
from torchvision import transforms


# (32, 32)ピクセル画像を(128, 128)ピクセル画像に拡大するモデルを作りたいので、データをResizeでXを(32, 32)、Yを(128, 128)に加工する
# 今回使用するlfwデータセットは(250, 250)の正方形のデータセットだが、そうでない場合はResizeのあとにCenterCropを挟んでもよい
class DownSizePairImageFolder(ImageFolder):
    def __init__(self, root, transform=None, large_size=128, small_size=32, **kwds):
        super().__init__(root, transform=transform, **kwds)
        self.large_resizer = transforms.Resize(large_size)
        self.small_resizer = transforms.Resize(small_size)
        
    def __getitem__(self, index):
        path, _ = self.imgs[index]
        img = self.loader(path)
        large_img = self.large_resizer(img)
        small_img = self.small_resizer(img)
        if self.transform is not None:
            large_img = self.transform(large_img)
            small_img = self.transform(small_img)
        return small_img, large_img

In [2]:
from torch.utils.data import DataLoader


train_data = DownSizePairImageFolder('./lfw-deepfunneled/train', transform=transforms.ToTensor())
test_data = DownSizePairImageFolder('./lfw-deepfunneled/test', transform=transforms.ToTensor())
batch_size = 32
train_loader = DataLoader(train_data, batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size, shuffle=False, num_workers=4)

In [3]:
from torch import nn

# ２層の畳み込みレイヤーに4層の逆畳み込みレイヤーをつなげたCNN
# Conv2d: 画像の畳込みレイヤー
# MaxPool2d: プーリングレイヤー
# BatchNorm2d: 画像用バッチノーマリゼーション
# Dropout2d: 画像用dropout
# ConvTranspose2d: 画像の逆畳み込みレイヤー
net = nn.Sequential(
    nn.Conv2d(3, 256, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.Conv2d(256, 512, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(512),
    nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(256),
    nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(128),
    nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1)
)

In [4]:
import torch
from torch.autograd import Variable as V
import math


# psnr: Peak Signal-to-Noise Ratio、s/n比の一種
# psnr = 20 * log10(MAX/sqrt(MSE)) = 10 * log10(MAX**2/MSE)
def psnr(mse, max_v=1.0):
    return 10 * math.log10(max_v**2 / mse)

def eval_net(net, data_loader):
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        x = V(x)
        y = V(y)
        y_pred = net(x)
        ys.append(y)
        ypreds.append(y_pred)
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    # 予測精度(MSE)を計算
    score = nn.functional.mse_loss(ypreds, ys).data[0]
    return score

In [5]:
from torch import optim
from tqdm import tqdm


def train_net(net, train_loader, test_loader, optimizer_cls=optim.Adam, loss_fn=nn.MSELoss(), n_iter=10):
    train_losses = []
    train_acc = []
    val_acc = []
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        net.train()
        n = 0
        score = 0
        for i, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
            xx = V(x)
            yy = V(y)
            y_pred = net(xx)
            loss = loss_fn(y_pred, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.data[0]
            n += len(xx)
        train_losses.append(running_loss / len(train_loader))
        val_acc.append(eval_net(net, test_loader))
        print(epoch, train_losses[-1], psnr(train_losses[-1]), psnr(val_acc[-1]), flush=True)

In [6]:
train_net(net, train_loader, test_loader)

100%|██████████| 409/409 [11:05<00:00,  1.63s/it]


0 0.015116326122096123 18.205537471136527 25.55763781582705


100%|██████████| 409/409 [11:09<00:00,  1.64s/it]


1 0.003073885794080029 25.123122721395347 26.129364900174924


100%|██████████| 409/409 [14:29<00:00,  2.13s/it]


2 0.0026717450572769806 25.73204985380699 26.18474647103931


100%|██████████| 409/409 [12:34<00:00,  1.84s/it]


3 0.002590846036638254 25.865583945885295 26.8263215719221


100%|██████████| 409/409 [14:27<00:00,  2.12s/it]


4 0.0023641477855295854 26.263253787070667 25.91946915436635


100%|██████████| 409/409 [15:21<00:00,  2.25s/it]


5 0.0022378019908270035 26.5017834412386 27.31573092628002


100%|██████████| 409/409 [20:49<00:00,  3.05s/it]


6 0.0022008408172383975 26.57411368030487 26.943016642401368


100%|██████████| 409/409 [19:26<00:00,  2.85s/it]


7 0.0022078453162389018 26.56031356965084 26.218687168849023


100%|██████████| 409/409 [19:29<00:00,  2.86s/it]


8 0.0021834944065746765 26.608479162669664 27.490337290672663


100%|██████████| 409/409 [19:24<00:00,  2.85s/it]


9 0.002133573506205072 26.708923901661276 27.792880789802233


In [7]:
from torchvision.utils import save_image


# テストデータから4つ取り出す
random_test_loader = DataLoader(test_data, batch_size=4, shuffle=True)
# 4つのテストデータをイテレータに変換して画像を取り出す
it = iter(random_test_loader)
x, y = next(it)
# bilinearで拡大
bl_recon = torch.nn.functional.upsample(x, 128, mode='bilinear')
yp = net(V(x))
# torch.catでオリジナル, Bilinear, CNNの画像を結合してsave_imageで画像ファイルに書き出す
save_image(torch.cat([y, bl_recon.data, yp.data], 0), 'cnn_upscale.jpg', nrow=4)