In [None]:
import h5py
import numpy as np
from torch.utils.data import Dataset


class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


class EvalDataset(Dataset):
    def __init__(self, h5_file):
        super(EvalDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

In [None]:
from torch import nn


class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [None]:
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

def convert_rgb_to_y(img):
    if type(img) == np.ndarray:
        return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
    else:
        raise Exception('Unknown Type', type(img))


def convert_rgb_to_ycbcr(img):
    if type(img) == np.ndarray:
        y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
        cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
        cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
        return np.array([y, cb, cr]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
        cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
        cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
        return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))


def convert_ycbcr_to_rgb(img):
    if type(img) == np.ndarray:
        r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
        g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
        b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
        return np.array([r, g, b]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
        g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
        b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
        return torch.cat([r, g, b], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
def create_loss_model(vgg, end_layer, use_maxpool=True, use_cuda=False):

    vgg = copy.deepcopy(vgg)

    model = nn.Sequential()

    #if use_cuda:
        #model.cuda(device_id=0)

    i = 0
    for layer in list(vgg):

        if i > end_layer:
            break

        if isinstance(layer, nn.Conv2d):
            name = "conv_" + str(i)
            model.add_module(name, layer)

        if isinstance(layer, nn.ReLU):
            name = "relu_" + str(i)
            model.add_module(name, layer)

        if isinstance(layer, nn.MaxPool2d):
            name = "pool_" + str(i)
            if use_maxpool:
                model.add_module(name, layer)
            else:
                avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
                model.add_module(name, avgpool)
        i += 1
    return model



In [None]:

import os
import copy
import h5py
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import torch
from torch import nn
import torchvision.models as models




if __name__ == '__main__':


    seed = 123
    learn_rate = 1e-4
    tf = "/content/drive/My Drive/DataSet/91-image_x3.h5"
    batch = 8
    ef = "/content/drive/My Drive/DataSet/Set5_x3.h5"
    ne = 100
    out_dir = "/content/output"
    scale = 3
    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(seed)

    model = SRCNN().to(device)
    criterion = nn.MSELoss()
    vgg16 = models.vgg16(pretrained=True).features
    
    #vgg16.cuda(device_id=0);

    vgg_loss = create_loss_model(vgg16, 8)

    for param in vgg_loss.parameters():
        param.requires_grad = True
    optimizer = optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': learn_rate * 0.1}
    ], lr=learn_rate)

    train_dataset = TrainDataset(tf)
    print("here")
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True,
                                  drop_last=True)
    eval_dataset = EvalDataset(ef)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0
#mean = 0.485,0.456,0.406
#std = 0.229,0.224,0.225
    for epoch in range(ne):
        model.train()
        epoch_losses = AverageMeter()
# Mean and std deviation for inputs
        with tqdm(total=(len(train_dataset) - len(train_dataset) % batch)) as t:
            t.set_description('epoch: {}/{}'.format(epoch, ne - 1))

            for data in train_dataloader:
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)

               
                inputs = torch.cat((inputs,inputs,inputs),dim=1)                
                vgg_loss_inp = vgg_loss(inputs)
                labels = torch.cat((labels,labels,labels),dim=1)
                vgg_loss_tgt = vgg_loss(labels)
                loss = criterion(vgg_loss_inp, vgg_loss_tgt)
                
            new loss = 

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                t.update(len(inputs))

        torch.save(model.state_dict(), os.path.join(out_dir, 'epoch_{}.pth'.format(epoch)))

        model.eval()
        epoch_psnr = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs).clamp(0.0, 1.0)

            epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(out_dir, 'best.pth'))



Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))




epoch: 0/99:   0%|          | 0/21880 [00:00<?, ?it/s]

here


