In [1]:
import os
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import numpy as np

In [2]:
class SRCNN(nn.Module):
    
    def __init__(self):
        
        super(SRCNN, self).__init__()
    
        self.conv1 = nn.Conv2d(3, 64, 9, 1)
        self.conv2 = nn.Conv2d(64, 32, 1 ,1)
        self.conv3 = nn.Conv2d(32, 3, 5, 1)
    
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [3]:
transform = T.Compose([T.ToTensor()])

In [4]:
class Mydata(Dataset):
    
    def __init__(self, x_path, l_path, transforms=None):
        
        X_imgs = os.listdir(x_path)
        L_imgs = os.listdir(l_path)
        self.X_imgs = [os.path.join(x_path, x) for x in X_imgs]
        self.L_imgs = [os.path.join(l_path, x) for x in L_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        L_img_path = self.L_imgs[index]
        label = Image.open(L_img_path)
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            label = self.transforms(label)
            
        return data, label
    
    def __len__(self):
        
        return len(self.X_imgs)

In [5]:
class Mytestdata(Dataset):
    
    def __init__(self, test_path, transforms=None):
        
        X_imgs = os.listdir(test_path)
        self.X_imgs = [os.path.join(test_path, x) for x in X_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            
        return data
    
    def __len__(self):
        
        return len(self.X_imgs)

In [6]:
mysrcnn = SRCNN().cuda()
x_path = '/media/sinong/DATA/data/archive/train_lrx4'
l_path = '/media/sinong/DATA/data/archive/train_20'

mydata = Mydata(x_path, l_path, transform)
mydataloader = DataLoader(mydata, batch_size=16, shuffle=True, num_workers=0)

conv3_params = list(map(id, mysrcnn.conv3.parameters()))
base_params = filter(lambda p: id(p) not in conv3_params,
                     mysrcnn.parameters())
lr = 0.00001
optimizer = optim.SGD([
            {'params': base_params},
            {'params': mysrcnn.conv3.parameters(), 'lr': lr }],
             lr=lr*10, momentum=0.9)

criterion = nn.MSELoss()


In [7]:
epochs = 500

In [8]:
mysrcnn = mysrcnn.cuda()
for epoch in range(epochs):
    
    loss = 0.0
    for i, data in enumerate(mydataloader, 0):
        inputs, labels = data
        inputs, labels = V(inputs).cuda(), V(labels).cuda()
        
        optimizer.zero_grad()
        
        outputs = mysrcnn(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        if i == 0:
            print('[%3d, %d] loss: %.4f' % (epoch+1, i, loss.data.item()))
print('Finished')

[  1, 0] loss: 0.2998
[  2, 0] loss: 0.0233
[  3, 0] loss: 0.0251
[  4, 0] loss: 0.0133
[  5, 0] loss: 0.0134
[  6, 0] loss: 0.0110
[  7, 0] loss: 0.0088
[  8, 0] loss: 0.0121
[  9, 0] loss: 0.0121
[ 10, 0] loss: 0.0088
[ 11, 0] loss: 0.0043
[ 12, 0] loss: 0.0057
[ 13, 0] loss: 0.0065
[ 14, 0] loss: 0.0068
[ 15, 0] loss: 0.0037
[ 16, 0] loss: 0.0046
[ 17, 0] loss: 0.0023
[ 18, 0] loss: 0.0039
[ 19, 0] loss: 0.0089
[ 20, 0] loss: 0.0056
[ 21, 0] loss: 0.0058
[ 22, 0] loss: 0.0065
[ 23, 0] loss: 0.0067
[ 24, 0] loss: 0.0050
[ 25, 0] loss: 0.0047
[ 26, 0] loss: 0.0065
[ 27, 0] loss: 0.0066
[ 28, 0] loss: 0.0073
[ 29, 0] loss: 0.0055
[ 30, 0] loss: 0.0061
[ 31, 0] loss: 0.0070
[ 32, 0] loss: 0.0024
[ 33, 0] loss: 0.0045
[ 34, 0] loss: 0.0047
[ 35, 0] loss: 0.0040
[ 36, 0] loss: 0.0039
[ 37, 0] loss: 0.0064
[ 38, 0] loss: 0.0075
[ 39, 0] loss: 0.0078
[ 40, 0] loss: 0.0063
[ 41, 0] loss: 0.0051
[ 42, 0] loss: 0.0046
[ 43, 0] loss: 0.0028
[ 44, 0] loss: 0.0066
[ 45, 0] loss: 0.0017
[ 46, 0] l

[374, 0] loss: 0.0054
[375, 0] loss: 0.0031
[376, 0] loss: 0.0031
[377, 0] loss: 0.0026
[378, 0] loss: 0.0016
[379, 0] loss: 0.0021
[380, 0] loss: 0.0052
[381, 0] loss: 0.0043
[382, 0] loss: 0.0055
[383, 0] loss: 0.0024
[384, 0] loss: 0.0074
[385, 0] loss: 0.0048
[386, 0] loss: 0.0021
[387, 0] loss: 0.0039
[388, 0] loss: 0.0039
[389, 0] loss: 0.0042
[390, 0] loss: 0.0020
[391, 0] loss: 0.0053
[392, 0] loss: 0.0033
[393, 0] loss: 0.0057
[394, 0] loss: 0.0032
[395, 0] loss: 0.0087
[396, 0] loss: 0.0036
[397, 0] loss: 0.0027
[398, 0] loss: 0.0021
[399, 0] loss: 0.0020
[400, 0] loss: 0.0069
[401, 0] loss: 0.0028
[402, 0] loss: 0.0035
[403, 0] loss: 0.0024
[404, 0] loss: 0.0036
[405, 0] loss: 0.0070
[406, 0] loss: 0.0028
[407, 0] loss: 0.0036
[408, 0] loss: 0.0029
[409, 0] loss: 0.0041
[410, 0] loss: 0.0069
[411, 0] loss: 0.0059
[412, 0] loss: 0.0038
[413, 0] loss: 0.0069
[414, 0] loss: 0.0034
[415, 0] loss: 0.0041
[416, 0] loss: 0.0034
[417, 0] loss: 0.0027
[418, 0] loss: 0.0063
[419, 0] l

In [9]:
test_path1 = '/media/sinong/DATA/data/archive/Set14/Set14_lrx4'
test_path2 = ''
testdata = Mytestdata(test_path1, transform)
testloader = DataLoader(testdata, batch_size = 1, shuffle=False, num_workers=0)

show = T.ToPILImage()
for i, data in enumerate(testloader, 0):
    input = V(data)
    mysrcnn = mysrcnn.cpu()
    output = mysrcnn(input)
    img = show(output.squeeze(0))
    path = '/media/sinong/DATA/data/archive/SRCNN_predict/Set14/'
    img.save(path+'img'+str(i)+'.png')