In [2]:
5 == 4

False

In [7]:
import os
import torch
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader, Dataset 
from torchvision import transforms
from PIL import Image
import math 

In [14]:
# Copyright 2020 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math

import torch
from torch import nn, Tensor

__all__ = [
    "ESPCN",
    "espcn_x2", "espcn_x3", "espcn_x4", "espcn_x8",
]


class ESPCN(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            channels: int,
            upscale_factor: int,
    ) -> None:
        super(ESPCN, self).__init__()
        hidden_channels = channels // 2
        out_channels = int(out_channels * (upscale_factor ** 2))

        # Feature mapping
        self.feature_maps = nn.Sequential(
            nn.Conv2d(in_channels, channels, (5, 5), (1, 1), (2, 2)),
            nn.Tanh(),
            nn.Conv2d(channels, hidden_channels, (3, 3), (1, 1), (1, 1)),
            nn.Tanh(),
        )

        # Sub-pixel convolution layer
        self.sub_pixel = nn.Sequential(
            nn.Conv2d(hidden_channels, out_channels, (3, 3), (1, 1), (1, 1)),
            nn.PixelShuffle(upscale_factor),
        )

        # Initial model weights
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                if module.in_channels == 32:
                    nn.init.normal_(module.weight.data,
                                    0.0,
                                    0.001)
                    nn.init.zeros_(module.bias.data)
                else:
                    nn.init.normal_(module.weight.data,
                                    0.0,
                                    math.sqrt(2 / (module.out_channels * module.weight.data[0][0].numel())))
                    nn.init.zeros_(module.bias.data)

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

    # Support torch.script function.
    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.feature_maps(x)
        x = self.sub_pixel(x)

        x = torch.clamp_(x, 0.0, 1.0)

        return x


def espcn_x2(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=2, **kwargs)

    return model


def espcn_x3(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=3, **kwargs)

    return model


def espcn_x4(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=4, **kwargs)

    return model


def espcn_x8(**kwargs) -> ESPCN:
    model = ESPCN(upscale_factor=8, **kwargs)

    return model

In [None]:
class ImageDataset(Dataset):
    def __init__(self,image_dir, upscale_factor, mode = 'train'):
        self.image_dir - image_dir
        self.upscale_factor = upscale_factor
        self.mode = mode 
        self.image_filemname = [os.path.join(image_dir,x) for x in os.listdir(image_dir)]
        self.transform = transforms.Compose([transforms.ToTesnor(),transform.Normalize(mean = [0.5,0.5,0.5], std = [0.5,0.5,0.5])])

    def __len__(self):
        return len(self.image_filename)

    def __getitem__(self,idx):
        image = Image.open(self.image_filenames[idx]).convert('RGB')        
        if self.mode == 'train':
            lr_image = image.resize((image.size[0]//self.upscale_factor,image.size[1]//self.upscale_factor),Image.BICUBIC)
            hr_image = image
        else:
            lr_image = image
            hr_image = None

        lr_image = self.transform(lr_image)        
        if hr_image is not Nonne:
            hr_image = self.transform(hr_image)
        else:
            return lr_image



In [15]:
def train(model, train_oader,criterion,optimizer,device):
    model.train()
    running_loss = 0.0
    for lr_image,hr_images in train_loader:
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)

        optimizer.zero_grad()
        output = model(lr_images)
        loss = criterion(outputs,hr_images)
        loss.backward()
        optimizer.step()


def test(model, test_loader,device):
    model.eval()
    torch_psnr = 0.0
    with torch.no_grad():
        for lr_images in test_loader:
            lr_images = lr_images.to(device)        
            outputs = model(lr_images)
            outputs = outputs.clamp(0.0,1.0)
            total_psnr += calculate_psnr(outputs,lr_images)

    return total_psnr/len(test_holder)        


def calculate_psnr(img1,img2):
    mse = torch.mean((img1 - img)**2)
    if mse == 0:
        return float('inf')    
    return 20*torch.log10(1.0/torch.sqrt(mse))    


def save_model(model,path):
    torch.save(model.state_dict(),path)

def load_model(model,path,device):
    model.load_state_dict(torch.load(path,map_location=device)) 


In [16]:
def main():
    upscale_factor = 2
    in_channels = 3
    out_channels = 3
    channels = 64
    batch_size = 16
    num_epochs = 50
    learning_rate = 0.001
    train_image_dir = r'C:\Users\91995\Downloads\div2k\DIV2K_train_HR\DIV2K_train_HR'
    test_image_dir = r'C:\Users\91995\OneDrive\Desktop\ESCPN++\Set14\Set14_4'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = ESPCN(in_channels,out_channels,channels,upscale_factor).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(),lr = learning_rate)

    train_dataset = ImageDataset(train_image_dir,upscale_factor,mode = 'train')
    test_dataset = ImageDataset(test_image_dir,upscale_factor,mode = 'test')
    train_loader = DataLoader(train_dataset,batch_size = 1,shuffle = True)
    test_loader = DataLoader(test_dataset,batch_size = 1,shuffle = False)


    for epoch in range(num_epochs):
        train_loss = train(model, tran_loader,criterion,optimizer,device)

        if(epoch+1)%10 == 0:
            avg_psnr = test(model,test_loader,device)
            print(f'PSNR : {avg_psnr:.2f}dB')
            save_model(model,f'espcn_x{upscale_factor}_epoch_{epoch + 1}.pth')    

        

In [17]:
if __name__ == 'main':
    main()