[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YuichiKZ/pytorch_sr/blob/main/Section_3_2_SRCNN.ipynb)


#SRCNNのノートブック
- このノートブックではPyTorchによるSRCNNネットワークの実装と学習を行い超解像画像生成を体験します。

#学習目標
- SRCNNのネットワーク構造を理解する
- モデル学習時に必要となるデータセットクラスが構築できるようにする
- SRCNNの損失関数と最適化アルゴリズムを実装し、ニューラルネットワークの学習を行う
- SRCNNで超解像画像の生成を体験する


### GPU確認

In [1]:
!nvidia-smi

Sat Aug 20 05:17:19 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## SRCNN



### Google Driveマウント

In [19]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### ライブラリ

In [2]:
import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.transforms import ToTensor, RandomCrop
from PIL import Image, ImageOps
import random

from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from pathlib import Path
from math import log10

### デバイス

In [3]:
# 学習に使用するデバイスを得る。可能ならGPUを使用する
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### パラメータ

In [4]:
num_epochs = 30000
lr1=1e-4
lr2=1e-5
batch_size = 10
scale_factor = 4
patch_size = 96

g_drive_dir = "/content/drive/MyDrive/SRCNN/"

### Google Driveに保存ディレクトリ作成

In [20]:
import os
if not os.path.exists(g_drive_dir):
    os.makedirs(g_drive_dir)

### General-100

In [5]:
!wget -O General-100.zip https://github.com/YuichiKZ/pytorch_sr/blob/main/General-100.zip?raw=true

--2022-08-20 05:17:31--  https://github.com/YuichiKZ/pytorch_sr/blob/main/General-100.zip?raw=true
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/YuichiKZ/pytorch_sr/raw/main/General-100.zip [following]
--2022-08-20 05:17:31--  https://github.com/YuichiKZ/pytorch_sr/raw/main/General-100.zip
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/General-100.zip [following]
--2022-08-20 05:17:31--  https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/General-100.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response.

In [None]:
!unzip General-100.zip

In [7]:
import os
import shutil
from glob import glob
import numpy as np

if not os.path.exists('./General-100/train'):
    os.makedirs('./General-100/train')
if not os.path.exists('./General-100/test'):
    os.makedirs('./General-100/test')
if not os.path.exists('./General-100/val'):
    os.makedirs('./General-100/val')

filenames = np.array(glob('./General-100/*.png'))

np.random.seed(0)
train_files = np.random.choice(filenames, size=80, replace=False)
for filename in train_files:
    shutil.move(filename, './General-100/train')

test_val_files = np.array(list(set(filenames) - set(train_files)))
test_files = np.random.choice(test_val_files, size=10, replace=False)
for filename in test_files:
    shutil.move(filename, './General-100/test')

val_files = np.array(list(set(test_val_files) - set(test_files)))
for filename in val_files:
    shutil.move(filename, './General-100/val')

### データセットの定義

