In [1]:
import torch
import torch.nn as nn
import functools
import torch.optim as optim
from torch.nn import init
import torch.utils.data as data
from pathlib import Path
from PIL import Image
import torchvision.transforms as transforms
import time
import numpy as np
from torch.optim import lr_scheduler

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [5]:
opt = {
    # options for model train
    'mode': 'test',
    'lr': 0.0002,
    'lambdaG': 1,
    'lambdaA': 10,
    'epoch_count': 1,
    'niter': 100, 
    'niter_decay': 50,
    'model_save_path': '/content/drive/My Drive/image2sketch/out/model',
    
    'data_path': Path('/content/drive/My Drive/image2sketch/all-in-one'),
    'n_gt': 5,
    'width': 5,
    'img_size': 256,
    
    'batch_size': 16,
    'is_shuffle': False,
    
    'save_freq': 2,
    'log_freq': 20,
    'img_log_path': '/content/drive/My Drive/image2sketch/out/img',
    'device': 'cpu',
    
    # options for model test
    'test_image': 'test.jpg',
    'test_output': 'test_result.jpg',
    'model_load_path': './output/model',
    'model_load_epoch': 120,
}

# networks
## unet generator
### unet module

In [19]:
class unetModule(nn.Module):
    def __init__(self, input_nc, inner_nc, output_nc=None, sub_module=None, is_outest=False):
        super(unetModule, self).__init__()
        self.is_outest = is_outest
        
        conv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1)
        
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(input_nc)
        
        down_relu = nn.LeakyReLU(0.2, True)
        up_relu = nn.ReLU(True)
        tanh = nn.Tanh()
        
        if is_outest:
            assert(output_nc != None)
            
            convT = nn.ConvTranspose2d(inner_nc * 2, output_nc, kernel_size=4, stride=2, padding=1)
            up = [convT] + [tanh]
            
            down = [conv] + [down_relu]
        elif sub_module:
            convT = nn.ConvTranspose2d(inner_nc * 2, input_nc, kernel_size=4, stride=2, padding=1)
            up = [convT] + [up_norm] + [up_relu]
            
            down = [conv] + [down_norm] + [down_relu]
        else:
            convT = nn.ConvTranspose2d(inner_nc, input_nc, kernel_size=4, stride=2, padding=1)
            up = [convT] + [up_norm] + [up_relu]
            
            down = [conv] + [down_relu]
            
        if sub_module:
            model = down + [sub_module] + up
        else:
            model = down + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.is_outest:
            return self.model(x)
        else:
            return torch.cat((x, self.model(x)), 1)

# inner_test = torch.ones(16, 256, 4, 4)
# unetModule_test = unetModule(256, 512)
# rslt = unetModule_test(inner_test)
# rslt.shape

torch.Size([16, 512, 4, 4])

### build unet

In [20]:
class unetG(nn.Module):
    def __init__(self, input_nc, output_nc, num_wrapper, first_nc=64):
        super(unetG, self).__init__()
        
        unet_submodule = unetModule(first_nc * 2**3, first_nc * 2**3)
        for _ in range(num_wrapper-4):
            unet_submodule = unetModule(first_nc * 2**3, first_nc * 2**3, sub_module=unet_submodule)
        
        unet_submodule = unetModule(first_nc * 2**2, first_nc * 2**3, sub_module=unet_submodule)
        unet_submodule = unetModule(first_nc * 2**1, first_nc * 2**2, sub_module=unet_submodule)
        unet_submodule = unetModule(first_nc, first_nc * 2**1, sub_module=unet_submodule)
        unet_submodule = unetModule(input_nc, first_nc, sub_module=unet_submodule, 
                                    output_nc=output_nc, is_outest=True)
        
        self.model = unet_submodule
        
    def forward(self, x):
        return self.model(x)

x_test = torch.ones(16, 3, 256, 256)
unetG_test = unetG(3, 1, num_wrapper=7)
rslt = unetG_test(x_test)
print(rslt.shape)

torch.Size([16, 1, 256, 256])


## patchGAN discriminator
### build net

