[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YuichiKZ/pytorch_sr/blob/main/Section_4_2_FSRCNN.ipynb)


#FSRCNNのノートブック
- このノートブックではPyTorchによるFSRCNNネットワークの実装と学習を行い超解像画像生成を体験します。

#学習目標
- FSRCNNのネットワーク構造を理解する
- モデル学習時に必要となるデータセットクラスが構築できるようにする
- FSRCNNの損失関数と最適化アルゴリズムを実装し、ニューラルネットワークの学習を行う
- FSRCNNで超解像画像の生成を体験する

### GPU確認

In [1]:
!nvidia-smi

Sat Aug 20 08:03:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## FSRCNN

### Google Drive接続

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### ライブラリ

In [10]:
import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.transforms import ToTensor, RandomCrop
from PIL import Image, ImageOps
import random

from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from pathlib import Path
from math import log10

### デバイス

In [12]:
# 学習に使用するデバイスを得る。可能ならGPUを使用する
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### パラメータ

In [3]:
num_epochs = 30000
lr1=1e-3
lr2=1e-4
batch_size = 10
scale_factor = 4
patch_size = 96

g_drive_dir = "/content/drive/MyDrive/FSRCNN/"

### Google Driveに保存ディレクトリ作成

In [6]:
import os
if not os.path.exists(g_drive_dir):
    os.makedirs(g_drive_dir)

### General-100

In [7]:
!wget -O General-100.zip https://github.com/YuichiKZ/pytorch_sr/blob/main/General-100.zip?raw=true

--2022-08-20 09:19:29--  https://github.com/YuichiKZ/pytorch_sr/blob/main/General-100.zip?raw=true
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/YuichiKZ/pytorch_sr/raw/main/General-100.zip [following]
--2022-08-20 09:19:29--  https://github.com/YuichiKZ/pytorch_sr/raw/main/General-100.zip
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/General-100.zip [following]
--2022-08-20 09:19:29--  https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/General-100.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response.

In [None]:
!unzip General-100.zip

In [None]:
import os
import shutil
from glob import glob
import numpy as np

if not os.path.exists('./General-100/train'):
    os.makedirs('./General-100/train')
if not os.path.exists('./General-100/test'):
    os.makedirs('./General-100/test')
if not os.path.exists('./General-100/val'):
    os.makedirs('./General-100/val')

filenames = np.array(glob('./General-100/*.png'))

np.random.seed(0)
train_files = np.random.choice(filenames, size=80, replace=False)
for filename in train_files:
    shutil.move(filename, './General-100/train')

test_val_files = np.array(list(set(filenames) - set(train_files)))
test_files = np.random.choice(test_val_files, size=10, replace=False)
for filename in test_files:
    shutil.move(filename, './General-100/test')

val_files = np.array(list(set(test_val_files) - set(test_files)))
for filename in val_files:
    shutil.move(filename, './General-100/val')

### データセットの定義