epoch: 0/99: 100%|██████████| 21880/21880 [15:57<00:00, 22.84it/s, loss=0.399684]
epoch: 1/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 1/99: 100%|██████████| 21880/21880 [10:10<00:00, 35.85it/s, loss=0.399573]
epoch: 2/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 2/99: 100%|██████████| 21880/21880 [10:13<00:00, 35.66it/s, loss=0.399698]
epoch: 3/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 3/99: 100%|██████████| 21880/21880 [10:12<00:00, 35.74it/s, loss=0.399719]
epoch: 4/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 4/99: 100%|██████████| 21880/21880 [10:06<00:00, 36.09it/s, loss=0.399677]
epoch: 5/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 5/99: 100%|██████████| 21880/21880 [10:07<00:00, 36.00it/s, loss=0.399666]
epoch: 6/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 6/99: 100%|██████████| 21880/21880 [10:08<00:00, 35.95it/s, loss=0.399699]
epoch: 7/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 7/99: 100%|██████████| 21880/21880 [10:07<00:00, 36.03it/s, loss=0.399673]
epoch: 8/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 8/99: 100%|██████████| 21880/21880 [10:24<00:00, 35.05it/s, loss=0.399662]
epoch: 9/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 9/99: 100%|██████████| 21880/21880 [10:33<00:00, 34.53it/s, loss=0.399672]
epoch: 10/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 10/99: 100%|██████████| 21880/21880 [10:28<00:00, 34.83it/s, loss=0.399644]
epoch: 11/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 11/99: 100%|██████████| 21880/21880 [10:35<00:00, 34.45it/s, loss=0.399699]
epoch: 12/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 12/99: 100%|██████████| 21880/21880 [10:45<00:00, 33.88it/s, loss=0.399710]
epoch: 13/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 13/99: 100%|██████████| 21880/21880 [10:19<00:00, 35.30it/s, loss=0.399698]
epoch: 14/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 14/99: 100%|██████████| 21880/21880 [10:14<00:00, 35.61it/s, loss=0.399673]
epoch: 15/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 15/99: 100%|██████████| 21880/21880 [10:15<00:00, 35.57it/s, loss=0.399711]
epoch: 16/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 16/99: 100%|██████████| 21880/21880 [10:12<00:00, 35.71it/s, loss=0.399696]
epoch: 17/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 17/99: 100%|██████████| 21880/21880 [10:11<00:00, 35.80it/s, loss=0.399719]
epoch: 18/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 18/99: 100%|██████████| 21880/21880 [10:13<00:00, 35.65it/s, loss=0.399685]
epoch: 19/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 19/99: 100%|██████████| 21880/21880 [10:10<00:00, 35.82it/s, loss=0.399677]
epoch: 20/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 20/99: 100%|██████████| 21880/21880 [10:15<00:00, 35.58it/s, loss=0.399691]
epoch: 21/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 21/99: 100%|██████████| 21880/21880 [10:12<00:00, 35.72it/s, loss=0.399692]
epoch: 22/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 22/99: 100%|██████████| 21880/21880 [10:12<00:00, 35.72it/s, loss=0.399673]
epoch: 23/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 23/99: 100%|██████████| 21880/21880 [10:11<00:00, 35.75it/s, loss=0.399699]
epoch: 24/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 24/99: 100%|██████████| 21880/21880 [10:20<00:00, 35.24it/s, loss=0.399718]
epoch: 25/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 25/99: 100%|██████████| 21880/21880 [10:22<00:00, 35.14it/s, loss=0.399614]
epoch: 26/99:   0%|          | 0/21880 [00:00<?, ?it/s]

eval psnr: 6.67


epoch: 26/99:  42%|████▏     | 9104/21880 [04:17<06:03, 35.13it/s, loss=0.397980]

TEST (UNMODIFIED)

In [None]:


import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

from models import SRCNN
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights-file', type=str, required=True)
    parser.add_argument('--image-file', type=str, required=True)
    parser.add_argument('--scale', type=int, default=3)
    args = parser.parse_args()

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = SRCNN().to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    image = pil_image.open(args.image_file).convert('RGB')

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale
    image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
    image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)

    psnr = calc_psnr(y, preds)
    print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))