Skip to content

Commit

Permalink
fixed architecture logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Aman Chadha authored and Aman Chadha committed Dec 7, 2019
1 parent ba0d255 commit 6f10314
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
55 changes: 27 additions & 28 deletions iSeeBetterTest.py
Expand Up @@ -3,21 +3,16 @@

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from rbpn import Net as RBPN
from data import get_test_set
from functools import reduce
import numpy as np
import utils
from scipy.misc import imsave
import scipy.io as sio
import time
import cv2
import math
import pdb
import logger

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
Expand All @@ -36,32 +31,33 @@
parser.add_argument('--model_type', type=str, default='RBPN')
parser.add_argument('--residual', type=bool, default=False)
parser.add_argument('--output', default='Results/', help='Location to save checkpoint models')
#parser.add_argument('--model', default='weights/netG_epoch_4_1_APITLoss.pth', help='sr pretrained base model')
parser.add_argument('--model', default='weights/netG_epoch_4_1.pth', help='sr pretrained base model')
#parser.add_argument('--model', default='weights/netG_epoch_4_1.pth', help='sr pretrained base model')
#parser.add_argument('--model', default='weights/RBPN_4x.pth', help='sr pretrained base model')
parser.add_argument('-v', '--debug', default=False, action='store_true', help='Print debug spew.')

opt = parser.parse_args()
args = parser.parse_args()

gpus_list=range(opt.gpus)
print(opt)
gpus_list=range(args.gpus)
print(args)

cuda = opt.gpu_mode
cuda = args.gpu_mode
if cuda:
print("Using GPU mode")
if not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
torch.manual_seed(args.seed)
if cuda:
torch.cuda.manual_seed(opt.seed)
torch.cuda.manual_seed(args.seed)

print('==> Loading datasets')
test_set = get_test_set(opt.data_dir, opt.nFrames, opt.upscale_factor, opt.file_list, opt.other_dataset, opt.future_frame)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
test_set = get_test_set(args.data_dir, args.nFrames, args.upscale_factor, args.file_list, args.other_dataset, args.future_frame)
testing_data_loader = DataLoader(dataset=test_set, num_workers=args.threads, batch_size=args.testBatchSize, shuffle=False)

print('==> Building model ', opt.model_type)
if opt.model_type == 'RBPN':
model = RBPN(num_channels=3, base_filter=256, feat = 64, num_stages=3, n_resblock=5, nFrames=opt.nFrames, scale_factor=opt.upscale_factor)
print('==> Building model ', args.model_type)
if args.model_type == 'RBPN':
model = RBPN(num_channels=3, base_filter=256, feat = 64, num_stages=3, n_resblock=5, nFrames=args.nFrames, scale_factor=args.upscale_factor)

if cuda:
model = torch.nn.DataParallel(model, device_ids=gpus_list)
Expand All @@ -72,12 +68,15 @@
model = model.cuda(gpus_list[0])

def eval():
# Initialize Logger
logger.initLogger(args.debug)

# print iSeeBetter architecture
utils.printNetworkArch(netG=model, netD={})
utils.printNetworkArch(netG=model, netD=None)

# load model
modelPath = os.path.join(opt.model)
utils.loadPreTrainedModel(gpuMode=opt.gpu_mode, model=model, modelPath=modelPath)
modelPath = os.path.join(args.model)
utils.loadPreTrainedModel(gpuMode=args.gpu_mode, model=model, modelPath=modelPath)

model.eval()
count = 0
Expand All @@ -98,14 +97,14 @@ def eval():
flow = [Variable(j).to(device=device, dtype=torch.float) for j in flow]

t0 = time.time()
if opt.chop_forward:
if args.chop_forward:
with torch.no_grad():
prediction = chop_forward(input, neigbor, flow, model, opt.upscale_factor)
prediction = chop_forward(input, neigbor, flow, model, args.upscale_factor)
else:
with torch.no_grad():
prediction = model(input, neigbor, flow)

if opt.residual:
if args.residual:
prediction = prediction + bicubic

t1 = time.time()
Expand All @@ -120,7 +119,7 @@ def eval():
target = target.squeeze().numpy().astype(np.float32)
target = target*255.

psnr_predicted = PSNR(prediction, target, shave_border=opt.upscale_factor)
psnr_predicted = PSNR(prediction, target, shave_border=args.upscale_factor)
print("PSNR Predicted = ", psnr_predicted)
avg_psnr_predicted += psnr_predicted
count += 1
Expand All @@ -131,12 +130,12 @@ def save_img(img, img_name, pred_flag):
save_img = img.squeeze().clamp(0, 1).numpy().transpose(1,2,0)

# save img
save_dir=os.path.join(opt.output, opt.data_dir, os.path.splitext(opt.file_list)[0]+'_'+str(opt.upscale_factor)+'x')
save_dir=os.path.join(args.output, args.data_dir, os.path.splitext(args.file_list)[0]+'_'+str(args.upscale_factor)+'x')
if not os.path.exists(save_dir):
os.makedirs(save_dir)

if pred_flag:
save_fn = save_dir +'/'+ img_name+'_'+opt.model_type+'F'+str(opt.nFrames)+'.png'
save_fn = save_dir +'/'+ img_name+'_'+args.model_type+'F'+str(args.nFrames)+'.png'
else:
save_fn = save_dir +'/'+ img_name+'.png'
cv2.imwrite(save_fn, cv2.cvtColor(save_img*255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0])
Expand All @@ -151,7 +150,7 @@ def PSNR(pred, gt, shave_border=0):
return 100
return 20 * math.log10(255.0 / rmse)

def chop_forward(x, neigbor, flow, model, scale, shave=8, min_size=2000, nGPUs=opt.gpus):
def chop_forward(x, neigbor, flow, model, scale, shave=8, min_size=2000, nGPUs=args.gpus):
b, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
Expand Down
18 changes: 10 additions & 8 deletions utils.py
Expand Up @@ -43,14 +43,16 @@ def _printNetworkArch(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
logger.info(net)
logger.info('Total number of parameters: %d' % num_params)

def printNetworkArch(netG, netD):
logger.info('------------- iSeeBetter Network Architecture -------------')
logger.info('----------------- Generator Architecture ------------------')
_printNetworkArch(netG)

logger.info('--------------- Discriminator Architecture ----------------')
_printNetworkArch(netD)
logger.info('-----------------------------------------------------------')
if netG:
logger.info('----------------- Generator Architecture ------------------')
_printNetworkArch(netG)

if netD:
logger.info('--------------- Discriminator Architecture ----------------')
_printNetworkArch(netD)
logger.info('-----------------------------------------------------------')

0 comments on commit 6f10314

Please sign in to comment.