In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.autograd import Variable
from torchvision import transforms
from torch.utils import data

from PIL import Image

In [2]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN,self).__init__()
        self.patch_extraction=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=5,stride=1,padding=4)
        self.non_linear=nn.Conv2d(in_channels=64,out_channels=32,kernel_size=1,stride=1,padding=0)
        self.reconstruction=nn.Conv2d(in_channels=32,out_channels=3,kernel_size=9,stride=1,padding=2)
    def forward(self,x):
        fm_1=F.relu(self.patch_extraction(x))
        fm_2=F.relu(self.non_linear(fm_1))
        fm_3=self.reconstruction(fm_2)
        return fm_3

In [3]:
srcnn=SRCNN()
print(srcnn)


SRCNN(
  (patch_extraction): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(4, 4))
  (non_linear): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (reconstruction): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1), padding=(2, 2))
)


In [4]:
transform=transforms.Compose([transforms.ToTensor(),])
class I91DataSet(data.Dataset):
    def __init__(self,root):
        imgs=os.listdir(root)
        self.imgs=[os.path.join(root,k) for k in imgs]
        self.transforms=transform
    def __getitem__(self, item):
        img_path=self.imgs[item]
        pil_img=Image.open(img_path)
        LR_pil_img=pil_img.resize((11,11),Image.BICUBIC)
        LR_pil_img=LR_pil_img.resize(pil_img.size)
        #LR_pil_img.show()
        #pil_img.show()
        return transforms.ToTensor()(pil_img),transforms.ToTensor()(LR_pil_img)
    def __len__(self):
        return len(self.imgs)

In [5]:
dataSet=I91DataSet('CropData')
print(dataSet[0])