In [8]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, patch_size, scale_factor, data_augmentation=True):
        super(DatasetFromFolder, self).__init__()
        #fileの取り出し
        self.filenames = [str(filename) for filename in Path(image_dir).glob('*') if filename.suffix in ['.bmp', '.jpg', '.png']]
        #patch_sizeの指定
        self.patch_size = patch_size
        #scale_factor(倍率)の指定
        self.scale_factor = scale_factor
        #data_augmentationの指定
        self.data_augmentation = data_augmentation
        #PatchSizeでRandom Cropする関数を定義する
        self.crop = RandomCrop(self.patch_size)

    def __getitem__(self, index):
        #filenames listから画像ファイルを開く
        target_img = Image.open(self.filenames[index]).convert('RGB')
        #patch_size * patch_sizeにResize
        target_img = self.crop(target_img) 
        
        #augmentationデータ水増し
        if self.data_augmentation:
            if random.random() < 0.5:
                #上下反転
                target_img = ImageOps.flip(target_img)
            if random.random() < 0.5:
                #左右反転
                target_img = ImageOps.mirror(target_img)
            if random.random() < 0.5:
                #画像回転
                target_img = target_img.rotate(180)
        
        #低解像度画像の生成
        #(96/4)=24より(W,H)=(96,96)=>(24,24)にResize
        input_img = target_img.resize((self.patch_size // self.scale_factor, self.patch_size // self.scale_factor), Image.BICUBIC)
        #(W,H)=(24,24)=>(96,96)にResize
        input_img = input_img.resize((self.patch_size, self.patch_size), Image.BICUBIC)
        
        #補間 24*24=>96*96の画像とoriginal 96*96の画像テンソルを返す
        return ToTensor()(input_img), ToTensor()(target_img)

    def __len__(self):
        return len(self.filenames)

In [9]:
class DatasetFromFolderEval(data.Dataset):
    def __init__(self, image_dir, scale_factor):
        super(DatasetFromFolderEval, self).__init__()
        #fileの取り出し
        self.filenames = [str(filename) for filename in Path(image_dir).glob('*') if filename.suffix in ['.bmp', '.jpg', '.png']]
        #scale_factorの指定
        self.scale_factor = scale_factor

    def __getitem__(self, index):
        #filenames listから画像ファイルを開く
        target_img = Image.open(self.filenames[index]).convert('RGB')
        
        # 低解像度画像にダウンサンプル
        input_img = target_img.resize((target_img.size[0] // self.scale_factor, target_img.size[1] // self.scale_factor), Image.BICUBIC)
        # SRCNNに入力できるように拡大してサイズを戻す
        input_img = input_img.resize(target_img.size, Image.BICUBIC)

        return ToTensor()(input_img), ToTensor()(target_img), Path(self.filenames[index]).stem

    def __len__(self):
        return len(self.filenames)

In [10]:
train_set = DatasetFromFolder(image_dir='/content/General-100/train', patch_size=patch_size, scale_factor=scale_factor, data_augmentation=True)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)

val_set = DatasetFromFolderEval(image_dir='/content/General-100/val', scale_factor=scale_factor)
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

### ネットワーク構造

In [11]:
from torch import nn

class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()

        #パッチ特徴抽出
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=4),
            nn.ReLU(),
        )
        #非線形変換
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, padding=0),
            nn.ReLU(),
        )
        #画像再構成
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=2),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

### 重み初期化関数

