In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.autograd import Variable
from torchvision import transforms
from torch.utils import data
from PIL import Image

In [2]:
import cbam
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN,self).__init__()
        self.patch_extraction=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=5,stride=1,padding=4)
        self.non_linear=nn.Conv2d(in_channels=64,out_channels=32,kernel_size=1,stride=1,padding=0)
        self.reconstruction=nn.Conv2d(in_channels=32,out_channels=3,kernel_size=9,stride=1,padding=2)
    def forward(self,x):
        fm_1=F.relu(self.patch_extraction(x))
        fm_2=F.relu(self.non_linear(fm_1))
        #fm_3=self.cbam(fm_2)
        fm_4=self.reconstruction(fm_2)
        return fm_4

In [3]:
srcnn=SRCNN()
print(srcnn)


SRCNN(
  (patch_extraction): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(4, 4))
  (non_linear): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (reconstruction): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1), padding=(2, 2))
)


In [4]:
import numpy as np
transform=transforms.Compose([transforms.ToTensor(),])
class I91DataSet(data.Dataset):
    def __init__(self,root):
        imgs=os.listdir(root)
        self.imgs=[os.path.join(root,k) for k in imgs]
        self.transforms=transform
    def __getitem__(self, item):
        img_path=self.imgs[item]
        pil_img=Image.open(img_path)
        LR_pil_img=pil_img.resize((11,11),Image.BICUBIC)
        LR_pil_img=LR_pil_img.resize(pil_img.size)
        arr_LR_pil_img=np.array(LR_pil_img).astype(int)
        arr_pil_img=np.array(pil_img).astype(int)
        #print(arr_pil_img.dtype)
        #pil_img.show()
        #print(transforms.ToTensor()(LR_pil_img).dtype)
        #print(transforms.ToTensor()((arr_LR_pil_img-arr_pil_img).astype(np.float32))/255)
        return transforms.ToTensor()(LR_pil_img),transforms.ToTensor()(pil_img)

        #return transforms.ToTensor()(LR_pil_img),transforms.ToTensor()((arr_pil_img-arr_LR_pil_img).astype(np.float32))/255
    def __len__(self):
        return len(self.imgs)

In [5]:
dataSet=I91DataSet('CropData')
print(dataSet.__len__())

3513


In [6]:
optimizer=optim.Adam(srcnn.parameters(),lr=0.0001)
trainLoader=data.DataLoader(dataset=dataSet,batch_size=8,shuffle=True)
criterion=nn.MSELoss()
EPOCH=200
for epoch in range(EPOCH):
    print('------------epoch {}-------------'.format(epoch))
    avg=0
    for i,(x,y) in enumerate(trainLoader):
        batch_x=Variable(x)
        batch_y=Variable(y)
        output=srcnn(batch_x)
        
        loss=criterion(output,batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg+=loss
    print(avg)

------------epoch 0-------------
tensor(8.9957, grad_fn=<AddBackward0>)
------------epoch 1-------------
tensor(1.3880, grad_fn=<AddBackward0>)
------------epoch 2-------------
tensor(1.1069, grad_fn=<AddBackward0>)
------------epoch 3-------------
tensor(0.9727, grad_fn=<AddBackward0>)
------------epoch 4-------------
tensor(0.9013, grad_fn=<AddBackward0>)
------------epoch 5-------------
tensor(0.8703, grad_fn=<AddBackward0>)
------------epoch 6-------------
tensor(0.8451, grad_fn=<AddBackward0>)
------------epoch 7-------------
tensor(0.8330, grad_fn=<AddBackward0>)
------------epoch 8-------------
tensor(0.8191, grad_fn=<AddBackward0>)
------------epoch 9-------------
tensor(0.8126, grad_fn=<AddBackward0>)
------------epoch 10-------------
tensor(0.8028, grad_fn=<AddBackward0>)
------------epoch 11-------------
tensor(0.7986, grad_fn=<AddBackward0>)
------------epoch 12-------------
tensor(0.7936, grad_fn=<AddBackward0>)
------------epoch 13-------------
tensor(0.7864, grad_fn=<Add

KeyboardInterrupt: 

In [7]:
#197*176
pil_test=Image.open('t1.bmp')
pil_test_y=transforms.ToTensor()(pil_test)
pil_test_y=pil_test_y.unsqueeze(0)
LR_pil_test=pil_test.resize((197//3,176//3),Image.BICUBIC)
LR_pil_test=LR_pil_test.resize(pil_test.size)

#LR_pil_test.show()
tensor_test=transforms.ToTensor()(LR_pil_test)
LR_pil_test=LR_pil_test.save('sourcet1.jpg')
tensor_test=tensor_test.unsqueeze(0)
print(tensor_test)
tensor_test_y=srcnn(tensor_test)
#tensor_test_y=tensor_test_y+tensor_test
print(tensor_test_y)
loss=criterion(tensor_test_y,pil_test_y)
print(loss*100)
loss=criterion(tensor_test,pil_test_y)
print(loss*100)
tensor_test_y=tensor_test_y.squeeze(0)
result=transforms.ToPILImage()(tensor_test_y)
result.show()
result.save('resultt1.jpg')

tensor([[[[0.5333, 0.5373, 0.5451,  ..., 0.0784, 0.0824, 0.0824],
          [0.5373, 0.5412, 0.5490,  ..., 0.0784, 0.0824, 0.0824],
          [0.5412, 0.5451, 0.5529,  ..., 0.0784, 0.0824, 0.0824],
          ...,
          [0.2745, 0.2784, 0.2824,  ..., 0.7412, 0.7373, 0.7333],
          [0.2627, 0.2667, 0.2745,  ..., 0.7490, 0.7451, 0.7412],
          [0.2588, 0.2627, 0.2706,  ..., 0.7529, 0.7490, 0.7451]],

         [[0.1804, 0.1843, 0.1922,  ..., 0.0902, 0.0941, 0.0941],
          [0.1843, 0.1882, 0.1961,  ..., 0.0902, 0.0941, 0.0941],
          [0.1961, 0.2000, 0.2039,  ..., 0.0902, 0.0941, 0.0941],
          ...,
          [0.1843, 0.1804, 0.1804,  ..., 0.5647, 0.5647, 0.5647],
          [0.1843, 0.1843, 0.1843,  ..., 0.5725, 0.5765, 0.5765],
          [0.1843, 0.1843, 0.1882,  ..., 0.5765, 0.5804, 0.5804]],

         [[0.5529, 0.5569, 0.5686,  ..., 0.0275, 0.0314, 0.0314],
          [0.5569, 0.5608, 0.5725,  ..., 0.0275, 0.0314, 0.0314],
          [0.5686, 0.5725, 0.5804,  ..., 0

In [9]:
torch.save(srcnn,"srcnn.pkl")
print("123")

123
