In [37]:
import torch
from torch import nn
from model import SRResNet
#from dataset import dataSet
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
import numpy as np


In [38]:
class ResBlock(nn.Module):
    """Residual Bloack part"""
    def __init__(self,inputC,outputC):
        super(ResBlock,self).__init__()
        self.conv1 = nn.Conv2d(inputC,outputC,kernel_size=1,bias=False)
        self.bn1 = nn.BatchNorm2d(outputC)
        self.conv2 = nn.Conv2d(outputC,outputC,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(outputC)
        self.conv3 = nn.Conv2d(outputC,outputC,kernel_size=1,bias=False)
        self.relu = nn.PReLU()
        
    def forward(self,x):
        resudial = x 
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(x)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(x)
        
        out += resudial
        out = self.relu(out)
        return out


In [39]:
class SRResNet(nn.Module):
    def __init__(self, num_blocks = 16):
        
        super(SRResNet,self).__init__()
        
        #in channel=3,out_channels=64
        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4,padding_mode='reflect',stride=1)
        self.relu = nn.PReLU()
        #self.res_block = nn.Sequential(*[ResidualBlock(kernel_size=3, out_channels=64) for i in range(num_blocks)])
        resBlockLayer = 16
        self.res_block = self._makeLayer_(64,64,resBlockLayer)
        
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.PReLU()
        
        #subpixel convolution
        self.subConv1 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=2,padding_mode='reflect')
        self.shuffle1 = nn.PixelShuffle(2)
        self. reluSub1 = nn.PReLU()
        
        self.subConv2 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=1,padding_mode='reflect')
        self.shuffle2 = nn.PixelShuffle(2)
        self. reluSub2 = nn.PReLU()
        
        
        
        #in channel=64,out_channels=3
        self.conv3 = nn.Conv2d(64,3,kernel_size=9,stride=1)
        
    def _makeLayer_(self,inputC,outputC,blocks):
        layers = []
        layers.append(ResBlock(inputC,outputC))
        
        for i in range(1,blocks):
            layers.append(ResBlock(outputC,outputC))
        
        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.relu(output)
        residual = output
        
        output - self.res_block(output)
        output = self.conv2(output)
        output = self.bn2(output)
        output += residual
        
        output = self.subConv1(output)   
        output = self.shuffle1(output)
        output = self.reluSub1(output)
        
        output = self.subConv2(output)   
        output = self.shuffle2(output)
        output = self.reluSub2(output)
        output = self.conv3(output)
        
        return output

In [40]:
from torch.utils.data import Dataset
import os
import json
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([transforms.RandomCrop(96), transforms.ToTensor()]) 
    
    
class dataSet(Dataset):
    
    def __init__(self, dataPath, dataType,transforms = transform):
        """
        :dataPath = json file
        :dataType = train or test
        
        """
        
        self.dataPath = dataPath
        self.dataType = dataType
        self.transforms = transform
        
        if self.dataType == 'train':
            with open(os.path.join(dataPath, 'train_images.json'), 'r') as f:
                self.imgs = json.load(f)
        else:
            with open(os.path.join(dataPath, 'test_images.json'), 'r') as f:
                self.imgs = json.load(f)

    #get img's lenth  
    def __len__(self):
        return len(self.imgs)
    
    #get item
    def __getitem__(self,index):
       
        img = Image.open(self.imgs[index], mode='r')
        img = img.convert('RGB')
        result = self.transforms(img)  #对原始图像进行处理
        cropResult = torch.nn.MaxPool2d(4)(result)
 
        return result, cropResult




In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataPath = './output/'
writer = SummaryWriter()

def run():
    SRnet = SRResNet()
    
    SRnet =  SRnet.to(device)
    loss_function = nn.MSELoss().to(device) 
    optimizer = optim.Adam(SRnet.parameters(),lr=0.001) 
    
    train_dataset = dataSet(dataPath=dataPath, dataType='train')
    train_loader = DataLoader(train_dataset,batch_size=64)
    
    Epoches = 50
    for epoch in range(Epoches):
        SRnet.train()
        total_loss = 0.0
    
    
        #n_iter = len(train_loader)
        #print(n_iter)
        
        for i,(result, cropResult) in tqdm(enumerate(train_loader,1)):
            result = result.to(device)
            cropResult = cropResult.to(device)
            #print(i)
            
            #forward 
            next_imgs = SRnet(cropResult)
            
            #calculate loss
            loss = loss_function(next_imgs,result)
            #backward
            optimizer.zero_grad()
            loss.backward()
            #renew model
            optimizer.step()
            
            total_loss += loss.item()
            
            
            if i == 4:
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_1', make_grid(cropResult[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_2', make_grid(next_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_3', make_grid(result[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
            
            


            
        average_lose = total_loss/(i+1)
        writer.add_scalar('SRResNet/MSE_Loss', average_lose, epoch)   

        
        #delete memory
        del result, cropResult, next_imgs
        
    writer.close()   
    print("trainning end...")




In [44]:
run()

4it [00:32,  8.03s/it]
4it [00:32,  8.12s/it]
4it [00:31,  7.99s/it]
4it [00:32,  8.08s/it]
4it [00:31,  7.87s/it]
4it [00:31,  7.99s/it]
4it [00:31,  7.94s/it]
4it [00:31,  7.94s/it]
4it [00:32,  8.01s/it]
4it [00:32,  8.02s/it]
4it [00:31,  7.96s/it]
4it [00:32,  8.02s/it]
4it [00:31,  7.93s/it]
4it [00:32,  8.01s/it]
4it [00:31,  7.89s/it]
4it [00:31,  7.98s/it]
4it [00:31,  7.91s/it]
4it [00:31,  7.91s/it]
4it [00:32,  8.05s/it]
4it [00:32,  8.05s/it]
4it [00:31,  8.00s/it]
4it [00:31,  7.99s/it]
4it [00:31,  8.00s/it]
4it [00:31,  7.94s/it]
4it [00:32,  8.11s/it]
4it [00:32,  8.05s/it]
4it [00:32,  8.08s/it]
4it [00:32,  8.00s/it]
4it [00:32,  8.06s/it]
4it [00:32,  8.08s/it]
4it [00:32,  8.10s/it]
4it [00:32,  8.07s/it]
4it [00:31,  7.98s/it]
4it [00:31,  7.98s/it]
4it [00:31,  7.94s/it]
4it [00:32,  8.23s/it]
4it [00:31,  7.92s/it]
4it [00:31,  7.92s/it]
4it [00:31,  7.92s/it]
4it [00:32,  8.07s/it]
4it [00:31,  7.96s/it]
4it [00:31,  7.90s/it]
4it [00:31,  7.94s/it]
4it [00:31,

trainning end...



