In [16]:
import argparse
import os
import copy
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import numpy as np
from models import SRCNN
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from GLoss import GradientVariance
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr, convert_rgb_to_ycbcr, convert_ycbcr_to_rgb

import torch.optim as optim
import torch.backends.cudnn as cudnn
import PIL.Image as pil_image


cudnn.benchmark = True

In [17]:
parser = argparse.ArgumentParser()
parser.add_argument('--train-file', type=str, required=True)
parser.add_argument('--eval-file', type=str, required=True)
parser.add_argument('--outputs-dir', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--num-epochs', type=int, default=400)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--loss-patch-size', type=int, default=3) 
parser.add_argument('--gradloss-weight', type=float, default=1.0)

args = parser.parse_args([
    '--train-file', 'data/91-image_x2.h5',
    '--eval-file', 'data/Set5_x2.h5',
    '--outputs-dir', 'weights/SRCNN_GELU_x2',
    '--scale', '2',
    '--lr', '0.0001',
    '--batch-size', '16',
    '--num-epochs', '400',
    '--num-workers', '8',
    '--seed', '123',
    '--loss-patch-size', '8'
])

args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))

if not os.path.exists(args.outputs_dir):
    os.makedirs(args.outputs_dir)

cudnn.benchmark = True
print(torch.cuda.is_available())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(args.seed)

model = SRCNN().to(device)

# criterion
criterion = SSIM(data_range=1.0).to(device)
grad_criterion = GradientVariance(patch_size=args.loss_patch_size, cpu=True).to(device)

optimizer = optim.Adam([
    {'params': model.conv1.parameters()},
    {'params': model.conv2.parameters()},
    {'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
], lr=args.lr)

train_dataset = TrainDataset(args.train_file)
train_dataloader = DataLoader(dataset=train_dataset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)
eval_dataset = EvalDataset(args.eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(args.num_epochs):
    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            preds = model(inputs)

            loss_mse =  criterion(preds, labels)
            loss_grad = args.gradloss_weight * grad_criterion(preds, labels)
            loss = loss_mse + loss_grad

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

    torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

    model.eval()
    epoch_psnr = AverageMeter()

    for data in eval_dataloader:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.inference_mode():
            preds = model(inputs).clamp(0.0, 1.0)

        epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

    if epoch_psnr.avg > best_psnr:
        best_epoch = epoch
        best_psnr = epoch_psnr.avg
        best_weights = copy.deepcopy(model.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

False


  0%|          | 0/21904 [00:00<?, ?it/s]

epoch: 0/399:   0%|          | 0/21904 [00:06<?, ?it/s]


RuntimeError: Given groups=1, weight of size [1, 1, 3, 3], expected input[16, 0, 33, 33] to have 1 channels, but got 0 channels instead

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale-down', type=bool, default=False)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args([
    '--weights-file', 'weights/SRCNN_GLoss_x2/best.pth',
    '--image-file', 'data/butterfly_GT.bmp',
    '--scale-down', 'False',
    '--scale', '2'
])

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = SRCNN().to(device)

state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

model.eval()

image = pil_image.open(args.image_file).convert('RGB')

image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
if args.scale_down:
    image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
    image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

image = np.array(image).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(image)

y = ycbcr[..., 0]
y /= 255.
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)

with torch.inference_mode():
    preds = model(y).clamp(0.0, 1.0)

psnr = calc_psnr(y, preds)
print('PSNR: {:.2f}'.format(psnr))

preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
output.save(args.image_file.replace('.', '_srcnn_GLoss_x{}.'.format(args.scale)))