In [None]:
"""
Copyright (C) 2018 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

from __future__ import print_function
from utils import get_config, get_data_loader_folder
from trainer import MUNIT_Trainer, UNIT_Trainer
import argparse
from torch.autograd import Variable
import torchvision.utils as vutils
import sys
import torch
import os
from torchvision import transforms
from PIL import Image

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"
def get_res(img_path):
    
    dev = 0
    params = {'config' : './configs/unit_gta2city_list.yaml',
              'input' : img_path,
              'output_folder' : './outputs/gta2city',
              'checkpoint' : './models/unit_gta2city.pt',
              'style' : '',
              'a2b' : 1,
              'seed' : 10,
              'num_style' : 10,
              'synchronized' : True,
              'output_only' : True,
              'output_path' : './',
              'trainer' : 'UNIT'

    }
    
    torch.manual_seed(params['seed'])
    torch.cuda.manual_seed(params['seed'])
    if not os.path.exists(params['output_folder']):
        os.makedirs(params['output_folder'])

    # Load experiment setting
    config = get_config(params['config'])
    params['num_style'] = 1 if params['style'] != '' else params['num_style']

    # Setup model and data loader
    config['vgg_model_path'] = params['output_path']
    if params['trainer'] == 'MUNIT':
        style_dim = config['gen']['style_dim']
        trainer = MUNIT_Trainer(config)
    elif params['trainer'] == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    state_dict = torch.load(params['checkpoint'])
    trainer.gen_a.load_state_dict(state_dict['a'])
    trainer.gen_b.load_state_dict(state_dict['b'])
    trainer.cuda(dev)
    trainer.eval()
    encode = trainer.gen_a.encode if params['a2b'] else trainer.gen_b.encode # encode function
    style_encode = trainer.gen_b.encode if params['a2b'] else trainer.gen_a.encode # encode function
    decode = trainer.gen_b.decode if params['a2b'] else trainer.gen_a.decode # decode function

    if 'new_size' in config:
        new_size = config['new_size']
    else:
        if params['a2b']==1:
            new_size = config['new_size_a']
        else:
            new_size = config['new_size_b']

    transform = transforms.Compose([transforms.Resize(new_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    image = Variable(transform(Image.open(params['input']).convert('RGB')).unsqueeze(0).cuda(dev), volatile=True)
    style_image = Variable(transform(Image.open(params['style']).convert('RGB')).unsqueeze(0).cuda(dev), volatile=True) if params['style'] != '' else None

    # Start testing
    content, _ = encode(image)

    if params['trainer'] == 'MUNIT':
        style_rand = Variable(torch.randn(params['num_style'], style_dim, 1, 1).cuda(dev), volatile=True)
        if params['style'] != '':
            _, style = style_encode(style_image)
        else:
            style = style_rand
        for j in range(params['num_style']):
            s = style[j].unsqueeze(0)
            outputs = decode(content, s)
            outputs = (outputs + 1) / 2.
            path = os.path.join(params['output_folder'], 'output{:03d}.jpg'.format(j))
            vutils.save_image(outputs.data, path, padding=0, normalize=True)
    elif params['trainer'] == 'UNIT':
        
        outputs = decode(content)
        outputs = (outputs + 1) / 2.
        #path = os.path.join(params['output_folder'],'output_'+img_path.split('/')[-1])

        vutils.save_image(outputs.data, img_path, padding=0, normalize=True)
    else:
        pass

    if not params['output_only']:
        # also save input images
        vutils.save_image(image.data, os.path.join(params['output_folder'], 'input_'+img_path.split('/')[-1]), padding=0, normalize=True)

In [None]:
path = '/content/RGB/'

In [None]:
def get_file_paths(path):
    paths = [os.path.join(root, file)  for root, dirs, files in os.walk(path) for file in files]
    return paths

In [None]:
all_file_paths = get_file_paths(path)

In [None]:
c = 1
for i in all_file_paths[2500:5000]:
    print (str(c) + ' :: loading ... : ' + i)
    get_res(i)
    c=c+1