In [11]:
import os
from os.path import realpath
import torch
from skimage import io
import numpy as np
from util.config import cfg as test_cfg
from data.test_dataset import TestDataset
from util import util
from models.networks import RainNet
from models.normalize import RAIN
import matplotlib.pyplot as plt
from imageio import mimsave
from tqdm import tqdm

%matplotlib inline

In [3]:
def load_network(cfg):
    net = RainNet(input_nc=cfg.input_nc, 
                output_nc=cfg.output_nc, 
                ngf=cfg.ngf, 
                norm_layer=RAIN, 
                use_dropout=not cfg.no_dropout)
    
    load_path = os.path.join(cfg.checkpoints_dir, cfg.name, 'net_G.pth')
    if not os.path.exists(load_path):
        raise FileExistsError(print('%s not exists. Please check the file'%(load_path)))
    print(f'loading the model from {load_path}')
    state_dict = torch.load(load_path)
    util.copy_state_dict(net.state_dict(), state_dict)
    # net.load_state_dict(state_dict)
    return net

def save_img(path, img):
    os.makedirs(os.path.split(path)[0], exist_ok=True)
    io.imsave(path, img)

In [5]:
test_cfg

CfgNode({'dataset_root': '../dataset/iHarmony4', 'dataset_mode': 'iharmony4', 'batch_size': 10, 'beta1': 0.5, 'checkpoints_dir': './checkpoints', 'crop_size': 256, 'load_size': 256, 'num_threads': 11, 'preprocess': 'none', 'gan_mode': 'wgangp', 'model': 'rainnet', 'netG': 'rainnet', 'normD': 'instance', 'normG': 'RAIN', 'is_train': False, 'input_nc': 3, 'output_nc': 3, 'ngf': 64, 'no_dropout': False, 'name': 'experiment_train', 'gpu_ids': 0, 'lambda_L1': 100, 'print_freq': 400, 'continue_train': False, 'load_iter': 0, 'niter': 100, 'niter_decay': 0})

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rainnet = load_network(test_cfg)

loading the model from ./checkpoints/experiment_train/net_G.pth


In [8]:
comp_path = ['examples/1.png', 'examples/2.png', 'examples/3.png']
mask_path = ['examples/1-mask.png', 'examples/2-mask.png', 'examples/3-mask.png']
real_path = ['examples/1-gt.png', 'examples/2-gt.png', 'examples/3-gt.png']
# load the testing set
testdata = TestDataset(foreground_paths=comp_path, mask_paths=mask_path, background_paths=real_path, load_size=256)

total 3 images


In [14]:
rainnet.cuda()

repeat_times = 0 # adjust the foreground image by several times
for idx in tqdm(range(len(testdata))):
    sample = testdata[idx]
    # unsqueeze the data to shape of (1, channel, H, W)
    comp = sample['comp'].unsqueeze(0).to(device)
    mask = sample['mask'].unsqueeze(0).to(device) # if you want to adjust the background to be compatible with the foreground, then add the following command
    # mask = 1 - mask
    real = sample['real'].unsqueeze(0).to(device) # if the real_path is not given, then return composite image by sample['real']
    img_path = sample['img_path']
    pred = rainnet.processImage(comp, mask, real)
    for i in range(repeat_times):
        pred = rainnet.processImage(pred, mask, pred)
        
    # tensor2image
    pred_rgb = util.tensor2im(pred[0:1])
    comp_rgb = util.tensor2im(comp[:1])
    mask_rgb = util.tensor2im(mask[:1])
    real_rgb = util.tensor2im(real[:1])
    print(img_path)
    save_img(img_path.split('.')[0] + '-results.png', np.hstack([comp_rgb, mask_rgb, pred_rgb]))

  0%|                                                                                      | 0/3 [00:00<?, ?it/s]

examples/1.png


 67%|████████████████████████████████████████████████████                          | 2/3 [00:03<00:01,  1.63s/it]

examples/2.png


100%|██████████████████████████████████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.36s/it]

examples/3.png





In [20]:
comp_path = ['dataset/MOCOD_objs/9T234_2/fg_000.png', ]
mask_path = ['dataset/MOCOD_objs/9T234_2/mask_000.png']
bg_path = ['dataset/MOCOD_bgs/bg1.jpg']
testdata = TestDataset(foreground_paths=comp_path, mask_paths=mask_path, background_paths=bg_path, load_size=256)

total 1 images


In [21]:
rainnet.cuda()

repeat_times = 0 # adjust the foreground image by several times
for idx in tqdm(range(len(testdata))):
    sample = testdata[idx]
    # unsqueeze the data to shape of (1, channel, H, W)
    comp = sample['comp'].unsqueeze(0).to(device)
    mask = sample['mask'].unsqueeze(0).to(device) # if you want to adjust the background to be compatible with the foreground, then add the following command
    # mask = 1 - mask
    real = sample['real'].unsqueeze(0).to(device) # if the real_path is not given, then return composite image by sample['real']
    img_path = sample['img_path']
    pred = rainnet.processImage(comp, mask, real)
    for i in range(repeat_times):
        pred = rainnet.processImage(pred, mask, pred)
        
    # tensor2image
    pred_rgb = util.tensor2im(pred[0:1])
    comp_rgb = util.tensor2im(comp[:1])
    mask_rgb = util.tensor2im(mask[:1])
    real_rgb = util.tensor2im(real[:1])
    print(img_path)
    save_img('t/ttt.png', np.hstack([comp_rgb, mask_rgb, pred_rgb]))

  0%|                                                                                      | 0/1 [00:00<?, ?it/s]

dataset/MOCOD_objs/9T234_2/fg_000.png


100%|██████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.57it/s]
