In [1]:
# 필요한 import문
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

In [2]:
# PyTorch에서 구현된 초해상도 모델
import torch.nn as nn
import torch.nn.init as init


class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# 위에서 정의된 모델을 사용하여 초해상도 모델 생성
torch_model = SuperResolutionNet(upscale_factor=3)

In [3]:
# 미리 학습된 가중치를 읽어옵니다
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # 임의의 수

# 모델을 미리 학습된 가중치로 초기화합니다
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# 모델을 추론 모드로 전환합니다
torch_model.eval()

SuperResolutionNet(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [4]:
# 모델에 대한 입력값
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)

# 모델 변환
torch.onnx.export(torch_model,               # 실행될 모델
                  x,                         # 모델 입력값 (튜플 또는 여러 입력값들도 가능)
                  "super_resolution.onnx",   # 모델 저장 경로 (파일 또는 파일과 유사한 객체 모두 가능)
                  export_params=True,        # 모델 파일 안에 학습된 모델 가중치를 저장할지의 여부
                  opset_version=10,          # 모델을 변환할 때 사용할 ONNX 버전
                  do_constant_folding=True,  # 최적화시 상수폴딩을 사용할지의 여부
                  input_names = ['input'],   # 모델의 입력값을 가리키는 이름
                  output_names = ['output'], # 모델의 출력값을 가리키는 이름
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 가변적인 길이를 가진 차원
                                'output' : {0 : 'batch_size'}})

verbose: False, log level: Level.ERROR



In [43]:
import onnx
# 필요한 import문
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

batch_size = 1
input_SequenceLength = 5
x = torch.randn(input_SequenceLength, batch_size, 3, 224, 224, requires_grad=True)
# x = torch.randn(3, 224, 224, requires_grad=True)
onnx_model = onnx.load("KSLmodel_resnet50.onnx")
onnx.checker.check_model(onnx_model)

In [49]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("KSLmodel_resnet50.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# ONNX 런타임에서 계산된 결과값
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
# print(ort_inputs.values())
print(ort_session.get_inputs()[0])
print(ort_inputs['input'].shape)
print(ort_inputs['input'])
ort_outs = ort_session.run(None, ort_inputs)

# # ONNX 런타임과 PyTorch에서 연산된 결과값 비교
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")

NodeArg(name='input', type='tensor(float)', shape=['input_SequenceLength', 'BatchSize', 3, 224, 224])
(5, 1, 3, 224, 224)
[[[[[-3.04876585e-02  6.67223275e-01  1.23180974e+00 ...
     -1.33932590e-01 -1.31061876e+00 -4.81282473e-01]
    [-9.19815063e-01 -1.28460753e+00  5.33898294e-01 ...
      4.18975294e-01 -5.10674238e-01  1.98081100e+00]
    [ 3.11443138e+00  1.64364517e+00  7.04479516e-01 ...
     -2.37442470e+00  1.39032435e-02 -2.47324370e-02]
    ...
    [ 4.35406864e-01  1.39246428e+00  1.63152263e-01 ...
      1.08935285e+00 -3.74491334e-01 -1.04018725e-01]
    [-2.64780015e-01  1.41313589e+00 -1.91482961e+00 ...
      9.66548264e-01  1.38075745e+00  1.65683076e-01]
    [-1.14924169e+00  4.06958431e-01 -1.09354854e+00 ...
     -3.32650155e-01 -2.22459030e+00  8.22066605e-01]]

   [[ 1.88419390e+00 -4.93361562e-01 -2.69458145e-01 ...
      1.41047224e-01 -4.37365845e-02 -3.05013210e-01]
    [-2.30860665e-01  1.23573220e+00  7.34668255e-01 ...
     -1.28302646e+00 -4.96784419e-

In [72]:
from PIL import Image
import torchvision.transforms as transforms

img = Image.open("/home/khs/Documents/final_proj/cat_224x224.jpg")

resize = transforms.Resize([224, 224])
img = resize(img)

img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()

to_tensor = transforms.ToTensor()

img_y = to_tensor(img)
img_y.unsqueeze_(0)
print(img_y.shape)
img = torch.stack([img_y,img_y], dim=0)
img = torch.cat([img,img], dim=0)

print(img.shape)

torch.Size([1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])


In [76]:
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]
print(ort_outs[0])

[[0.02254743 0.11896421 0.08581936 0.08800354 0.07835846 0.1518134
  0.16795203 0.07015263 0.21638899]]


In [6]:
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')

# PyTorch 버전의 후처리 과정 코드를 이용해 결과 이미지 만들기
final_img = Image.merge(
    "YCbCr", [
        img_out_y,
        img_cb.resize(img_out_y.size, Image.BICUBIC),
        img_cr.resize(img_out_y.size, Image.BICUBIC),
    ]).convert("RGB")

# 이미지를 저장하고 모바일 기기에서의 결과 이미지와 비교하기
final_img.save("/home/khs/Documents/final_proj/cat_224x224_2.jpg")