In [1]:
import ioutils as io
import skimage
from PIL import Image
import glob
import os
import natsort
import numpy as np
from skimage.transform import resize
import random
import cv2
from PIL import Image
from collections import OrderedDict
import random
import sys
import pickle

from torchvision import transforms
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable

sys.path = ['./cyclegan/'] + sys.path

from my_networks import scr_net, CreateDiscriminator
from models.base_model import BaseModel
from models import networks

from scr_utils import *

import warnings
warnings.filterwarnings('ignore')

In [2]:
# L1 or L2 norm for the regression loss
norm = 2
# training batch size
b_size = 48
# total training update steps
num_steps = 20000
# λL2 = 1, λGAN_F = 0.7 and λGAN_Y = 0.4 * 0.7 = 0.28.
# loss_trg = lambda_adv_target * loss_D_trg_fake + loss_seg_trg + 0.4 * lambda_adv_target * loss_D_2_trg_fake
lambda_adv_target = 0.7
# print loss every print_interval steps
print_interval = 500
# validate current model performance every test_interval step
test_interval = 1000

date = '0118'
version = 0
NAME = 'scr_cyclegan_b{}_{}_his_cyclegan_adv_2gan_version{}'.format(b_size, date, version)

data_path = '/scratch/zq415/grammar_cor/pose/pose_estimate/icra_data'


In [3]:
render_img_paths, render_scr_paths, render_poses = make_render_dataset(data_path)
### set the render_img_paths to histogram matched
render_img_paths = natsort.natsorted(glob.glob('/scratch/zq415/grammar_cor/pose/pose_estimate/train_render_matched'+'/**/*.png', recursive=True))
print(len(render_img_paths), len(render_scr_paths), render_poses.shape)


real_val_img_paths, val_scr_label_paths, val_poses = make_render_dataset(data_path, train_flag=False)


print(len(real_val_img_paths), len(val_scr_label_paths), val_poses.shape)


real_train_img_paths = natsort.natsorted(glob.glob('/scratch/zq415/grammar_cor/pose/pose_estimate/icra_data/train_real_imgs'+'/**/*.png', recursive=True))
train_scr_label_paths = natsort.natsorted(glob.glob('/scratch/zq415/grammar_cor/pose/pose_estimate/icra_data/train_real_scene_coords'+'/**/*.tiff', recursive=True))
real_train_rvecs = io.load(os.path.join(data_path, 'train_real_poses/2/rvecs.json'))
real_train_rvecs += io.load(os.path.join(data_path, 'train_real_poses/5/rvecs.json'))
real_train_tvecs = io.load(os.path.join(data_path, 'train_real_poses/2/tvecs.json'))
real_train_tvecs += io.load(os.path.join(data_path, 'train_real_poses/5/tvecs.json'))
train_poses = []
for i in range(len(real_train_rvecs)):
    cur_pose = np.eye(4)
    cur_pose[:3, :3] = cv2.Rodrigues(np.squeeze(np.array(real_train_rvecs[i])))[0]
    cur_pose[:3, 3] = np.array(real_train_tvecs[i]).T
    train_poses.append(cur_pose[np.newaxis,:,:])
train_poses = np.concatenate(train_poses)

print(len(real_train_img_paths), len(real_train_rvecs), train_poses.shape)


100000 100000 (100000, 4, 4)
1637 1637 (1637, 4, 4)
28411 28411 (28411, 4, 4)


In [4]:
real_val_dataset = dataset_scr(real_val_img_paths, val_poses, val_scr_label_paths,
                              transform=transforms.Compose([fix_crop()]))
real_val_dataloader = DataLoader(real_val_dataset, batch_size=100,
                            shuffle=False, num_workers=4)


render_train_dataset = dataset_scr(render_img_paths, render_poses, render_scr_paths,
                               transform=transforms.Compose([random_crop()]), max_iters=num_steps*b_size)
render_train_dataloader = DataLoader(render_train_dataset, batch_size=b_size,
                            shuffle=True, num_workers=4)

real_train_dataset = dataset_scr(real_train_img_paths, train_poses, train_scr_label_paths,
                               transform=transforms.Compose([random_crop()]), max_iters=num_steps*b_size)
real_train_dataloader = DataLoader(real_train_dataset, batch_size=b_size,
                            shuffle=True, num_workers=4)



In [5]:
class ImagePool():
    """This class implements an image buffer that stores previously generated images.
    This buffer enables us to update discriminators using a history of generated images
    rather than the ones produced by the latest generators.
    """
    def __init__(self, pool_size=200):
        self.pool_size = pool_size
        self.num_imgs = 0
        self.images = []
    def query(self, images): # images b,c,h,w torch tensor.
        """Return an image from the pool.
        Parameters:
            images: the latest generated images from the generator
        Returns images from the buffer.
        By 50/100, the buffer will return input images.
        By 50/100, the buffer will return images previously stored in the buffer,
        and insert the current images to the buffer.
        """
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:       # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)   # collect all the images and return
        return return_images