(tensor([[[0.5490, 0.5922, 0.5490,  ..., 0.9294, 0.8471, 0.9451],
         [0.4549, 0.5647, 0.5059,  ..., 0.9843, 0.9922, 0.9412],
         [0.6157, 0.5490, 0.4980,  ..., 0.9804, 1.0000, 0.9725],
         ...,
         [0.7686, 0.7922, 0.8196,  ..., 0.5686, 0.5647, 0.5725],
         [0.8118, 0.8745, 0.8902,  ..., 0.5608, 0.5843, 0.5686],
         [0.8863, 0.9216, 0.9529,  ..., 0.4667, 0.4941, 0.5216]],

        [[0.4353, 0.5059, 0.4745,  ..., 0.7490, 0.6549, 0.7569],
         [0.3176, 0.4431, 0.4078,  ..., 0.8000, 0.8078, 0.7608],
         [0.4471, 0.3882, 0.3608,  ..., 0.7922, 0.8588, 0.8157],
         ...,
         [0.5569, 0.5882, 0.6235,  ..., 0.3529, 0.3490, 0.3569],
         [0.5922, 0.6588, 0.6863,  ..., 0.3529, 0.3647, 0.3451],
         [0.6667, 0.7098, 0.7490,  ..., 0.2706, 0.2824, 0.3020]],

        [[0.2157, 0.2784, 0.2431,  ..., 0.4471, 0.3451, 0.4353],
         [0.0980, 0.2196, 0.1843,  ..., 0.4941, 0.5020, 0.4549],
         [0.2314, 0.1765, 0.1490,  ..., 0.4627, 0.5647, 0

In [None]:
optimizer=optim.SGD(srcnn.parameters(),lr=0.0001)
trainLoader=data.DataLoader(dataset=dataSet,batch_size=8,shuffle=True)
criterion=nn.MSELoss()
EPOCH=200
for epoch in range(EPOCH):
    print('------------epoch {}-------------'.format(epoch))
    for i,(x,y) in enumerate(trainLoader):
        batch_x=Variable(x)
        batch_y=Variable(y)
        output=srcnn(batch_x)
        loss=criterion(output,batch_y)
        loss.backward()
        if(i%100==0):
            print(loss)

------------epoch 0-------------
tensor(0.2158, grad_fn=<MseLossBackward>)
tensor(0.2002, grad_fn=<MseLossBackward>)
tensor(0.2528, grad_fn=<MseLossBackward>)
tensor(0.2774, grad_fn=<MseLossBackward>)
tensor(0.2869, grad_fn=<MseLossBackward>)
------------epoch 1-------------
tensor(0.2321, grad_fn=<MseLossBackward>)
tensor(0.2108, grad_fn=<MseLossBackward>)
tensor(0.2321, grad_fn=<MseLossBackward>)
tensor(0.2258, grad_fn=<MseLossBackward>)
tensor(0.2528, grad_fn=<MseLossBackward>)
------------epoch 2-------------
tensor(0.2098, grad_fn=<MseLossBackward>)
tensor(0.3511, grad_fn=<MseLossBackward>)
tensor(0.1253, grad_fn=<MseLossBackward>)
tensor(0.2143, grad_fn=<MseLossBackward>)
tensor(0.2114, grad_fn=<MseLossBackward>)
------------epoch 3-------------
tensor(0.2081, grad_fn=<MseLossBackward>)
tensor(0.2315, grad_fn=<MseLossBackward>)
tensor(0.2112, grad_fn=<MseLossBackward>)
tensor(0.1476, grad_fn=<MseLossBackward>)
tensor(0.2342, grad_fn=<MseLossBackward>)
------------epoch 4---------

tensor(0.2510, grad_fn=<MseLossBackward>)
tensor(0.2176, grad_fn=<MseLossBackward>)
------------epoch 34-------------
tensor(0.1809, grad_fn=<MseLossBackward>)
tensor(0.2534, grad_fn=<MseLossBackward>)
tensor(0.1333, grad_fn=<MseLossBackward>)
tensor(0.1727, grad_fn=<MseLossBackward>)
tensor(0.2575, grad_fn=<MseLossBackward>)
------------epoch 35-------------
tensor(0.2443, grad_fn=<MseLossBackward>)
tensor(0.2043, grad_fn=<MseLossBackward>)
tensor(0.1588, grad_fn=<MseLossBackward>)
tensor(0.2410, grad_fn=<MseLossBackward>)
tensor(0.3063, grad_fn=<MseLossBackward>)
------------epoch 36-------------
tensor(0.2555, grad_fn=<MseLossBackward>)
tensor(0.2167, grad_fn=<MseLossBackward>)
tensor(0.1649, grad_fn=<MseLossBackward>)
tensor(0.2204, grad_fn=<MseLossBackward>)
tensor(0.2074, grad_fn=<MseLossBackward>)
------------epoch 37-------------
tensor(0.1974, grad_fn=<MseLossBackward>)
tensor(0.2583, grad_fn=<MseLossBackward>)
tensor(0.1844, grad_fn=<MseLossBackward>)
tensor(0.3046, grad_fn=<

tensor(0.2948, grad_fn=<MseLossBackward>)
tensor(0.2239, grad_fn=<MseLossBackward>)
tensor(0.2375, grad_fn=<MseLossBackward>)
tensor(0.1623, grad_fn=<MseLossBackward>)
------------epoch 68-------------
tensor(0.2117, grad_fn=<MseLossBackward>)
tensor(0.1985, grad_fn=<MseLossBackward>)
tensor(0.2010, grad_fn=<MseLossBackward>)
tensor(0.1924, grad_fn=<MseLossBackward>)
tensor(0.1845, grad_fn=<MseLossBackward>)
------------epoch 69-------------
tensor(0.2672, grad_fn=<MseLossBackward>)
tensor(0.2894, grad_fn=<MseLossBackward>)
tensor(0.1968, grad_fn=<MseLossBackward>)
tensor(0.1514, grad_fn=<MseLossBackward>)
tensor(0.1729, grad_fn=<MseLossBackward>)
------------epoch 70-------------
tensor(0.2788, grad_fn=<MseLossBackward>)
tensor(0.2659, grad_fn=<MseLossBackward>)
tensor(0.3539, grad_fn=<MseLossBackward>)
tensor(0.2553, grad_fn=<MseLossBackward>)
tensor(0.2481, grad_fn=<MseLossBackward>)
------------epoch 71-------------
tensor(0.1704, grad_fn=<MseLossBackward>)
tensor(0.2426, grad_fn=<

------------epoch 101-------------
tensor(0.2353, grad_fn=<MseLossBackward>)
tensor(0.1954, grad_fn=<MseLossBackward>)
tensor(0.3116, grad_fn=<MseLossBackward>)
tensor(0.3079, grad_fn=<MseLossBackward>)
tensor(0.1427, grad_fn=<MseLossBackward>)
------------epoch 102-------------
tensor(0.2436, grad_fn=<MseLossBackward>)
tensor(0.2336, grad_fn=<MseLossBackward>)
tensor(0.2270, grad_fn=<MseLossBackward>)
tensor(0.2069, grad_fn=<MseLossBackward>)
tensor(0.3158, grad_fn=<MseLossBackward>)
------------epoch 103-------------
tensor(0.2034, grad_fn=<MseLossBackward>)
tensor(0.2328, grad_fn=<MseLossBackward>)
tensor(0.2722, grad_fn=<MseLossBackward>)
tensor(0.3640, grad_fn=<MseLossBackward>)
tensor(0.2255, grad_fn=<MseLossBackward>)
------------epoch 104-------------
tensor(0.1783, grad_fn=<MseLossBackward>)
tensor(0.2454, grad_fn=<MseLossBackward>)
tensor(0.2371, grad_fn=<MseLossBackward>)
tensor(0.2015, grad_fn=<MseLossBackward>)
tensor(0.1895, grad_fn=<MseLossBackward>)
------------epoch 10

tensor(0.1801, grad_fn=<MseLossBackward>)
tensor(0.2231, grad_fn=<MseLossBackward>)
tensor(0.1938, grad_fn=<MseLossBackward>)
------------epoch 135-------------
tensor(0.3021, grad_fn=<MseLossBackward>)
tensor(0.2164, grad_fn=<MseLossBackward>)
tensor(0.1478, grad_fn=<MseLossBackward>)
tensor(0.2387, grad_fn=<MseLossBackward>)
tensor(0.2704, grad_fn=<MseLossBackward>)
------------epoch 136-------------
tensor(0.2002, grad_fn=<MseLossBackward>)
tensor(0.2255, grad_fn=<MseLossBackward>)
tensor(0.2602, grad_fn=<MseLossBackward>)
tensor(0.1809, grad_fn=<MseLossBackward>)
tensor(0.2771, grad_fn=<MseLossBackward>)
------------epoch 137-------------
tensor(0.2835, grad_fn=<MseLossBackward>)
tensor(0.3131, grad_fn=<MseLossBackward>)
tensor(0.2546, grad_fn=<MseLossBackward>)
tensor(0.1394, grad_fn=<MseLossBackward>)
tensor(0.2281, grad_fn=<MseLossBackward>)
------------epoch 138-------------
tensor(0.2278, grad_fn=<MseLossBackward>)
tensor(0.3218, grad_fn=<MseLossBackward>)
tensor(0.1752, grad_

In [None]:
pil_test=Image.open("Data/t1.bmp")
pil_test_y=transforms.ToTensor()(pil_test)
pil_test_y=pil_test_y.unsqueeze(0)
LR_pil_test=pil_test.resize((66,59),Image.BICUBIC)
LR_pil_test=LR_pil_test.resize(pil_test.size)
#LR_pil_test.show()
tensor_test=transforms.ToTensor()(LR_pil_test)
tensor_test=tensor_test.unsqueeze(0)
print(tensor_test)
tensor_test_y=srcnn(tensor_test)
print(tensor_test_y)
loss=criterion(tensor_test_y,pil_test_y)
print(loss)
tensor_test_y=tensor_test_y.squeeze(0)
result=transforms.ToPILImage()(tensor_test_y)
result.show()