In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.autograd import Variable
from torchvision import transforms
from torch.utils import data

from PIL import Image

In [2]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN,self).__init__()
        self.patch_extraction=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=5,stride=1,padding=4)
        self.non_linear=nn.Conv2d(in_channels=64,out_channels=32,kernel_size=1,stride=1,padding=0)
        self.reconstruction=nn.Conv2d(in_channels=32,out_channels=3,kernel_size=9,stride=1,padding=2)
    def forward(self,x):
        fm_1=F.relu(self.patch_extraction(x))
        fm_2=F.relu(self.non_linear(fm_1))
        fm_3=self.reconstruction(fm_2)
        return fm_3

In [3]:
srcnn=SRCNN()
print(srcnn)


SRCNN(
  (patch_extraction): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(4, 4))
  (non_linear): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (reconstruction): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1), padding=(2, 2))
)


In [4]:
transform=transforms.Compose([transforms.ToTensor(),])
class I91DataSet(data.Dataset):
    def __init__(self,root):
        imgs=os.listdir(root)
        self.imgs=[os.path.join(root,k) for k in imgs]
        self.transforms=transform
    def __getitem__(self, item):
        img_path=self.imgs[item]
        pil_img=Image.open(img_path)
        LR_pil_img=pil_img.resize((11,11),Image.BICUBIC)
        LR_pil_img=LR_pil_img.resize(pil_img.size)
        #LR_pil_img.show()
        #pil_img.show()
        return transforms.ToTensor()(pil_img),transforms.ToTensor()(LR_pil_img)
    def __len__(self):
        return len(self.imgs)

In [5]:
dataSet=I91DataSet('CropData')
print(dataSet[0])

