In [1]:
import os
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import numpy as np

In [2]:
class SRCNN(nn.Module):
    
    def __init__(self):
        
        super(SRCNN, self).__init__()
    
        self.conv1 = nn.Conv2d(3, 64, 9, 1)
        self.conv2 = nn.Conv2d(64, 32, 1 ,1)
        self.conv3 = nn.Conv2d(32, 3, 5, 1)
    
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [3]:
transform = T.Compose([T.ToTensor()])

In [4]:
class Mydata(Dataset):
    
    def __init__(self, root, transforms=None):
        
        X_imgs = os.listdir(root + '\\tl')
        L_imgs = os.listdir(root + '\\t916')
        self.X_imgs = [os.path.join(root+'\\tl', x) for x in X_imgs]
        self.L_imgs = [os.path.join(root+'\\t916', x) for x in L_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        L_img_path = self.L_imgs[index]
        label = Image.open(L_img_path)
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            label = self.transforms(label)
            
        return data, label
    
    def __len__(self):
        
        return len(self.X_imgs)

In [5]:
mysrcnn = SRCNN().cuda()
root = 'D:\\data\\archive\\train'

mydata = Mydata(root, transform)
mydataloader = DataLoader(mydata, batch_size=16, shuffle=True, num_workers=0)

conv3_params = list(map(id, mysrcnn.conv3.parameters()))
base_params = filter(lambda p: id(p) not in conv3_params,
                     mysrcnn.parameters())
lr = 0.00001
optimizer = optim.SGD([
            {'params': base_params},
            {'params': mysrcnn.conv3.parameters(), 'lr': lr }],
             lr=lr*10, momentum=0.9)

criterion = nn.MSELoss()


In [6]:
epochs = 500

In [None]:
for epoch in range(epochs):
    
    loss = 0.0
    for i, data in enumerate(mydataloader, 0):
        inputs, labels = data
        inputs, labels = V(inputs).cuda(), V(labels).cuda()
        
        optimizer.zero_grad()
        
        outputs = mysrcnn(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        if i == 0:
            print('[%d, %4d] loss: %.4f' % (epoch, i, loss.data.item()))
print('Finished')

[0,    0] loss: 0.2819
[1,    0] loss: 0.0226
[2,    0] loss: 0.0185
[3,    0] loss: 0.0196
[4,    0] loss: 0.0154
[5,    0] loss: 0.0123
[6,    0] loss: 0.0098
[7,    0] loss: 0.0090
[8,    0] loss: 0.0073
[9,    0] loss: 0.0083
[10,    0] loss: 0.0097
[11,    0] loss: 0.0058
[12,    0] loss: 0.0060
[13,    0] loss: 0.0055
[14,    0] loss: 0.0082
[15,    0] loss: 0.0145
[16,    0] loss: 0.0062
[17,    0] loss: 0.0049
[18,    0] loss: 0.0040
[19,    0] loss: 0.0061
[20,    0] loss: 0.0076
[21,    0] loss: 0.0074
[22,    0] loss: 0.0048
[23,    0] loss: 0.0070
[24,    0] loss: 0.0054
[25,    0] loss: 0.0064
[26,    0] loss: 0.0032
[27,    0] loss: 0.0063
[28,    0] loss: 0.0048
[29,    0] loss: 0.0049
[30,    0] loss: 0.0055
[31,    0] loss: 0.0041
[32,    0] loss: 0.0052
[33,    0] loss: 0.0036
[34,    0] loss: 0.0067
[35,    0] loss: 0.0051
[36,    0] loss: 0.0052
[37,    0] loss: 0.0068
[38,    0] loss: 0.0039
[39,    0] loss: 0.0032
[40,    0] loss: 0.0037
[41,    0] loss: 0.0042
[4

[333,    0] loss: 0.0027
[334,    0] loss: 0.0061
[335,    0] loss: 0.0027
[336,    0] loss: 0.0050
[337,    0] loss: 0.0034
[338,    0] loss: 0.0043
[339,    0] loss: 0.0054
[340,    0] loss: 0.0029
[341,    0] loss: 0.0049
[342,    0] loss: 0.0049
[343,    0] loss: 0.0050
[344,    0] loss: 0.0024
[345,    0] loss: 0.0032
[346,    0] loss: 0.0036
[347,    0] loss: 0.0071
[348,    0] loss: 0.0030
[349,    0] loss: 0.0061
[350,    0] loss: 0.0030
[351,    0] loss: 0.0058
[352,    0] loss: 0.0039
[353,    0] loss: 0.0042
[354,    0] loss: 0.0040
[355,    0] loss: 0.0023
[356,    0] loss: 0.0048
[357,    0] loss: 0.0046
[358,    0] loss: 0.0061
[359,    0] loss: 0.0052
[360,    0] loss: 0.0046
[361,    0] loss: 0.0041
[362,    0] loss: 0.0034
[363,    0] loss: 0.0028
[364,    0] loss: 0.0069
[365,    0] loss: 0.0036
[366,    0] loss: 0.0072
[367,    0] loss: 0.0047
[368,    0] loss: 0.0044
[369,    0] loss: 0.0025
[370,    0] loss: 0.0029
[371,    0] loss: 0.0046
[372,    0] loss: 0.0053


In [None]:
class Mytestdata(Dataset):
    
    def __init__(self, transforms=None):
        
        X_imgs = os.listdir('D:\\data\\archive\\Set14\\set14_lr128')
        self.X_imgs = [os.path.join('D:\\data\\archive\\Set14\\set14_lr128', x) for x in X_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            
        return data
    
    def __len__(self):
        
        return len(self.X_imgs)

In [None]:

testdata = Mytestdata(transform)
testloader = DataLoader(testdata, batch_size = 1, shuffle=False, num_workers=0)

show = T.ToPILImage()
for i, data in enumerate(testloader, 0):
    input = V(data)
    mysrcnn = mysrcnn.cpu()
    output = mysrcnn(input)
    img = show(output.squeeze(0))
    path = 'D:\\data\\archive\\test\\SRCNN\\'
    img.save(path+'img'+str(i)+'.png')

In [None]:
(data, label) = mydata[10]
out = mysrcnn(data.unsqueeze(0))
show = T.ToPILImage()
img = show(out.squeeze())
img.save('D:\\data\\archive\\test\\SRCNN\\test.png')