In [12]:
def weights_init(m):
    """
    ニューラルネットワークの重みを初期化する。作成したインスタンスに対しapplyメソッドで適用する
    :param m: ニューラルネットワークを構成する層
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:            # 畳み込み層の場合
        m.weight.data.normal_(0.0, 0.001)
        m.bias.data.fill_(0)

### モデルインスタンス化

In [13]:
model = SRCNN().to(device)
model.apply(weights_init)

SRCNN(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): ReLU()
  )
  (layer2): Sequential(
    (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
  )
  (layer3): Sequential(
    (0): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
  )
)

In [16]:
import torchsummary
torchsummary.summary(model,(3,96,96))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 96, 96]          15,616
              ReLU-2           [-1, 64, 96, 96]               0
            Conv2d-3           [-1, 32, 96, 96]           2,080
              ReLU-4           [-1, 32, 96, 96]               0
            Conv2d-5            [-1, 3, 96, 96]           2,403
              ReLU-6            [-1, 3, 96, 96]               0
Total params: 20,099
Trainable params: 20,099
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.11
Forward/backward pass size (MB): 13.92
Params size (MB): 0.08
Estimated Total Size (MB): 14.10
----------------------------------------------------------------


### 学習

In [17]:
criterion = nn.MSELoss()
optimizer = optim.Adam([{'params': model.layer1.parameters()},
                        {'params': model.layer2.parameters()},
                        {'params': model.layer3.parameters(), 'lr': lr2}],
                        lr=lr1)

In [None]:
PSNR_list = []
PSNR_val_list = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss, epoch_psnr = 0, 0
    for batch in train_loader:
        inputs, targets = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()
        prediction = model(inputs)
        loss = criterion(prediction, targets)
        epoch_loss += loss.data.cpu().detach().numpy()
        epoch_psnr += 10 * log10(1 / loss.data)
        
        loss.backward()
        optimizer.step()

    PSNR_list.append(epoch_psnr / len(train_loader))
    if (epoch + 1) % 10 == 0:
        print('[Epoch {}] Loss: {:.4f}, PSNR: {:.4f} dB'.format(epoch + 1, epoch_loss / len(train_loader), epoch_psnr / len(train_loader)))

    model.eval()
    val_loss, val_psnr = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = batch[0].to(device), batch[1].to(device)     
            
            prediction = model(inputs)
            loss = criterion(prediction, targets)
            val_loss += loss.data.cpu().detach().numpy()
            val_psnr += 10 * log10(1 / loss.data)
    
    PSNR_val_list.append(val_psnr / len(val_loader))
    if (epoch + 1) % 10 == 0:
        print("===> Validation Loss: {:.4f}, Validation PSNR: {:.4f} dB".format(val_loss / len(val_loader), val_psnr / len(val_loader)))
    
    """
    モデルの保存
    """
    if (epoch + 1) % 2000 == 0:   # 2000エポックごとにモデルを保存
        torch.save(model.state_dict(), '{}/SRCNN_epoch_{}.pth'.format(g_drive_dir, epoch + 1))
    

### PSNR表示

In [None]:
from matplotlib import pyplot as plt

plt.plot(PSNR_list, label='PSNR')
plt.plot(PSNR_val_list, label='PSNR_val')
plt.legend()
plt.savefig('PSNR.png')

### テスト

In [22]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from pathlib import Path
from math import log10

criterion = nn.MSELoss()

save_dir = os.path.join(g_drive_dir,'pic')

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

test_set = DatasetFromFolderEval(image_dir='/content/General-100/test', scale_factor=4)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

model.eval()
total_loss, total_psnr = 0, 0
total_loss_b, total_psnr_b = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch[0].to(device), batch[1].to(device)

        prediction = model(inputs)
        loss = criterion(prediction, targets)
        total_loss += loss.data
        total_psnr += 10 * log10(1 / loss.data)

        loss = criterion(inputs, targets)
        total_loss_b += loss.data
        total_psnr_b += 10 * log10(1 / loss.data)

        #google driveに保存
        save_image(prediction, Path(save_dir) / '{}_sr.png'.format(batch[2][0]), nrow=1)
        save_image(inputs, Path(save_dir) / '{}_lr.png'.format(batch[2][0]), nrow=1)
        save_image(targets, Path(save_dir) / '{}_hr.png'.format(batch[2][0]), nrow=1)

print("===> [Bicubic] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss_b / len(test_loader), total_psnr_b / len(test_loader)))
print("===> [SRCNN] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss / len(test_loader), total_psnr / len(test_loader)))

===> [Bicubic] Avg. Loss: 0.0046, PSNR: 25.1120 dB
===> [SRCNN] Avg. Loss: 0.0039, PSNR: 25.9509 dB


### 学習済み重みを利用する

In [14]:
# githubから30000epoch学習済みモデルをダウンロード
!wget -O SRCNN_epoch_30000.pth https://github.com/YuichiKZ/pytorch_sr/blob/main/weights/SRCNN_epoch_30000.pth?raw=true

--2022-08-20 05:18:15--  https://github.com/YuichiKZ/pytorch_sr/blob/main/weights/SRCNN_epoch_30000.pth?raw=true
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/YuichiKZ/pytorch_sr/raw/main/weights/SRCNN_epoch_30000.pth [following]
--2022-08-20 05:18:16--  https://github.com/YuichiKZ/pytorch_sr/raw/main/weights/SRCNN_epoch_30000.pth
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/weights/SRCNN_epoch_30000.pth [following]
--2022-08-20 05:18:16--  https://raw.githubusercontent.com/YuichiKZ/pytorch_sr/main/weights/SRCNN_epoch_30000.pth
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|18

In [21]:
save_dir = os.path.join(g_drive_dir,'pic')
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

model = SRCNN().to(device)
model.load_state_dict(torch.load("/content/SRCNN_epoch_30000.pth"))

test_set = DatasetFromFolderEval(image_dir='/content/General-100/test', scale_factor=4)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

model.eval()
total_loss, total_psnr = 0, 0
total_loss_b, total_psnr_b = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch[0].to(device), batch[1].to(device)

        prediction = model(inputs)
        loss = criterion(prediction, targets)
        total_loss += loss.data
        total_psnr += 10 * log10(1 / loss.data)

        loss = criterion(inputs, targets)
        total_loss_b += loss.data
        total_psnr_b += 10 * log10(1 / loss.data)

        #google driveに保存
        save_image(prediction, Path(save_dir) / '{}_sr.png'.format(batch[2][0]), nrow=1)
        save_image(inputs, Path(save_dir) / '{}_lr.png'.format(batch[2][0]), nrow=1)
        save_image(targets, Path(save_dir) / '{}_hr.png'.format(batch[2][0]), nrow=1)

print("===> [Bicubic] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss_b / len(test_loader), total_psnr_b / len(test_loader)))
print("===> [SRCNN] Avg. Loss: {:.4f}, PSNR: {:.4f} dB".format(total_loss / len(test_loader), total_psnr / len(test_loader)))

===> [Bicubic] Avg. Loss: 0.0046, PSNR: 25.1120 dB
===> [SRCNN] Avg. Loss: 0.0039, PSNR: 25.9509 dB