(tensor([[[0.5490, 0.5922, 0.5490,  ..., 0.9294, 0.8471, 0.9451],
         [0.4549, 0.5647, 0.5059,  ..., 0.9843, 0.9922, 0.9412],
         [0.6157, 0.5490, 0.4980,  ..., 0.9804, 1.0000, 0.9725],
         ...,
         [0.7686, 0.7922, 0.8196,  ..., 0.5686, 0.5647, 0.5725],
         [0.8118, 0.8745, 0.8902,  ..., 0.5608, 0.5843, 0.5686],
         [0.8863, 0.9216, 0.9529,  ..., 0.4667, 0.4941, 0.5216]],

        [[0.4353, 0.5059, 0.4745,  ..., 0.7490, 0.6549, 0.7569],
         [0.3176, 0.4431, 0.4078,  ..., 0.8000, 0.8078, 0.7608],
         [0.4471, 0.3882, 0.3608,  ..., 0.7922, 0.8588, 0.8157],
         ...,
         [0.5569, 0.5882, 0.6235,  ..., 0.3529, 0.3490, 0.3569],
         [0.5922, 0.6588, 0.6863,  ..., 0.3529, 0.3647, 0.3451],
         [0.6667, 0.7098, 0.7490,  ..., 0.2706, 0.2824, 0.3020]],

        [[0.2157, 0.2784, 0.2431,  ..., 0.4471, 0.3451, 0.4353],
         [0.0980, 0.2196, 0.1843,  ..., 0.4941, 0.5020, 0.4549],
         [0.2314, 0.1765, 0.1490,  ..., 0.4627, 0.5647, 0

In [7]:
optimizer=optim.SGD(srcnn.parameters(),lr=0.0001)
trainLoader=data.DataLoader(dataset=dataSet,batch_size=8,shuffle=True)
criterion=nn.MSELoss()
EPOCH=5
for epoch in range(EPOCH):
    for i,(x,y) in enumerate(trainLoader):
        batch_x=Variable(x)
        batch_y=Variable(y)
        output=srcnn(batch_x)
        loss=criterion(output,batch_y)
        loss.backward()
        print(loss)

tensor(0.3928, grad_fn=<MseLossBackward>)
tensor(0.1945, grad_fn=<MseLossBackward>)
tensor(0.4562, grad_fn=<MseLossBackward>)
tensor(0.4815, grad_fn=<MseLossBackward>)
tensor(0.4123, grad_fn=<MseLossBackward>)
tensor(0.4033, grad_fn=<MseLossBackward>)
tensor(0.3177, grad_fn=<MseLossBackward>)
tensor(0.2045, grad_fn=<MseLossBackward>)
tensor(0.3500, grad_fn=<MseLossBackward>)
tensor(0.2634, grad_fn=<MseLossBackward>)
tensor(0.2324, grad_fn=<MseLossBackward>)
tensor(0.3036, grad_fn=<MseLossBackward>)
tensor(0.3346, grad_fn=<MseLossBackward>)
tensor(0.3267, grad_fn=<MseLossBackward>)
tensor(0.3285, grad_fn=<MseLossBackward>)
tensor(0.3362, grad_fn=<MseLossBackward>)
tensor(0.3619, grad_fn=<MseLossBackward>)
tensor(0.2869, grad_fn=<MseLossBackward>)
tensor(0.3105, grad_fn=<MseLossBackward>)
tensor(0.2552, grad_fn=<MseLossBackward>)
tensor(0.3108, grad_fn=<MseLossBackward>)
tensor(0.3521, grad_fn=<MseLossBackward>)
tensor(0.2866, grad_fn=<MseLossBackward>)
tensor(0.1876, grad_fn=<MseLossBac

tensor(0.3069, grad_fn=<MseLossBackward>)
tensor(0.3114, grad_fn=<MseLossBackward>)
tensor(0.2519, grad_fn=<MseLossBackward>)
tensor(0.3152, grad_fn=<MseLossBackward>)
tensor(0.3360, grad_fn=<MseLossBackward>)
tensor(0.2451, grad_fn=<MseLossBackward>)
tensor(0.2500, grad_fn=<MseLossBackward>)
tensor(0.2591, grad_fn=<MseLossBackward>)
tensor(0.3440, grad_fn=<MseLossBackward>)
tensor(0.3058, grad_fn=<MseLossBackward>)
tensor(0.4074, grad_fn=<MseLossBackward>)
tensor(0.3773, grad_fn=<MseLossBackward>)
tensor(0.2853, grad_fn=<MseLossBackward>)
tensor(0.2618, grad_fn=<MseLossBackward>)
tensor(0.3898, grad_fn=<MseLossBackward>)
tensor(0.3710, grad_fn=<MseLossBackward>)
tensor(0.2874, grad_fn=<MseLossBackward>)
tensor(0.2939, grad_fn=<MseLossBackward>)
tensor(0.3651, grad_fn=<MseLossBackward>)
tensor(0.3090, grad_fn=<MseLossBackward>)
tensor(0.2519, grad_fn=<MseLossBackward>)
tensor(0.3518, grad_fn=<MseLossBackward>)
tensor(0.3187, grad_fn=<MseLossBackward>)
tensor(0.3861, grad_fn=<MseLossBac

tensor(0.2510, grad_fn=<MseLossBackward>)
tensor(0.3418, grad_fn=<MseLossBackward>)
tensor(0.4618, grad_fn=<MseLossBackward>)
tensor(0.2981, grad_fn=<MseLossBackward>)
tensor(0.3417, grad_fn=<MseLossBackward>)
tensor(0.3384, grad_fn=<MseLossBackward>)
tensor(0.2452, grad_fn=<MseLossBackward>)
tensor(0.2863, grad_fn=<MseLossBackward>)
tensor(0.2781, grad_fn=<MseLossBackward>)
tensor(0.2604, grad_fn=<MseLossBackward>)
tensor(0.2389, grad_fn=<MseLossBackward>)
tensor(0.3174, grad_fn=<MseLossBackward>)
tensor(0.2068, grad_fn=<MseLossBackward>)
tensor(0.3275, grad_fn=<MseLossBackward>)
tensor(0.3368, grad_fn=<MseLossBackward>)
tensor(0.4744, grad_fn=<MseLossBackward>)
tensor(0.3087, grad_fn=<MseLossBackward>)
tensor(0.3427, grad_fn=<MseLossBackward>)
tensor(0.2557, grad_fn=<MseLossBackward>)
tensor(0.2941, grad_fn=<MseLossBackward>)
tensor(0.2253, grad_fn=<MseLossBackward>)
tensor(0.4005, grad_fn=<MseLossBackward>)
tensor(0.2427, grad_fn=<MseLossBackward>)
tensor(0.3837, grad_fn=<MseLossBac

tensor(0.2703, grad_fn=<MseLossBackward>)
tensor(0.2504, grad_fn=<MseLossBackward>)
tensor(0.3512, grad_fn=<MseLossBackward>)
tensor(0.3711, grad_fn=<MseLossBackward>)
tensor(0.2852, grad_fn=<MseLossBackward>)
tensor(0.2621, grad_fn=<MseLossBackward>)
tensor(0.3067, grad_fn=<MseLossBackward>)
tensor(0.3819, grad_fn=<MseLossBackward>)
tensor(0.4190, grad_fn=<MseLossBackward>)
tensor(0.1801, grad_fn=<MseLossBackward>)
tensor(0.2688, grad_fn=<MseLossBackward>)
tensor(0.5477, grad_fn=<MseLossBackward>)
tensor(0.3246, grad_fn=<MseLossBackward>)
tensor(0.2115, grad_fn=<MseLossBackward>)
tensor(0.2212, grad_fn=<MseLossBackward>)
tensor(0.2441, grad_fn=<MseLossBackward>)
tensor(0.3253, grad_fn=<MseLossBackward>)
tensor(0.1944, grad_fn=<MseLossBackward>)
tensor(0.3924, grad_fn=<MseLossBackward>)
tensor(0.2736, grad_fn=<MseLossBackward>)
tensor(0.2835, grad_fn=<MseLossBackward>)
tensor(0.3349, grad_fn=<MseLossBackward>)
tensor(0.4240, grad_fn=<MseLossBackward>)
tensor(0.3682, grad_fn=<MseLossBac

tensor(0.2843, grad_fn=<MseLossBackward>)
tensor(0.2565, grad_fn=<MseLossBackward>)
tensor(0.3736, grad_fn=<MseLossBackward>)
tensor(0.2721, grad_fn=<MseLossBackward>)
tensor(0.2468, grad_fn=<MseLossBackward>)
tensor(0.2559, grad_fn=<MseLossBackward>)
tensor(0.3898, grad_fn=<MseLossBackward>)
tensor(0.2874, grad_fn=<MseLossBackward>)
tensor(0.2558, grad_fn=<MseLossBackward>)
tensor(0.3282, grad_fn=<MseLossBackward>)
tensor(0.2616, grad_fn=<MseLossBackward>)
tensor(0.3388, grad_fn=<MseLossBackward>)
tensor(0.2752, grad_fn=<MseLossBackward>)
tensor(0.3062, grad_fn=<MseLossBackward>)
tensor(0.2808, grad_fn=<MseLossBackward>)
tensor(0.2873, grad_fn=<MseLossBackward>)
tensor(0.2835, grad_fn=<MseLossBackward>)
tensor(0.2431, grad_fn=<MseLossBackward>)
tensor(0.3824, grad_fn=<MseLossBackward>)
tensor(0.3847, grad_fn=<MseLossBackward>)
tensor(0.2105, grad_fn=<MseLossBackward>)
tensor(0.3972, grad_fn=<MseLossBackward>)
tensor(0.3016, grad_fn=<MseLossBackward>)
tensor(0.3247, grad_fn=<MseLossBac

tensor(0.2570, grad_fn=<MseLossBackward>)
tensor(0.3671, grad_fn=<MseLossBackward>)
tensor(0.4182, grad_fn=<MseLossBackward>)
tensor(0.3510, grad_fn=<MseLossBackward>)
tensor(0.3611, grad_fn=<MseLossBackward>)
tensor(0.3046, grad_fn=<MseLossBackward>)
tensor(0.3966, grad_fn=<MseLossBackward>)
tensor(0.5045, grad_fn=<MseLossBackward>)
tensor(0.3107, grad_fn=<MseLossBackward>)
tensor(0.4126, grad_fn=<MseLossBackward>)
tensor(0.2164, grad_fn=<MseLossBackward>)
tensor(0.2505, grad_fn=<MseLossBackward>)
tensor(0.3853, grad_fn=<MseLossBackward>)
tensor(0.2372, grad_fn=<MseLossBackward>)
tensor(0.2565, grad_fn=<MseLossBackward>)
tensor(0.2538, grad_fn=<MseLossBackward>)
tensor(0.3434, grad_fn=<MseLossBackward>)
tensor(0.3359, grad_fn=<MseLossBackward>)
tensor(0.3484, grad_fn=<MseLossBackward>)
tensor(0.3254, grad_fn=<MseLossBackward>)
tensor(0.2832, grad_fn=<MseLossBackward>)
tensor(0.2604, grad_fn=<MseLossBackward>)
tensor(0.2981, grad_fn=<MseLossBackward>)
tensor(0.3087, grad_fn=<MseLossBac

tensor(0.3521, grad_fn=<MseLossBackward>)
tensor(0.3097, grad_fn=<MseLossBackward>)
tensor(0.3280, grad_fn=<MseLossBackward>)
tensor(0.2774, grad_fn=<MseLossBackward>)
tensor(0.3910, grad_fn=<MseLossBackward>)
tensor(0.2428, grad_fn=<MseLossBackward>)
tensor(0.1793, grad_fn=<MseLossBackward>)
tensor(0.2857, grad_fn=<MseLossBackward>)
tensor(0.3941, grad_fn=<MseLossBackward>)
tensor(0.4402, grad_fn=<MseLossBackward>)
tensor(0.3262, grad_fn=<MseLossBackward>)
tensor(0.3253, grad_fn=<MseLossBackward>)
tensor(0.2357, grad_fn=<MseLossBackward>)
tensor(0.3206, grad_fn=<MseLossBackward>)
tensor(0.2626, grad_fn=<MseLossBackward>)
tensor(0.3937, grad_fn=<MseLossBackward>)
tensor(0.3397, grad_fn=<MseLossBackward>)
tensor(0.2702, grad_fn=<MseLossBackward>)
tensor(0.3149, grad_fn=<MseLossBackward>)
tensor(0.2565, grad_fn=<MseLossBackward>)
tensor(0.1798, grad_fn=<MseLossBackward>)
tensor(0.2499, grad_fn=<MseLossBackward>)
tensor(0.2864, grad_fn=<MseLossBackward>)
tensor(0.3067, grad_fn=<MseLossBac

tensor(0.3789, grad_fn=<MseLossBackward>)
tensor(0.3010, grad_fn=<MseLossBackward>)
tensor(0.3418, grad_fn=<MseLossBackward>)
tensor(0.3074, grad_fn=<MseLossBackward>)
tensor(0.2150, grad_fn=<MseLossBackward>)
tensor(0.3734, grad_fn=<MseLossBackward>)
tensor(0.3790, grad_fn=<MseLossBackward>)
tensor(0.3996, grad_fn=<MseLossBackward>)
tensor(0.2785, grad_fn=<MseLossBackward>)
tensor(0.3276, grad_fn=<MseLossBackward>)
tensor(0.1247, grad_fn=<MseLossBackward>)
tensor(0.2605, grad_fn=<MseLossBackward>)
tensor(0.1601, grad_fn=<MseLossBackward>)
tensor(0.2747, grad_fn=<MseLossBackward>)
tensor(0.2968, grad_fn=<MseLossBackward>)
tensor(0.2684, grad_fn=<MseLossBackward>)
tensor(0.4587, grad_fn=<MseLossBackward>)
tensor(0.1695, grad_fn=<MseLossBackward>)
tensor(0.2873, grad_fn=<MseLossBackward>)
tensor(0.2621, grad_fn=<MseLossBackward>)
tensor(0.2792, grad_fn=<MseLossBackward>)
tensor(0.3970, grad_fn=<MseLossBackward>)
tensor(0.2589, grad_fn=<MseLossBackward>)
tensor(0.2626, grad_fn=<MseLossBac

tensor(0.2226, grad_fn=<MseLossBackward>)
tensor(0.2165, grad_fn=<MseLossBackward>)
tensor(0.4459, grad_fn=<MseLossBackward>)
tensor(0.3404, grad_fn=<MseLossBackward>)
tensor(0.2388, grad_fn=<MseLossBackward>)
tensor(0.2550, grad_fn=<MseLossBackward>)
tensor(0.3102, grad_fn=<MseLossBackward>)
tensor(0.2590, grad_fn=<MseLossBackward>)
tensor(0.3089, grad_fn=<MseLossBackward>)
tensor(0.2461, grad_fn=<MseLossBackward>)
tensor(0.2425, grad_fn=<MseLossBackward>)
tensor(0.3050, grad_fn=<MseLossBackward>)
tensor(0.3779, grad_fn=<MseLossBackward>)
tensor(0.4210, grad_fn=<MseLossBackward>)
tensor(0.1914, grad_fn=<MseLossBackward>)
tensor(0.3397, grad_fn=<MseLossBackward>)
tensor(0.3556, grad_fn=<MseLossBackward>)
tensor(0.4336, grad_fn=<MseLossBackward>)
tensor(0.2917, grad_fn=<MseLossBackward>)
tensor(0.2887, grad_fn=<MseLossBackward>)
tensor(0.4429, grad_fn=<MseLossBackward>)
tensor(0.3061, grad_fn=<MseLossBackward>)
tensor(0.3743, grad_fn=<MseLossBackward>)
tensor(0.3219, grad_fn=<MseLossBac

tensor(0.2748, grad_fn=<MseLossBackward>)
tensor(0.2940, grad_fn=<MseLossBackward>)
tensor(0.2863, grad_fn=<MseLossBackward>)
tensor(0.1819, grad_fn=<MseLossBackward>)
tensor(0.4353, grad_fn=<MseLossBackward>)
tensor(0.2755, grad_fn=<MseLossBackward>)
tensor(0.3184, grad_fn=<MseLossBackward>)
tensor(0.3214, grad_fn=<MseLossBackward>)
tensor(0.3165, grad_fn=<MseLossBackward>)
tensor(0.2221, grad_fn=<MseLossBackward>)
tensor(0.2273, grad_fn=<MseLossBackward>)
tensor(0.2730, grad_fn=<MseLossBackward>)
tensor(0.3366, grad_fn=<MseLossBackward>)
tensor(0.2265, grad_fn=<MseLossBackward>)
tensor(0.2722, grad_fn=<MseLossBackward>)
tensor(0.3860, grad_fn=<MseLossBackward>)
tensor(0.2876, grad_fn=<MseLossBackward>)
tensor(0.2104, grad_fn=<MseLossBackward>)
tensor(0.4707, grad_fn=<MseLossBackward>)
tensor(0.2307, grad_fn=<MseLossBackward>)
tensor(0.2330, grad_fn=<MseLossBackward>)
tensor(0.3390, grad_fn=<MseLossBackward>)
tensor(0.3177, grad_fn=<MseLossBackward>)
tensor(0.3403, grad_fn=<MseLossBac

tensor(0.2119, grad_fn=<MseLossBackward>)
tensor(0.2009, grad_fn=<MseLossBackward>)
tensor(0.3403, grad_fn=<MseLossBackward>)
tensor(0.2780, grad_fn=<MseLossBackward>)
tensor(0.2167, grad_fn=<MseLossBackward>)
tensor(0.3971, grad_fn=<MseLossBackward>)
tensor(0.3152, grad_fn=<MseLossBackward>)
tensor(0.3360, grad_fn=<MseLossBackward>)
tensor(0.3236, grad_fn=<MseLossBackward>)
tensor(0.2745, grad_fn=<MseLossBackward>)
tensor(0.2161, grad_fn=<MseLossBackward>)
tensor(0.3863, grad_fn=<MseLossBackward>)
tensor(0.3048, grad_fn=<MseLossBackward>)
tensor(0.4655, grad_fn=<MseLossBackward>)
tensor(0.2065, grad_fn=<MseLossBackward>)
tensor(0.3970, grad_fn=<MseLossBackward>)
tensor(0.2750, grad_fn=<MseLossBackward>)
tensor(0.2342, grad_fn=<MseLossBackward>)
tensor(0.3163, grad_fn=<MseLossBackward>)
tensor(0.1890, grad_fn=<MseLossBackward>)
tensor(0.2502, grad_fn=<MseLossBackward>)
tensor(0.2948, grad_fn=<MseLossBackward>)
tensor(0.3161, grad_fn=<MseLossBackward>)
tensor(0.2611, grad_fn=<MseLossBac