In [21]:
class patchDiscriminator(nn.Module):
    def __init__(self, input_nc, first_nc=64, num_layers=3):
        super(patchDiscriminator, self).__init__()
        
        model = [nn.Conv2d(input_nc, first_nc, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
        for i_layer in range(num_layers-1):
            conv = nn.Conv2d(first_nc * 2**i_layer, first_nc * 2**(i_layer+1), kernel_size=4, stride=2, padding=1)
            batch_norm = nn.BatchNorm2d(first_nc * 2**(i_layer+1))
            relu = nn.LeakyReLU(0.2, True)
            model.extend([conv]+[batch_norm]+[relu])
        
        conv = nn.Conv2d(first_nc * 4, first_nc * 8, kernel_size=4, stride=1, padding=1)
        batch_norm = nn.BatchNorm2d(first_nc * 8)
        relu = nn.LeakyReLU(0.2, True)
        model.extend([conv]+[batch_norm]+[relu])
        
        conv = nn.Conv2d(first_nc * 8, 1, kernel_size=4, stride=1, padding=1)
        model.append(conv)
        
        self.model = nn.Sequential(*model)
    
    
    def forward(self, x):
        return self.model(x)

x_test = torch.ones(16, 3, 512, 512)
test_patchDiscriminator = patchDiscriminator(3)
rslt = test_patchDiscriminator(x_test)
rslt.shape

torch.Size([16, 1, 62, 62])

## network utilities

In [9]:
def init_weight(network):
    def weights_init_normal(m):
        classname = m.__class__.__name__
        # print(classname)
        if classname.find('Conv') != -1:
            init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
    network.apply(weights_init_normal)
    
# test_init = init_weight(unetG_test)

# Loss function

In [10]:
class baseGANLoss:
    def __init__(self, gan_type='BCE', device='cuda'):
        self.device = device
        self.loss = nn.BCEWithLogitsLoss().to(device)
    
    def __call__(self, prediction, is_real):
        if is_real:
            target_tensor = torch.ones(prediction.shape).to(self.device)
        else:
            target_tensor = torch.zeros(prediction.shape).to(self.device)
        return self.loss(prediction, target_tensor)

# test_baseGANLoss = baseGANLoss()
# test_baseGANLoss(torch.tensor([1, 0.1, 0.5]).to('cuda'), True)

# model
## pix2pix

In [16]:
class pix2pix:
    def __init__(self, opt):
        self.opt = opt
        
        # networks
        self.netG = unetG(3, 1, num_wrapper=7).to(opt['device'])
        self.netD = patchDiscriminator(4).to(opt['device'])
        init_weight(self.netG)
        init_weight(self.netD)
                
        if opt['mode'] is 'train':
            # criterion
            self.criterionGAN = baseGANLoss(device=opt['device'])
            
            # optimizer
            self.optimizer_G = optim.Adam(self.netG.parameters(), lr=opt['lr'], betas=(0.5, 0.999))
            self.optimizer_D = optim.Adam(self.netD.parameters(), lr=opt['lr'], betas=(0.5, 0.999))
            
            # lr scheduler 
            self.lr_scheduler_G = self.get_scheduler(self.optimizer_G)
            self.lr_scheduler_D = self.get_scheduler(self.optimizer_D)
        
        if opt['mode'] is 'test':
            self.netG.load_state_dict(torch.load( opt['model_load_path']+ ('/G_net_%d.pth' % opt['model_load_epoch'])
                                                 ,map_location=torch.device(opt['device'])  ) )
            
    def set_inputs(self, inputs):
        self.A_real = inputs['A']
        self.B_real = inputs['B']
    
    def forward(self):
        self.B_fake = self.netG(self.A_real)
        
    def backward_D(self):
        # fake images
        input_D_fake = torch.cat((self.A_real, self.B_fake), 1).detach()
        pred_fake = self.netD(input_D_fake)
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        
        # a set of real images
        n_real = self.B_real.shape[1]
        loss_D_real_n = torch.zeros(n_real)
        for i_real in range(n_real):
            B_real_i = self.B_real[:, i_real, :, :].unsqueeze(1)
            input_D_real = torch.cat((self.A_real, B_real_i), 1).detach()
            pred_real = self.netD(input_D_real)
            loss_D_real_n[i_real] = self.criterionGAN(pred_real, True)
        self.loss_D_real = torch.mean(loss_D_real_n).to(self.opt['device'])
        
        # combine loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        
        # backward
        self.loss_D.backward()
    
    def backward_G(self):
        # First, G(A) should fake the discriminator
        input_D_fake = torch.cat((self.A_real, self.B_fake), 1)
        pred_fake = self.netD(input_D_fake)
        loss_G = self.criterionGAN(pred_fake, True)
        
        # Second, G(A) = B
        n_real = self.B_real.shape[1]
        
        self.B_fake.expand([-1, n_real, -1, -1])
        loss_G_L1 = torch.abs(self.B_real - self.B_fake)
        loss_G_L1 = loss_G_L1.view(-1, n_real, loss_G_L1.shape[2]*loss_G_L1.shape[3])
        loss_G_L1 = torch.mean(loss_G_L1, 2)
        
        min_loss_G_L1, min_index = (torch.min(loss_G_L1, 1))
        self.min_index = min_index
        
        # Combine loss
        self.loss_G = loss_G * self.opt['lambdaG'] + torch.mean(min_loss_G_L1) * self.opt['lambdaA']
        
        # backward
        self.loss_G.backward()
    
    def update(self):
        def set_is_require_grad(net, is_require):
            for param in net.parameters():
                param.require_grad = is_require
                
        # generate fake B
        self.forward()
        
        # update discriminator D
        set_is_require_grad(self.netD, is_require=True)
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()
        
        # update generator G
        set_is_require_grad(self.netD, is_require=False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
    
    def get_scheduler(self, optimizer):
        def lambda_rule(epoch, opt=self.opt):
            lr_l = 1.0 - max(0, epoch + 1 + opt['epoch_count'] - opt['niter']) / float(opt['niter_decay'] + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
        return scheduler

    def update_lr(self):
        self.lr_scheduler_G.step()
        self.lr_scheduler_D.step()
            
    def save(self, epoch):
        def save_network(net, net_type, epoch):
            filename = Path(self.opt['model_save_path']) / ('%s_net_%d.pth' % (net_type, epoch))
            torch.save(net.state_dict(), filename)
        save_network(self.netD, 'D', epoch)
        save_network(self.netG, 'G', epoch)
    
    def get_loss(self):
        return {'loss_G': self.loss_G.item(), 'loss_D': self.loss_D.item()}
    
    def get_lr(self):
        return {'G': self.optimizer_G.param_groups[0]['lr'], 
                'D': self.optimizer_D.param_groups[0]['lr']}
    
    def get_B_fake(self):
        return self.B_fake.detach()[0][0]
    
    def get_B_best_real(self):
        return self.B_real.detach()[0][self.min_index[0]]

# opt = {
#     'mode': 'train',
#     'lr': 0.0002,
#     'lambdaG': 1,
#     'lambdaA': 10,
#     'epoch_count': 1,
#     'niter': 100, 
#     'niter_decay': 100,
#     'model_save_path': '/content/drive/My Drive/image2sketch/out/model',
#     'device': 'cuda'
# }

# test_A_real = torch.rand([4, 3, 256, 256], dtype=torch.float32).to(opt['device'])
# test_B_real = torch.rand([4, 4, 256, 256], dtype=torch.float32).to(opt['device'])
# test_inputs= {}
# test_inputs['A'] = test_A_real
# test_inputs['B'] = test_B_real

# test_pix2pix = pix2pix(opt)
# test_pix2pix.set_inputs(test_inputs)
# test_pix2pix.update()
# test_pix2pix.save(-1)
# test_pix2pix.update_lr()

# B_fake = test_pix2pix.get_B_fake()
# B_best_real = test_pix2pix.get_B_best_real()

# B_fake_img = Image.fromarray((B_fake.cpu().numpy() + 1) * 255 / 2)
# B_best_real_img = Image.fromarray((B_best_real.cpu().numpy() + 1) * 255 / 2)

# data
## dataset

In [12]:
class one2pairDataset(data.Dataset):
    def __init__(self, opt, dataset_type):
        super(one2pairDataset, self).__init__()
        self.opt = opt
        
        file_path = opt['data_path'] / 'list' / (dataset_type + '.txt')
        with open(file_path) as f:
            content = f.readlines()
        self.file_names = sorted([x.strip() for x in content])

    def __getitem__(self, index):
        num_groud_truth = self.opt['n_gt']
        file_id = self.file_names[index]
        img_size = self.opt['img_size']

        path_img = self.opt['data_path'] / 'image' / (file_id + '.jpg')
        img = Image.open(path_img)
        img = img.resize((img_size, img_size), Image.BICUBIC)
        img = transforms.ToTensor()(img)
        img = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
        
        sketches = num_groud_truth * [None]
        for i_gt in range(num_groud_truth):
            path_sketch = self.opt['data_path'] / 'sketch-rendered' \
                            / ('width-%d' % self.opt['width']) / ('%s_%02d.png' % (file_id, i_gt + 1))
            sketch = Image.open(path_sketch)
            sketch = sketch.resize((img_size, img_size), Image.BICUBIC)
            sketch = transforms.ToTensor()(sketch)
            sketch = transforms.Normalize((0.5), (0.5))(sketch)
            sketches[i_gt] = sketch
        sketches = torch.cat(sketches, 0)
        
        return {'A': img, 'B': sketches}
    
    def __len__(self):
        return len(self.file_names)

# opt = {
#     'mode': 'train',
#     'lr': 0.0002,
#     'lambdaG': 1,
#     'lambdaA': 10,
#     'epoch_count': 1,
#     'niter': 100, 
#     'niter_decay': 100,
#     'model_save_path': '/content/drive/My Drive/image2sketch/out/model/',
    
#     'data_path': Path('/content/drive/My Drive/image2sketch/all-in-one'),
#     'mode': 'train',
#     'n_gt': 5,
#     'width': 5,
#     'img_size': 256,
# }
# test_dataset_type = 'train'

# test_one2pairDataset = one2pairDataset(opt, dataset_type=test_dataset_type)
# img_paired = test_one2pairDataset[0]

## dataloader

In [13]:
def get_dataloader(opt, dataset_type=['train', 'val', 'test']):
    dataloaders = {}
    for types in dataset_type:
        dataset = one2pairDataset(opt, dataset_type=types)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt['batch_size'],
                                                 shuffle=opt['is_shuffle'])
        dataloaders[types] = dataloader
    return dataloaders

# opt = {
#     'mode': 'train',
#     'lr': 0.0002,
#     'lambdaG': 1,
#     'lambdaA': 10,
#     'epoch_count': 1,
#     'niter': 100, 
#     'niter_decay': 100,
#     'model_save_path': '/content/drive/My Drive/image2sketch/out/model/l',
    
#     'data_path': Path('/content/drive/My Drive/image2sketch/all-in-one'),
#     'mode': 'train',
#     'n_gt': 5,
#     'width': 5,
#     'img_size': 256,
    
#     'batch_size': 1,
#     'is_shuffle': False,
    
# }
        
# test_dataloaders = get_dataloader(opt)

# train

In [14]:
if opt['mode'] is 'train':
    dataloaders = get_dataloader(opt)
    model = pix2pix(opt)
    
    lrs = []
    for epoch in range(0, 1 + opt['niter'] + opt['niter_decay']):
        epoch_time = 0
        epoch_cost_G = 0
        epoch_cost_D = 0
        for i, data in enumerate(dataloaders['test']):
            batch_time_start = time.time()
            
            # put data to gpu
            data['A'] = data['A'].to(opt['device'])
            data['B'] = data['B'].to(opt['device'])
            
            model.set_inputs(data)
            model.update()
            batch_loss = model.get_loss()
            
            batch_time_end = time.time()
            
            if i % opt['log_freq'] is 0:
                print('Epoch %d, batch %d. loss of G: %3f; loss of D: %3f;' \
                      % (epoch, i, batch_loss['loss_G'], batch_loss['loss_D']))
            
            epoch_time += (batch_time_end - batch_time_start)
            epoch_cost_G += batch_loss['loss_G']
            epoch_cost_D += batch_loss['loss_D']
            
        if epoch % opt['save_freq'] is 0:
            model.save(epoch)

            img_save_path = Path(opt['img_log_path'])
            B_fake_img = Image.fromarray((model.get_B_fake().cpu().numpy() + 1) * 255 / 2 ).convert('RGB')\
                            .save(img_save_path / ('%d_fake.png' % epoch))
            B_real_best = Image.fromarray((model.get_B_best_real().cpu().numpy() + 1) * 255 / 2 ).convert('RGB')\
                            .save(img_save_path / ('%d_true.png' % epoch))
            
        model.update_lr()
        lrs.append(model.get_lr()['G'])
        
        
        print('Epoch %d finished. cost time %.3f; mean loss of G %.3f; mean loss of D %.3f.' \
              % (epoch, epoch_time, epoch_cost_G / (i + 1), epoch_cost_D / (i+1)))
        
else:
    print('It is now in test mode, please check opt.')

It is now in test mode, please check opt.


# test

In [18]:
if opt['mode'] is 'test':
    # get test image
    test_img_path = opt['test_image']
    test_img = Image.open(test_img_path)
    test_img = test_img.resize((1024, 1024), Image.BICUBIC)
    test_img = transforms.ToTensor()(test_img)
    test_img = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(test_img)
    test_img = torch.unsqueeze(test_img, 0)
    
    # set model to use
    model = pix2pix(opt)
    data = {'A': test_img, 'B': None}
    model.set_inputs(data)
    
    # save output image
    model.forward()
    B_fake = model.get_B_fake()
    Image.fromarray((model.get_B_fake().cpu().numpy() + 1) * 255 / 2 ).convert('RGB')\
                            .save(opt['test_output'])
else:
    print('It is now in train mode, please check opt.')