In [1]:
import sys
import os
import glob
import numpy as np
import torch
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

In [2]:
class Model(nn.Module):
    def __init__(self, d=48, s=12, m=2, upscale = 2):
        super().__init__()
        
        #feature extraction layer
        self.fe_layer = nn.Conv2d(in_channels=1, out_channels=d, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros')
        #shrinking layer
        self.sh_layer = nn.Conv2d(in_channels=d, out_channels=s, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        
        #mapping layer
        self.m = m
        for i in range(m):
            setattr(self, f'map_layer{i+1}', nn.Conv2d(in_channels=s, out_channels=s, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros'))
        
        #expanding layer
        self.ex_layer = nn.Conv2d(in_channels=s, out_channels=d, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        
        #deconv layer
        self.deconv = nn.ConvTranspose2d(in_channels=d, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1, groups=1, bias=True, dilation=1, padding_mode='zeros')
        self.deconv1 = nn.Conv2d(in_channels=d, out_channels=upscale**2, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros')
        self.deconv2 = nn.PixelShuffle(upscale_factor=upscale)
        
        self.relu = nn.ReLU()
        self.prelu = nn.PReLU()
        
    def forward(self, x):
        out = self.prelu(self.fe_layer(x))
        #res_out0 = out
        #print(out.size())
        out = self.prelu(self.sh_layer(out))
        res_out1 = out
        #print(out.size())
        for i in range(self.m):
            out = self.prelu(getattr(self, f'map_layer{i+1}')(out))
            #print(out.size())
        out = self.prelu(self.ex_layer(out+res_out1))
        #print(out.size())
        #out = self.deconv(out)
        #print(out.size())
        out = self.deconv1(out)
        #print(out.size())
        out = self.relu(self.deconv2(out))
        #print(out.size())
        return out

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_torch = Model()
model_torch = model_torch.to(device)
#FSRCNNのモデル読み出し
model_torch.load_state_dict(torch.load('./G_model.pth'))

<All keys matched successfully>

In [4]:
dummy_input = torch.from_numpy(np.ones((1, 1, 128, 128))).float().to(device)
import time 
start = time.time()
dummy_output = model_torch(dummy_input)
end = time.time()
print(dummy_output.size())
print(f'time = {end-start}')

torch.Size([1, 1, 256, 256])
time = 0.0028531551361083984


In [5]:
torch.onnx.export(model_torch, dummy_input, './model.onnx')