In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.autograd import Variable
from torchvision import transforms
from torch.utils import data

from PIL import Image

In [2]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN,self).__init__()
        self.patch_extraction=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=5,stride=1,padding=4)
        self.non_linear=nn.Conv2d(in_channels=64,out_channels=32,kernel_size=1,stride=1,padding=0)
        self.reconstruction=nn.Conv2d(in_channels=32,out_channels=3,kernel_size=9,stride=1,padding=2)
    def forward(self,x):
        fm_1=F.relu(self.patch_extraction(x))
        fm_2=F.relu(self.non_linear(fm_1))
        fm_3=self.reconstruction(fm_2)
        return fm_3

In [3]:
srcnn=SRCNN()
print(srcnn)


SRCNN(
  (patch_extraction): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(4, 4))
  (non_linear): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (reconstruction): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1), padding=(2, 2))
)


In [4]:
transform=transforms.Compose([transforms.ToTensor(),])
class I91DataSet(data.Dataset):
    def __init__(self,root):
        imgs=os.listdir(root)
        self.imgs=[os.path.join(root,k) for k in imgs]
        self.transforms=transform
    def __getitem__(self, item):
        img_path=self.imgs[item]
        pil_img=Image.open(img_path)
        LR_pil_img=pil_img.resize((11,11),Image.BICUBIC)
        LR_pil_img=LR_pil_img.resize(pil_img.size)
        #LR_pil_img.show()
        #pil_img.show()
        return transforms.ToTensor()(pil_img),transforms.ToTensor()(LR_pil_img)
    def __len__(self):
        return len(self.imgs)

In [5]:
dataSet=I91DataSet('CropData')
print(dataSet[0])

(tensor([[[0.5490, 0.5922, 0.5490,  ..., 0.9294, 0.8471, 0.9451],
         [0.4549, 0.5647, 0.5059,  ..., 0.9843, 0.9922, 0.9412],
         [0.6157, 0.5490, 0.4980,  ..., 0.9804, 1.0000, 0.9725],
         ...,
         [0.7686, 0.7922, 0.8196,  ..., 0.5686, 0.5647, 0.5725],
         [0.8118, 0.8745, 0.8902,  ..., 0.5608, 0.5843, 0.5686],
         [0.8863, 0.9216, 0.9529,  ..., 0.4667, 0.4941, 0.5216]],

        [[0.4353, 0.5059, 0.4745,  ..., 0.7490, 0.6549, 0.7569],
         [0.3176, 0.4431, 0.4078,  ..., 0.8000, 0.8078, 0.7608],
         [0.4471, 0.3882, 0.3608,  ..., 0.7922, 0.8588, 0.8157],
         ...,
         [0.5569, 0.5882, 0.6235,  ..., 0.3529, 0.3490, 0.3569],
         [0.5922, 0.6588, 0.6863,  ..., 0.3529, 0.3647, 0.3451],
         [0.6667, 0.7098, 0.7490,  ..., 0.2706, 0.2824, 0.3020]],

        [[0.2157, 0.2784, 0.2431,  ..., 0.4471, 0.3451, 0.4353],
         [0.0980, 0.2196, 0.1843,  ..., 0.4941, 0.5020, 0.4549],
         [0.2314, 0.1765, 0.1490,  ..., 0.4627, 0.5647, 0

In [12]:
optimizer=optim.Adam(srcnn.parameters(),lr=0.0001)
trainLoader=data.DataLoader(dataset=dataSet,batch_size=8,shuffle=True)
criterion=nn.MSELoss()
EPOCH=50
for epoch in range(EPOCH):
    print('------------epoch {}-------------'.format(epoch))
    for i,(x,y) in enumerate(trainLoader):
        batch_x=Variable(x)
        batch_y=Variable(y)
        output=srcnn(batch_x)
        
        loss=criterion(output,batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if(i%100==0):
            print(loss)

------------epoch 0-------------
tensor(0.3377, grad_fn=<MseLossBackward>)
tensor(0.0177, grad_fn=<MseLossBackward>)
tensor(0.0053, grad_fn=<MseLossBackward>)
tensor(0.0032, grad_fn=<MseLossBackward>)
tensor(0.0014, grad_fn=<MseLossBackward>)
------------epoch 1-------------
tensor(0.0012, grad_fn=<MseLossBackward>)
tensor(0.0011, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0009, grad_fn=<MseLossBackward>)
tensor(0.0006, grad_fn=<MseLossBackward>)
------------epoch 2-------------
tensor(0.0008, grad_fn=<MseLossBackward>)
tensor(0.0009, grad_fn=<MseLossBackward>)
tensor(0.0005, grad_fn=<MseLossBackward>)
tensor(0.0009, grad_fn=<MseLossBackward>)
tensor(0.0003, grad_fn=<MseLossBackward>)
------------epoch 3-------------
tensor(0.0006, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0003, grad_fn=<MseLossBackward>)
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(0.0004, grad_fn=<MseLossBackward>)
------------epoch 4---------

tensor(7.0710e-05, grad_fn=<MseLossBackward>)
tensor(6.4401e-05, grad_fn=<MseLossBackward>)
tensor(7.2163e-05, grad_fn=<MseLossBackward>)
tensor(7.1149e-05, grad_fn=<MseLossBackward>)
------------epoch 34-------------
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(6.3966e-05, grad_fn=<MseLossBackward>)
tensor(5.9999e-05, grad_fn=<MseLossBackward>)
------------epoch 35-------------
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(6.0937e-05, grad_fn=<MseLossBackward>)
tensor(5.4921e-05, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(9.8094e-05, grad_fn=<MseLossBackward>)
------------epoch 36-------------
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(5.1719e-05, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(9.1182e-05, grad_fn=<MseLossBackward>)
tensor(5.6950e-05, grad_fn=<MseLossBackward>)
------------epoch 37-------------
tensor(0.0001, gra

KeyboardInterrupt: 

In [13]:
pil_test=Image.open('test.jpg')
print(pil_test)
pil_test_y=transforms.ToTensor()(pil_test)
pil_test_y=pil_test_y.unsqueeze(0)
LR_pil_test=pil_test.resize((640,480),Image.BICUBIC)
LR_pil_test=LR_pil_test.resize(pil_test.size)

#LR_pil_test.show()
tensor_test=transforms.ToTensor()(LR_pil_test)
LR_pil_test=LR_pil_test.save('source3.jpg')
tensor_test=tensor_test.unsqueeze(0)
print(tensor_test)
tensor_test_y=srcnn(tensor_test)
print(tensor_test_y)
loss=criterion(tensor_test_y,pil_test_y)
print(loss)
tensor_test_y=tensor_test_y.squeeze(0)
result=transforms.ToPILImage()(tensor_test_y)
result.show()
result.save('result3.jpg')

<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1920x1080 at 0x7F30308DEDF0>
tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.1098, 0.1020, 0.0980],
          [0.0000, 0.0000, 0.0000,  ..., 0.1098, 0.1020, 0.0980],
          [0.0000, 0.0000, 0.0000,  ..., 0.1059, 0.1020, 0.0980],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.2039, 0.2000, 0.2000],
          [0.0000, 0.0000, 0.0000,  ..., 0.2039, 0.2000, 0.2000],
          [0.0000, 0.0000, 0.0000,  ..., 0.2000, 0.1961, 0.1961],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.1922, 0.1882, 0.1882],
          [0.0000, 0.000