SRCNN training data

In [1]:
import torch
import cv2
import h5py
import numpy as np
from torch import nn,optim
import torch.nn.functional as F
from torchvision import transforms
import torch.utils.data as dataf
import matplotlib.pyplot as plt
import torchvision

In [2]:
file = 'train_data.h5'

In [3]:
 with h5py.File(file, 'r') as hf:
        data = np.array(hf.get('data'))
        label = np.array(hf.get('label'))
        print(data.shape)
        print(label.shape)
        #train_data = np.transpose(data, (0, 2, 3, 1))  #改为(14901,32,32,1)
        #train_label = np.transpose(label, (0, 2, 3, 1))
        

(14901, 32, 32, 3)
(14901, 20, 20, 3)


In [4]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1,128,kernel_size=9),
            nn.ReLU(True),
            nn.Conv2d(128,64,kernel_size=1),
            nn.ReLU(True),
            nn.Conv2d(64,1,kernel_size=5),
            
            
        )
    def forward(self,x):
        out = self.conv(x)
        return out

In [5]:
model = SRCNN()
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr = 0.01)

In [6]:
print(model)

SRCNN(
  (conv): Sequential(
    (0): Conv2d(1, 128, kernel_size=(9, 9), stride=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU(inplace)
    (4): Conv2d(64, 1, kernel_size=(5, 5), stride=(1, 1))
  )
)


In [7]:
from torchsummary import summary
summary(model,(1,28,28))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 20, 20]          10,496
              ReLU-2          [-1, 128, 20, 20]               0
            Conv2d-3           [-1, 64, 20, 20]           8,256
              ReLU-4           [-1, 64, 20, 20]               0
            Conv2d-5            [-1, 1, 16, 16]           1,601
Total params: 20,353
Trainable params: 20,353
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.17
Params size (MB): 0.08
Estimated Total Size (MB): 1.25
----------------------------------------------------------------


In [None]:
epoches = 1  #200

In [None]:
train_data = torch.from_numpy(data) 
train_data = torch.tensor(train_data)
train_label = torch.from_numpy(label)
train_label = torch.tensor(train_label)

print(train_data.shape,train_label.shape)

In [None]:
dataset = dataf.TensorDataset(train_data,train_label)

In [None]:
loader = dataf.DataLoader(dataset,batch_size=32,shuffle=True)

In [None]:
dataiter = iter(loader)
datas,labels = dataiter.next()
npimg = datas[0,:,:,:].numpy()
npimg = np.reshape(npimg,(32,32))
plt.imshow(npimg)
plt.show()

In [None]:
img = torchvision.utils.make_grid(datas)  #将32个batch的MNIST拼成一个图像
npimg = img.numpy()
np.shape(npimg)
plt.imshow(np.transpose(npimg, (1, 2, 0))) # (channel,pixel,pixel)-> (pixel,pixel,channel)
plt.show()

In [None]:
for epoch in range(epoches):
    print('epoch {}'.format(epoch + 1))
    print('*' * 10)
    running_loss = 0.0
    
    for i,data in enumerate(loader,1):
        lr_img,hr_img = data
        #print(type(lr_img),type(hr_img))  <class 'torch.Tensor'>
        # print(lr_img.shape)   torch.Size([32, 1, 32, 32])
        # print(hr_img.shape)   torch.Size([32, 1, 20, 20])
        lr_img = torch.tensor(lr_img,requires_grad = True)
        
        # forward
        out = model(lr_img)
        mse_loss = loss(out,hr_img)
        running_loss += mse_loss.item()
        
        optimizer.zero_grad()
        mse_loss.backward()
        optimizer.step()
        #if i%1000 == 0:
        #    print('[{}/{}] Loss: {:.6f}'.format(
        #        epoch + 1, epoches, running_loss))
    print('Finish {} epoch, Loss: {:.6f}'.format(
        epoch + 1, running_loss ))
    
    

In [None]:
testimg = 'Test/Set14/flowers.bmp'
img = cv2.imread(testimg,cv2.IMREAD_COLOR)
img.shape


In [None]:
img = img[:,:,0]
img = torch.from_numpy(img)
img = torch.tensor(img,requires_grad=True,dtype = torch.float)
img = img.view(1,1,362,500)
out = model(img)

In [None]:
npimg = out[0,:,:,:].detach().numpy()
npimg.shape

In [None]:
npimg = np.reshape(npimg,(350,488))
np.shape(npimg)
plt.imshow(npimg,cmap = 'gray')
plt.show() 