In [1]:
import os
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import numpy as np

In [2]:
class SRCNN(nn.Module):
    
    def __init__(self):
        
        super(SRCNN, self).__init__()
    
        self.conv1 = nn.Conv2d(3, 64, 9, 1)
        self.conv2 = nn.Conv2d(64, 32, 1 ,1)
        self.conv3 = nn.Conv2d(32, 3, 5, 1)
    
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [3]:
transform = T.Compose([T.ToTensor(),
                      T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])

In [4]:
class Mydata(Dataset):
    
    def __init__(self, root, transforms=None):
        
        X_imgs = os.listdir(root + '\\lr')
        L_imgs = os.listdir(root + '\\hr6')
        self.X_imgs = [os.path.join(root+'\\lr', x) for x in X_imgs]
        self.L_imgs = [os.path.join(root+'\\hr6', x) for x in L_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        L_img_path = self.L_imgs[index]
        label = Image.open(L_img_path)
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            label = self.transforms(label)
            
        return data, label
    
    def __len__(self):
        
        return len(self.X_imgs)

In [5]:
mysrcnn = SRCNN()
root = 'D:\\data\\archive\\train'

mydata = Mydata(root, transform)
mydataloader = DataLoader(mydata, batch_size=4, shuffle=True, num_workers=0)

criterion = nn.MSELoss()
optimizer = optim.SGD(mysrcnn.parameters(), lr=0.01, momentum=0.9, weight_decay=0.01)

In [6]:
epochs = 20

In [7]:
for epoch in range(epochs):
    
    loss = 0.0
    for i, data in enumerate(mydataloader, 0):
        inputs, labels = data
        inputs, labels = V(inputs), V(labels)
        
        optimizer.zero_grad()
        
        outputs = mysrcnn(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        if i % 100 == 0:
            print('[%d, %3d] loss: %.3f' % (epoch+1, i+1, loss.data.item()))
print('Finished')

[1,   1] loss: 0.376
[1, 101] loss: 0.016
[2,   1] loss: 0.012
[2, 101] loss: 0.015
[3,   1] loss: 0.014
[3, 101] loss: 0.011
[4,   1] loss: 0.006
[4, 101] loss: 0.024
[5,   1] loss: 0.004
[5, 101] loss: 0.008
[6,   1] loss: 0.008
[6, 101] loss: 0.012
[7,   1] loss: 0.016
[7, 101] loss: 0.006
[8,   1] loss: 0.011
[8, 101] loss: 0.009
[9,   1] loss: 0.009
[9, 101] loss: 0.018
[10,   1] loss: 0.011
[10, 101] loss: 0.004
[11,   1] loss: 0.019
[11, 101] loss: 0.009
[12,   1] loss: 0.012
[12, 101] loss: 0.015
[13,   1] loss: 0.005
[13, 101] loss: 0.016
[14,   1] loss: 0.018
[14, 101] loss: 0.007
[15,   1] loss: 0.010
[15, 101] loss: 0.012
[16,   1] loss: 0.022
[16, 101] loss: 0.010
[17,   1] loss: 0.012
[17, 101] loss: 0.018
[18,   1] loss: 0.006
[18, 101] loss: 0.016
[19,   1] loss: 0.016
[19, 101] loss: 0.017
[20,   1] loss: 0.003
[20, 101] loss: 0.022
Finished


In [8]:
class Mytestdata(Dataset):
    
    def __init__(self, root, transforms=None):
        
        X_imgs = os.listdir(root + '\\lr')
        self.X_imgs = [os.path.join(root+'\\lr', x) for x in X_imgs]
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        X_img_path = self.X_imgs[index]
        data = Image.open(X_img_path)
        
        if self.transforms:
            data = self.transforms(data)
            
        return data
    
    def __len__(self):
        
        return len(self.X_imgs)

In [9]:
root2 = 'D:\\data\\archive\\test'
testdata = Mytestdata(root2, transform)
testloader = DataLoader(testdata, batch_size = 1, shuffle=False, num_workers=0)

show = T.ToPILImage()
for i, data in enumerate(testloader, 0):
    input = V(data)
    output = mysrcnn(input)
    img = show((output.squeeze()+1)/2).resize((80, 80))
    path = 'D:\\data\\archive\\test\\SRCNN\\'
    img.save(path+'img'+str(500+i)+'.png')

In [18]:
(data, label) = mydata[100]
out = mysrcnn(data.unsqueeze(0))
show = T.ToPILImage()
img = show((out.squeeze()+1)/2).resize((80, 80))
img.save('D:\\data\\archive\\test\\SRCNN\\test.png')