In [1]:
import time, itertools
from dataset import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from networks import *
from utils import *
from glob import glob

class UGATIT(object) :
    def __init__(self, args):
        self.light = args.light

        if self.light :
            self.model_name = 'UGATIT_light'
        else :
            self.model_name = 'UGATIT'
        self.result_dir = args.result_dir
        self.dataset = args.dataset
        """ Weight """
#         self.adv_weight = args.adv_weight
#         self.cycle_weight = args.cycle_weight
#         self.identity_weight = args.identity_weight
#         self.cam_weight = args.cam_weight

        """ Generator """
        self.n_res = args.n_res
        self.ch = args.ch
        self.img_size = args.img_size
        self.img_ch = args.img_ch

        self.device = args.device
#         self.benchmark_flag = args.benchmark_flag
#         self.resume = args.resume

#         if torch.backends.cudnn.enabled and self.benchmark_flag:
#             print('set benchmark !')
#             torch.backends.cudnn.benchmark = True

        print()

        print("##### Information #####")
        print("# light : ", self.light)
#         print("# dataset : ", self.dataset)
#         print("# batch_size : ", self.batch_size)
#         print("# iteration per epoch : ", self.iteration)

#         print()

        print("##### Generator #####")
        print("# residual blocks : ", self.n_res)

        print()

#         print("##### Discriminator #####")
#         print("# discriminator layer : ", self.n_dis)

#         print()

#         print("##### Weight #####")
#         print("# adv_weight : ", self.adv_weight)
#         print("# cycle_weight : ", self.cycle_weight)
#         print("# identity_weight : ", self.identity_weight)
#         print("# cam_weight : ", self.cam_weight)

    ##################################################################################
    # Model
    ##################################################################################

    def build_model(self):
        """ DataLoader """
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.img_size + 30, self.img_size+30)),
            transforms.RandomCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        """ Define Generator, Discriminator """
        self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)

        """ Define Loss """
        self.L1_loss = nn.L1Loss().to(self.device)
        self.MSE_loss = nn.MSELoss().to(self.device)
        self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)

        """ Trainer """
        #self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)

        """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
        self.Rho_clipper = RhoClipper(0, 1)

    def load(self, dir, step):
        params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
        self.genA2B.load_state_dict(params['genA2B'])

    def test(self, img):
        model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
        if not len(model_list) == 0:
            model_list.sort()
            iter = int(model_list[-1].split('_')[-1].split('.')[0])
            self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter)
        else:
            return 'error'
        self.genA2B.eval()
        img = torch.from_numpy(img)
        img = img.to(self.device)
        fake_A2B, _, fake_A2B_heatmap = self.genA2B(img)
        out = RGB2BGR(tensor2numpy(denorm(fake_A2B[0])))
        return out

In [12]:
def transform_test(img):
    test_transform = transforms.Compose([
        transforms.Resize((260, 260)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    return test_transform(img)

In [17]:
a = Image.fromarray(open_cv_image)
b = transform_test(a)

In [26]:
Image.fromarray(b.numpy()[:, :, ::-1].copy() )

TypeError: Cannot handle this data type: (1, 1, 260), <f4

In [9]:
import requests
from PIL import Image
from io import BytesIO

In [10]:
url = 'https://images.unsplash.com/photo-1596330112119-eafc0fb76775?ixlib=rb-1.2.1&ixid=eyJhcHBfaWQiOjEyMDd9&auto=format&fit=crop&w=234&q=80'
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert('RGB')
open_cv_image = np.array(img) 
open_cv_image = open_cv_image[:, :, ::-1].copy() 

In [65]:
args = args_loader()


In [66]:
gan = UGATIT(args)


##### Information #####
# light :  True
##### Generator #####
# residual blocks :  4



In [67]:
gan.build_model()

In [None]:
gan.resul

In [70]:
model_list = glob(os.path.join(gan.result_dir, gan.dataset, 'model', '*.pt'))
if not len(model_list) == 0:
    model_list.sort()
    iter = int(model_list[-1].split('_')[-1].split('.')[0])
    gan.load(os.path.join(gan.result_dir, gan.dataset, 'model'), iter)
else:
    print('error')
gan.genA2B.eval()

ResnetGenerator(
  (gap_fc): Linear(in_features=256, out_features=1, bias=False)
  (gmp_fc): Linear(in_features=256, out_features=1, bias=False)
  (conv1x1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
  (relu): ReLU(inplace=True)
  (gamma): Linear(in_features=256, out_features=256, bias=False)
  (beta): Linear(in_features=256, out_features=256, bias=False)
  (UpBlock1_1): ResnetAdaILNBlock(
    (pad1): ReflectionPad2d((1, 1, 1, 1))
    (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (norm1): adaILN()
    (relu1): ReLU(inplace=True)
    (pad2): ReflectionPad2d((1, 1, 1, 1))
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (norm2): adaILN()
  )
  (UpBlock1_2): ResnetAdaILNBlock(
    (pad1): ReflectionPad2d((1, 1, 1, 1))
    (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (norm1): adaILN()
    (relu1): ReLU(inplace=True)
    (pad2): ReflectionPad2d((1, 1, 1, 1))
    (conv2): Conv2d(256, 256,

In [72]:
url = 'https://images.unsplash.com/photo-1596330112119-eafc0fb76775?ixlib=rb-1.2.1&ixid=eyJhcHBfaWQiOjEyMDd9&auto=format&fit=crop&w=234&q=80'
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert('RGB')
open_cv_image = np.array(img) 
open_cv_image = open_cv_image[:, :, ::-1].copy() 

In [81]:
a = transforms.ToTensor()(open_cv_image).unsqueeze_(0)
a.shape, open_cv_image.shape, img_t.shape

(torch.Size([3, 351, 234]), (351, 234, 3), torch.Size([351, 234, 3]))

In [83]:
# img_t = torch.from_numpy(open_cv_image)
img_t = transforms.ToTensor()(open_cv_image).unsqueeze_(0)
img_t = img_t.to(gan.device)
fake_A2B, _, fake_A2B_heatmap = gan.genA2B(img_t)
out = RGB2BGR(tensor2numpy(denorm(fake_A2B[0])))

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(out)

In [None]:
#         self.adv_weight = args.adv_weight
#         self.cycle_weight = args.cycle_weight
#         self.identity_weight = args.identity_weight
#         self.cam_weight = args.cam_weight
        
#     parser.add_argument('--dataset', type=str, default='testing', help='dataset_name')

#     parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations')
#     parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
#     parser.add_argument('--print_freq', type=int, default=1000, help='The number of image print freq')
#     parser.add_argument('--save_freq', type=int, default=100000, help='The number of model save freq')
#     parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')

#     parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
#     parser.add_argument('--weight_decay', type=float, default=0.0001, help='The weight decay')
#     parser.add_argument('--adv_weight', type=int, default=1, help='Weight for GAN')
#     parser.add_argument('--cycle_weight', type=int, default=10, help='Weight for Cycle')
#     parser.add_argument('--identity_weight', type=int, default=10, help='Weight for Identity')
#     parser.add_argument('--cam_weight', type=int, default=1000, help='Weight for CAM')

#     parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
#     parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
#     parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')


#     parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the results')
#     parser.add_argument('--benchmark_flag', type=str2bool, default=False)
#     parser.add_argument('--resume', type=str2bool, default=False)