In [13]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, patch_size, scale_factor, data_augmentation=True):
        super(DatasetFromFolder, self).__init__()
        #fileの取り出し
        self.filenames = [str(filename) for filename in Path(image_dir).glob('*') if filename.suffix in ['.bmp', '.jpg', '.png']]
        #patch_sizeの指定
        self.patch_size = patch_size
        #scale_factor(4)の指定
        self.scale_factor = scale_factor
        #data_augmentationの指定
        self.data_augmentation = data_augmentation
        #PatchSizeでRandom Cropする関数を定義する
        self.crop = RandomCrop(self.patch_size)

    def __getitem__(self, index):
        #filenames listから画像ファイルを開く
        target_img = Image.open(self.filenames[index]).convert('RGB')
        #patch_size * patch_sizeにResize
        target_img = self.crop(target_img)
        
        #augmentationデータ水増し
        if self.data_augmentation:
            if random.random() < 0.5:
                #上下反転
                target_img = ImageOps.flip(target_img)
            if random.random() < 0.5:
                #左右反転
                target_img = ImageOps.mirror(target_img)
            if random.random() < 0.5:
                #画像回転
                target_img = target_img.rotate(180)

        #96/4=24より(24,24)にResize
        input_img = target_img.resize((self.patch_size // self.scale_factor, self.patch_size // self.scale_factor), Image.BICUBIC)
        
        #補間 24*24=>96*96の画像とoriginal 96*96の画像テンソルを返す
        return ToTensor()(input_img), ToTensor()(target_img)

    def __len__(self):
        return len(self.filenames)

In [14]:
class DatasetFromFolderEval(data.Dataset):
    def __init__(self, image_dir, scale_factor):
        super(DatasetFromFolderEval, self).__init__()
        #fileの取り出し
        self.filenames = [str(filename) for filename in Path(image_dir).glob('*') if filename.suffix in ['.bmp', '.jpg', '.png']]
        #scale_factorの指定
        self.scale_factor = scale_factor

    def __getitem__(self, index):
        #filenames listから画像ファイルを開く
        target_img = Image.open(self.filenames[index]).convert('RGB')
        #画像サイズの縦横をscale_factorの倍数に変換する
        target_img_row_size = (target_img.size[0] // self.scale_factor) * self.scale_factor
        target_img_col_size = (target_img.size[1] // self.scale_factor) * self.scale_factor
        #target_imgをresizeする
        target_img = target_img.resize((target_img_row_size,target_img_col_size), Image.BICUBIC)

        #補間画像
        interpolated_img = target_img.resize((target_img.size[0] // self.scale_factor,target_img.size[1] // self.scale_factor), Image.BICUBIC)
        interpolated_img = interpolated_img.resize((target_img_row_size,target_img_col_size), Image.BICUBIC)
        
        #scale_factorで圧縮した画像を入力する
        input_img = target_img.resize((target_img.size[0] // self.scale_factor, target_img.size[1] // self.scale_factor), Image.BICUBIC)

        # 入力画像(圧縮), 補間画像,出力画像, 画像ファイル名
        return ToTensor()(input_img), ToTensor()(interpolated_img), ToTensor()(target_img), Path(self.filenames[index]).stem

    def __len__(self):
        return len(self.filenames)

In [15]:
train_set = DatasetFromFolder(image_dir='/content/General-100/train', patch_size=patch_size, scale_factor=scale_factor, data_augmentation=True)
train_loader = DataLoader(dataset=train_set, batch_size=10, shuffle=True)

val_set = DatasetFromFolderEval(image_dir='/content/General-100/val',  scale_factor=scale_factor)
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

### ネットワーク構造


In [16]:
from torch import nn
from torch.nn.functional import relu

class FSRCNN(nn.Module):
    def __init__(self,scale_factor,d=56,s=12):
        super(FSRCNN, self).__init__()

        #特徴抽出
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )
        #チャンネル縮小
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=d, out_channels=s, kernel_size=1),
            nn.PReLU()
        )
        #非線形変換
        self.layer3=[]
        for _ in range(4):
            self.layer3.extend([nn.Conv2d(in_channels=s, out_channels=s, kernel_size=3, padding=3//2), nn.PReLU()])
        self.layer3 = nn.Sequential(*self.layer3)
        #チャンネル拡張
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels=s, out_channels=d, kernel_size=1),
            nn.PReLU()
        )
        #転置畳み込み
        self.layer5 = nn.ConvTranspose2d(in_channels=d, out_channels=3, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        return x

### 重み初期化関数

In [14]:
def weights_init(m):
    """
    ニューラルネットワークの重みを初期化する。作成したインスタンスに対しapplyメソッドで適用する
    :param m: ニューラルネットワークを構成する層
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:            # 畳み込み層の場合
        m.weight.data.normal_(0.0, 0.001)
        m.bias.data.fill_(0)

### モデルインスタンス化

In [None]:
model = FSRCNN(scale_factor=4).to(device)
model.apply(weights_init)

### 学習

In [20]:
criterion = nn.MSELoss()
optimizer = optim.Adam([{'params': model.layer1.parameters()},
                        {'params': model.layer2.parameters()},
                        {'params': model.layer3.parameters()},
                        {'params': model.layer4.parameters()},
                        {'params': model.layer5.parameters(), 'lr': lr2}],
                        lr=lr1)

In [None]:
import torchsummary
torchsummary.summary(model,(3,24,24))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 56, 24, 24]           4,256
             PReLU-2           [-1, 56, 24, 24]              56
            Conv2d-3           [-1, 12, 24, 24]             684
             PReLU-4           [-1, 12, 24, 24]               1
            Conv2d-5           [-1, 12, 24, 24]           1,308
             PReLU-6           [-1, 12, 24, 24]               1
            Conv2d-7           [-1, 12, 24, 24]           1,308
             PReLU-8           [-1, 12, 24, 24]               1
            Conv2d-9           [-1, 12, 24, 24]           1,308
            PReLU-10           [-1, 12, 24, 24]               1
           Conv2d-11           [-1, 12, 24, 24]           1,308
            PReLU-12           [-1, 12, 24, 24]               1
           Conv2d-13           [-1, 56, 24, 24]             728
            PReLU-14           [-1, 56,

In [None]:
PSNR_list = []
PSNR_val_list = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss, epoch_psnr = 0, 0
    for batch in train_loader:
        inputs, targets = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()
        prediction = model(inputs)

        loss = criterion(prediction, targets)
        epoch_loss += loss.data
        epoch_psnr += 10 * log10(1 / loss.data)
        
        loss.backward()
        optimizer.step()
    
    PSNR_list.append(epoch_psnr / len(train_loader))
    if (epoch + 1) % 10 == 0:
        print('[Epoch {}] Loss: {:.4f}, PSNR: {:.4f} dB'.format(epoch + 1, epoch_loss / len(train_loader), epoch_psnr / len(train_loader)))
    
    model.eval()
    val_loss, val_psnr = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, interpolated, targets = batch[0].to(device), batch[1].to(device), batch[2].to(device)    

            prediction = model(inputs)
            loss = criterion(prediction, targets)
            val_loss += loss.data
            val_psnr += 10 * log10(1 / loss.data)
    
    PSNR_val_list.append(val_psnr / len(val_loader))
    if (epoch + 1) % 10 == 0:
        print("===> Validation Loss: {:.4f}, Validation PSNR: {:.4f} dB".format(val_loss / len(val_loader), val_psnr / len(val_loader)))
    
    """
    モデルの保存
    """
    if (epoch + 1) % 2000 == 0:   # 2000エポックごとにモデルを保存
        torch.save(model.state_dict(), '{}/FSRCNN_epoch_{}.pth'.format(g_drive_dir, epoch + 1))
    

In [None]:
import torchsummary
torchsummary.summary(model,(3,99,111))

### PSNR表示

In [None]:
from matplotlib import pyplot as plt

plt.plot(PSNR_list, label='PSNR')
plt.plot(PSNR_val_list, label='PSNR_val')
plt.legend()
plt.savefig('PSNR.png')

### テスト

In [23]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from pathlib import Path
from math import log10


save_dir = os.path.join(g_drive_dir,'pic')
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
is_cuda = True

test_set = DatasetFromFolderEval(image_dir='/content/General-100/test', scale_factor=4)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

# model.load_state_dict(torch.load(opt.weight_path, map_location='cuda' if opt.cuda else 'cpu'))

model.eval()
total_loss, total_psnr = 0, 0
total_loss_b, total_psnr_b = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs,interpolated, targets = batch[0].to(device), batch[1].to(device), batch[2].to(device)

        prediction = model(inputs)

        loss = criterion(prediction, targets)
        total_loss += loss.data
        total_psnr += 10 * log10(1 / loss.data)

        loss = criterion(prediction, targets)
        total_loss_b += loss.data
        total_psnr_b += 10 * log10(1 / loss.data)

        save_image(prediction, Path(save_dir) / '{}_sr.png'.format(batch[3][0]), nrow=1)
        save_image(interpolated, Path(save_dir) / '{}_lr.png'.format(batch[3][0]), nrow=1)
        save_image(targets, Path(save_dir) / '{}_hr.png'.format(batch[3][0]), nrow=1)

print("===> [Bicubic] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss_b / len(test_loader), total_psnr_b / len(test_loader)))
print("===> [SRCNN] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss / len(test_loader), total_psnr / len(test_loader)))

===> [Bicubic] Avg. Loss: 0.0016, PSNR: 29.2148 dB
===> [SRCNN] Avg. Loss: 0.0016, PSNR: 29.2148 dB


### 学習済みモデルを利用する

In [18]:
# githubから30000epoch学習済みモデルをダウンロード
!wget -O FSRCNN_epoch_30000.pth https://github.com/YuichiKZ/pytorch_sr/blob/main/weights/FSRCNN_epoch_30000.pth?raw=true

--2022-08-20 08:17:19--  https://github.com/YuichiKZ/pytorch_sr/blob/main/weights/FSRCNN_epoch_30000.pth?raw=true
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/YuichiKZ/pytorch_sr/raw/main/weights/FSRCNN_epoch_30000.pth [following]
--2022-08-20 08:17:19--  https://github.com/YuichiKZ/pytorch_sr/raw/main/weights/FSRCNN_epoch_30000.pth
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/weights/FSRCNN_epoch_30000.pth [following]
--2022-08-20 08:17:19--  https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/weights/FSRCNN_epoch_30000.pth
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.co

In [18]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from pathlib import Path
from math import log10

save_dir = os.path.join(g_drive_dir,'pic')
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

model = FSRCNN(scale_factor=4).to(device)
model.load_state_dict(torch.load("/content/FSRCNN_epoch_30000.pth"))

test_set = DatasetFromFolderEval(image_dir='/content/General-100/test', scale_factor=4)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

criterion = nn.MSELoss()

model.eval()
total_loss, total_psnr = 0, 0
total_loss_b, total_psnr_b = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs,interpolated, targets = batch[0].to(device), batch[1].to(device), batch[2].to(device)

        prediction = model(inputs)

        loss = criterion(prediction, targets)
        total_loss += loss.data
        total_psnr += 10 * log10(1 / loss.data)

        loss = criterion(prediction, targets)
        total_loss_b += loss.data
        total_psnr_b += 10 * log10(1 / loss.data)

        save_image(prediction, Path(save_dir) / '{}_sr.png'.format(batch[3][0]), nrow=1)
        save_image(interpolated, Path(save_dir) / '{}_lr.png'.format(batch[3][0]), nrow=1)
        save_image(targets, Path(save_dir) / '{}_hr.png'.format(batch[3][0]), nrow=1)

print("===> [Bicubic] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss_b / len(test_loader), total_psnr_b / len(test_loader)))
print("===> [SRCNN] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss / len(test_loader), total_psnr / len(test_loader)))

===> [Bicubic] Avg. Loss: 0.0016, PSNR: 29.2148 dB
===> [SRCNN] Avg. Loss: 0.0016, PSNR: 29.2148 dB
