In [1]:
import os
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import numpy as np

In [2]:
class SRCNN(nn.Module):
    
    def __init__(self):
        
        super(SRCNN, self).__init__()
    
        self.conv1 = nn.Conv2d(3, 64, 9, 1)
        self.conv2 = nn.Conv2d(64, 32, 1 ,1)
        self.conv3 = nn.Conv2d(32, 3, 5, 1)
    
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [3]:
transform = T.Compose([T.ToTensor()])

In [4]:
class Mydata(Dataset):
    
    def __init__(self, x_path, l_path, transforms=None):
        
        X_imgs = os.listdir(x_path)
        L_imgs = os.listdir(l_path)
        self.X_imgs = [os.path.join(x_path, x) for x in X_imgs]
        self.L_imgs = [os.path.join(l_path, x) for x in L_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        L_img_path = self.L_imgs[index]
        label = Image.open(L_img_path)
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            label = self.transforms(label)
            
        return data, label
    
    def __len__(self):
        
        return len(self.X_imgs)

In [5]:
class Mytestdata(Dataset):
    
    def __init__(self, test_path, transforms=None):
        
        X_imgs = os.listdir(test_path)
        self.X_imgs = [os.path.join(test_path, x) for x in X_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            
        return data
    
    def __len__(self):
        
        return len(self.X_imgs)

In [6]:
mysrcnn = SRCNN().cuda()
x_path = '/media/sinong/DATA/data/archive/train_lrx2'
l_path = '/media/sinong/DATA/data/archive/train_20'

mydata = Mydata(x_path, l_path, transform)
mydataloader = DataLoader(mydata, batch_size=16, shuffle=True, num_workers=0)

conv3_params = list(map(id, mysrcnn.conv3.parameters()))
base_params = filter(lambda p: id(p) not in conv3_params,
                     mysrcnn.parameters())
lr = 0.00001
optimizer = optim.SGD([
            {'params': base_params},
            {'params': mysrcnn.conv3.parameters(), 'lr': lr }],
             lr=lr*10, momentum=0.9)

criterion = nn.MSELoss()


In [7]:
epochs = 2000

In [None]:
mysrcnn = mysrcnn.cuda()
for epoch in range(epochs):
    
    loss = 0.0
    for i, data in enumerate(mydataloader, 0):
        inputs, labels = data
        inputs, labels = V(inputs).cuda(), V(labels).cuda()
        
        optimizer.zero_grad()
        
        outputs = mysrcnn(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        if i == 0:
            print('[%3d, %d] loss: %.4f' % (epoch+1, i, loss.data.item()))
print('Finished')

[  1, 0] loss: 0.2115
[  2, 0] loss: 0.0222
[  3, 0] loss: 0.0137
[  4, 0] loss: 0.0170
[  5, 0] loss: 0.0134
[  6, 0] loss: 0.0107
[  7, 0] loss: 0.0053
[  8, 0] loss: 0.0068
[  9, 0] loss: 0.0086
[ 10, 0] loss: 0.0047
[ 11, 0] loss: 0.0065
[ 12, 0] loss: 0.0074
[ 13, 0] loss: 0.0069
[ 14, 0] loss: 0.0053
[ 15, 0] loss: 0.0038
[ 16, 0] loss: 0.0052
[ 17, 0] loss: 0.0069
[ 18, 0] loss: 0.0047
[ 19, 0] loss: 0.0066
[ 20, 0] loss: 0.0049
[ 21, 0] loss: 0.0041
[ 22, 0] loss: 0.0030
[ 23, 0] loss: 0.0064
[ 24, 0] loss: 0.0051
[ 25, 0] loss: 0.0061
[ 26, 0] loss: 0.0064
[ 27, 0] loss: 0.0037
[ 28, 0] loss: 0.0038
[ 29, 0] loss: 0.0042
[ 30, 0] loss: 0.0045
[ 31, 0] loss: 0.0090
[ 32, 0] loss: 0.0099
[ 33, 0] loss: 0.0024
[ 34, 0] loss: 0.0052
[ 35, 0] loss: 0.0044
[ 36, 0] loss: 0.0023
[ 37, 0] loss: 0.0026
[ 38, 0] loss: 0.0042
[ 39, 0] loss: 0.0051
[ 40, 0] loss: 0.0030
[ 41, 0] loss: 0.0061
[ 42, 0] loss: 0.0033
[ 43, 0] loss: 0.0054
[ 44, 0] loss: 0.0026
[ 45, 0] loss: 0.0014
[ 46, 0] l

In [None]:
test_path1 = '/media/sinong/DATA/data/archive/Set14/Set14_lrx2'
test_path2 = '/media/sinong/DATA/data/archive/Set5/Set5_lrx4'
testdata = Mytestdata(test_path1, transform)
testloader = DataLoader(testdata, batch_size = 1, shuffle=False, num_workers=0)

show = T.ToPILImage()
for i, data in enumerate(testloader, 0):
    input = V(data)
    mysrcnn = mysrcnn.cpu()
    output = mysrcnn(input)
    img = show(output.squeeze(0))
    path = '/media/sinong/DATA/data/archive/SRCNN_predict/Set14/'
    img.save(path+'img'+str(i)+'.png')