In [6]:
def train(model, device, render_train_loader, optimizer, criterion, record, model_D, optimizer_D, model_D_2, optimizer_D_2,
         real_train_dataloader):
    model.train()
    model_D.train()
    model_D_2.train()
    trg_lbl_pool = ImagePool(pool_size=200)
    src_lbl_pool = ImagePool(pool_size=200)
    targetloader_iter, sourceloader_iter = iter(real_train_dataloader), iter(render_train_loader)
    
    running_loss = 0.0
    for i in range(num_steps):
        optimizer.zero_grad()
        optimizer_D.zero_grad()
        model_D_2.zero_grad()
        
        for param in model_D.parameters():
            param.requires_grad = False
        for param in model_D_2.parameters():
            param.requires_grad = False
        
        source_sample_batch = sourceloader_iter.next()
        src_img, src_lbl = source_sample_batch['image'], source_sample_batch['scr']
        src_img, src_lbl = src_img.to(device, dtype=torch.float), src_lbl.to(device, dtype=torch.float)
        with torch.no_grad():
            src_img = netG((src_img-0.5)/0.5)*0.5+0.5
            
        src_feat, src_pre = model(src_img)
        loss_reg_src = criterion(src_pre, torch.nn.functional.interpolate(src_lbl, scale_factor=1/8.0, mode='bilinear'),p=norm)
        loss_reg_src.backward()
    
        target_sample_batch = targetloader_iter.next()
        trg_img, trg_lbl = target_sample_batch['image'], target_sample_batch['scr']
        trg_img, trg_lbl = trg_img.to(device, dtype=torch.float), trg_lbl.to(device, dtype=torch.float)
        trg_feat, trg_pre = model(trg_img)
        trg_lbl = torch.nn.functional.interpolate(trg_lbl, scale_factor=1/8.0, mode='bilinear')
                    
        outD_trg = model_D(trg_feat, 0)
        loss_D_trg_fake = model_D.loss
        
        outD_trg_2 = model_D_2(trg_pre, 0)
        loss_D_2_trg_fake = model_D_2.loss
        
        loss_trg = lambda_adv_target * loss_D_trg_fake + 0.4 * lambda_adv_target * loss_D_2_trg_fake
        loss_trg.backward()
    
        for param in model_D.parameters():
            param.requires_grad = True
        for param in model_D_2.parameters():
            param.requires_grad = True
            
        src_feat, trg_feat = src_feat.detach(), trg_feat.detach()
        outD_src = model_D(src_feat, 0)
        loss_D_src_real = model_D.loss / 2
        loss_D_src_real.backward()
        outD_trg = model_D(trg_feat, 1)
        loss_D_trg_real = model_D.loss / 2
        loss_D_trg_real.backward()
        
        src_pre, trg_pre = src_lbl_pool.query(torch.nn.functional.interpolate(src_lbl, scale_factor=1/8.0, mode='bilinear')), trg_lbl_pool.query(trg_pre.detach())
        outD_src_2 = model_D_2(src_pre, 0)
        loss_D_src_real_2 = model_D_2.loss / 2
        loss_D_src_real_2.backward()
        outD_trg_2 = model_D_2(trg_pre, 1)
        loss_D_trg_real_2 = model_D_2.loss / 2
        loss_D_trg_real_2.backward()
                
        optimizer.step()
        optimizer_D.step()
        optimizer_D_2.step()
        
        running_loss += loss_reg_src.item()
        
        if (i+1) % print_interval == 0:
            out_str = "iter {}, current loss {}.\n".format(i+1,running_loss/print_interval)
            print(out_str)
            record.write(out_str)
            record.flush()
            running_loss = 0.0
        
        if (i+1) % test_interval == 0:
            real_val_scr_predicts, val_scrs, val_true_poses = validate(model, real_val_dataloader)
            
            out_str = 'real_mean_scr_error: {}.\n'.format(compute_mean_error(val_scrs, real_val_scr_predicts))
            print(out_str)
            record.write(out_str)
            record.flush()
            
            real_val_median = get_median(real_val_scr_predicts, val_true_poses)
            out_str = 'real_val, r_median: {}, t_median: {}.\n'.format(real_val_median[0], real_val_median[1])
            print(out_str)
            record.write(out_str)
            record.flush()
            
            
            torch.save(model.state_dict(), './check_point/scr_b{}_iter{}_norm{}_{}_his_cyclegan_adv2gan_ensemble{}.pth'.format(b_size,i,norm,date,version))
                

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')
print(device)


using GPU for training
cuda:0


In [8]:
netG_device = [i for i in range(torch.cuda.device_count())]
netG = networks.define_G(3, 3, 64, 'resnet_9blocks', 'instance', False, 'normal', 0.02, netG_device)
state_dict = torch.load('./cyclegan/checkpoints/match2real_cyclegan_scr_2gpu_1024_noidentity/4_net_G_A.pth', map_location=str(device))
# state_dict = torch.load('./cyclegan/checkpoints/match2real_cyclegan_scr_2gpu_1011/4_net_G_A.pth', map_location=str(device))
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if 'module.' not in k:
        k = 'module.'+ k
        new_state_dict[k]=v
netG.load_state_dict(new_state_dict)
for param in netG.parameters():
    param.requires_grad = False
netG.eval()



initialize network with normal


DataParallel(
  (module): ResnetGenerator(
    (model): Sequential(
      (0): ReflectionPad2d((3, 3, 3, 3))
      (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
      (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): ResnetBlock(
        (conv_block): Sequential(
          (0): ReflectionPad2d((1, 1, 1, 1))
          (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
          (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (3): ReLU(inplac

In [9]:
model_D, optimizer_D = CreateDiscriminator(input_channel=512)


In [10]:
model_D_2, optimizer_D_2 = CreateDiscriminator(input_channel=3)


In [11]:
net = scr_net()
net.to(device)
model_D.to(device)
model_D_2.to(device)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)
    model_D = nn.DataParallel(model_D)
    model_D_2 = nn.DataParallel(model_D_2)
    
model_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(model_total_params)


11572803


In [12]:
optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=0.00001)
criterion = maskedPnorm



In [None]:
record = open('./save_info/'+ NAME +'.txt','w+')

basic_str = "norm: {}, batch_size: {}.\n".format(norm, b_size)
print(basic_str)
record.write(basic_str)
record.flush()

train(net, device, render_train_dataloader, optimizer, criterion, record, model_D, optimizer_D, model_D_2, optimizer_D_2,
         real_train_dataloader)


norm: 2, batch_size: 48.

