In [1]:
import os
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import numpy as np

In [2]:
class SRCNN(nn.Module):
    
    def __init__(self):
        
        super(SRCNN, self).__init__()
    
        self.conv1 = nn.Conv2d(3, 64, 9, 1)
        self.conv2 = nn.Conv2d(64, 32, 1 ,1)
        self.conv3 = nn.Conv2d(32, 3, 5, 1)
    
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [3]:
transform = T.Compose([T.ToTensor()])

In [4]:
class Mydata(Dataset):
    
    def __init__(self, root, transforms=None):
        
        X_imgs = os.listdir(root + '\\tl')
        L_imgs = os.listdir(root + '\\t916')
        self.X_imgs = [os.path.join(root+'\\tl', x) for x in X_imgs]
        self.L_imgs = [os.path.join(root+'\\t916', x) for x in L_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        L_img_path = self.L_imgs[index]
        label = Image.open(L_img_path)
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            label = self.transforms(label)
            
        return data, label
    
    def __len__(self):
        
        return len(self.X_imgs)

In [5]:
mysrcnn = SRCNN().cuda()
root = 'D:\\data\\archive\\train'

mydata = Mydata(root, transform)
mydataloader = DataLoader(mydata, batch_size=16, shuffle=True, num_workers=0)

conv3_params = list(map(id, mysrcnn.conv3.parameters()))
base_params = filter(lambda p: id(p) not in conv3_params,
                     mysrcnn.parameters())
lr = 0.00001
optimizer = optim.SGD([
            {'params': base_params},
            {'params': mysrcnn.conv3.parameters(), 'lr': lr }],
             lr=lr*10, momentum=0.9)

criterion = nn.MSELoss()


In [6]:
epochs = 60

In [7]:
for epoch in range(epochs):
    
    loss = 0.0
    for i, data in enumerate(mydataloader, 0):
        inputs, labels = data
        inputs, labels = V(inputs).cuda(), V(labels).cuda()
        
        optimizer.zero_grad()
        
        outputs = mysrcnn(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        if i % 300 == 299:
            print('[%d, %4d] loss: %.4f' % (epoch+1, i+1, loss.data.item()))
print('Finished')

[1,  300] loss: 0.1064
[1,  600] loss: 0.0379
[1,  900] loss: 0.0271
[1, 1200] loss: 0.0274
[2,  300] loss: 0.0209
[2,  600] loss: 0.0168
[2,  900] loss: 0.0197
[2, 1200] loss: 0.0190
[3,  300] loss: 0.0330
[3,  600] loss: 0.0206
[3,  900] loss: 0.0265
[3, 1200] loss: 0.0208
[4,  300] loss: 0.0160
[4,  600] loss: 0.0190
[4,  900] loss: 0.0153
[4, 1200] loss: 0.0173
[5,  300] loss: 0.0115
[5,  600] loss: 0.0086
[5,  900] loss: 0.0104
[5, 1200] loss: 0.0134
[6,  300] loss: 0.0160
[6,  600] loss: 0.0077
[6,  900] loss: 0.0121
[6, 1200] loss: 0.0097
[7,  300] loss: 0.0092
[7,  600] loss: 0.0140
[7,  900] loss: 0.0057
[7, 1200] loss: 0.0105
[8,  300] loss: 0.0093
[8,  600] loss: 0.0064
[8,  900] loss: 0.0088
[8, 1200] loss: 0.0114
[9,  300] loss: 0.0092
[9,  600] loss: 0.0082
[9,  900] loss: 0.0113
[9, 1200] loss: 0.0068
[10,  300] loss: 0.0059
[10,  600] loss: 0.0097
[10,  900] loss: 0.0135
[10, 1200] loss: 0.0044
[11,  300] loss: 0.0061
[11,  600] loss: 0.0070
[11,  900] loss: 0.0078
[11,

In [8]:
class Mytestdata(Dataset):
    
    def __init__(self, transforms=None):
        
        X_imgs = os.listdir('D:\\data\\archive\\Set14\\set14_lr128')
        self.X_imgs = [os.path.join('D:\\data\\archive\\Set14\\set14_lr128', x) for x in X_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            
        return data
    
    def __len__(self):
        
        return len(self.X_imgs)

In [9]:

testdata = Mytestdata(transform)
testloader = DataLoader(testdata, batch_size = 1, shuffle=False, num_workers=0)

show = T.ToPILImage()
for i, data in enumerate(testloader, 0):
    input = V(data)
    mysrcnn = mysrcnn.cpu()
    output = mysrcnn(input)
    img = show(output.squeeze(0))
    path = 'D:\\data\\archive\\test\\SRCNN\\'
    img.save(path+'img'+str(i)+'.png')

In [None]:
(data, label) = mydata[10]
out = mysrcnn(data.unsqueeze(0))
show = T.ToPILImage()
img = show(out.squeeze())
img.save('D:\\data\\archive\\test\\SRCNN\\test.png')