In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

import cv2
import os
import numpy as np
import glob
import random
import matplotlib.pyplot as plt
import albumentations as A

# glob.glob(train set의 경로)
data = glob.glob("./yolov5/data/images/train/*.jpg")
# test에는 cropped image를 우선 모아둔 파일을 지정
test = glob.glob("./yolov5/final_data_640/images/train/*.png")

In [None]:
class SRdataset(Dataset):

    def __init__(self, pth):
        self.pth = pth        

    def __len__(self):
        return len(self.pth)

    def __getitem__(self,idx):
        path = self.pth[idx]
        label = cv2.imread(path)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
        # yolo 모델의 학습을 위해서 640 x 640으로 맞춰준다.
        label = cv2.resize(label, dsize=(640, 640), interpolation=cv2.INTER_AREA)
        # pixel 정규화
        label = label.astype(np.float32) / 255.0
        # low resolution image를 위해서 Gaussian Blur를 추가
        inp = cv2.GaussianBlur(label,(0,0),2)
        label = np.transpose(label, (2,0,1))
        inp = np.transpose(inp, (2,0,1))

        inp, label = torch.tensor(inp, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

        return inp, label

In [None]:
# 학습 데이터셋은 확보한 원본 이미지 4000여장을 활용
train = SRdataset(data[:4000])
train_loader = DataLoader(train, batch_size=16)
len(train), len(train_loader)

In [None]:
test_ = SRdataset(test)
test_loader = DataLoader(test_, batch_size = 1)

In [None]:
# 우선 low resolution과 label을 체크한다.
# 아래의 그림은 input에 해당하는 low resolution image, label에 해당하는 high resolution image
i, l = train[0]
plt.subplot(1,2,1)
plt.title("low resolution image")
plt.imshow(np.transpose(i, (1,2,0)))
plt.subplot(1,2,2)
plt.title("label image")
plt.imshow(np.transpose(l, (1,2,0)))

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()

        # 기본적인 SRCNN 모델 구조를 따랐습니다.
        self.conv1 = nn.Conv2d(3, 64, 9, padding=2, padding_mode='replicate')
        self.conv2 = nn.Conv2d(64, 32, 1, padding=2, padding_mode='replicate')
        self.conv3 = nn.Conv2d(32, 3, 5, padding=2, padding_mode='replicate')
        
        # 가중치 초기화
        torch.nn.init.kaiming_normal_(self.conv1.weight)
        torch.nn.init.kaiming_normal_(self.conv2.weight)
        torch.nn.init.kaiming_normal_(self.conv3.weight)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)

        return x

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SRCNN().to(device)

In [None]:
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

In [None]:
# PSNR: Super Resolution의 평가 지표
import math

def psnr(label, outputs, max_val=1.):
    label = label.cpu().detach().numpy()
    outputs = outputs.cpu().detach().numpy()
    img_diff = outputs - label
    rmse = math.sqrt(np.mean((img_diff)**2))
    if rmse == 0: # label과 output이 완전히 일치하는 경우
        return 100
    else:
        psnr = 20 * math.log10(max_val/rmse)
        return psnr

In [None]:
# train 함수
def training(model, data_loader):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0

    for i, (image, label) in enumerate(data_loader):
        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        outputs = model(image)
        loss = loss_function(outputs, label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        batch_psnr = psnr(label, outputs)
        running_psnr += batch_psnr
    
    final_loss = running_loss / len(data_loader)
    final_psnr = running_psnr / len(data_loader)
    return final_loss, final_psnr

In [None]:
from tqdm import tqdm
epochs = 20
best_psnr = 0
train_loss = []
train_psnr = []
for epoch in tqdm(range(epochs)):
    print(f'Epoch {epoch + 1} of {epochs}')
    train_epoch_loss, train_epoch_psnr = training(model, train_loader)
    train_loss.append(train_epoch_loss)
    train_psnr.append(train_epoch_psnr)
    if train_epoch_psnr >= best_psnr:
        best_psnr = train_epoch_psnr
        torch.save(model, "best_psnr.pt")
    print(f'Train PSNR: {train_epoch_psnr:.3f}')

print("Best PSNR: {}".format(best_psnr))

In [None]:
# 해당 부분은 Cropped image를 model에 통과 시켜 High resolution image를 얻어내고 저장까지 하는 부분이다.
cnt = 0
with torch.no_grad():
    for x, _ in tqdm(test_loader):
        x = x.to(device)
        output_image = model(x)
        output = output_image.cpu().numpy().squeeze()
        output = output.astype(np.uint8)
        output = np.transpose(output, (1, 2, 0))
        output = Image.fromarray(output)
        # save(cropped_image 저장 경로)
        output.save(test[cnt])
        cnt += 1