In [1]:
import ioutils as io
import skimage
from PIL import Image
import glob
import os
import natsort
import numpy as np
from skimage.transform import resize
import random
import cv2
from PIL import Image
from collections import OrderedDict
import random
import sys
import pickle
import open3d as o3d

from torchvision import transforms
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable

sys.path = ['./cyclegan/'] + sys.path

from my_networks import scr_net, CreateDiscriminator
from models.base_model import BaseModel
from models import networks

from scr_utils import *

import warnings
warnings.filterwarnings('ignore')


In [2]:
norm = 2
epochs = 20
b_size = 48
print_interval = 100
# train_real_proba = 1.0
date = '0131'
NAME = 'scr_cyclegan_b{}_{}_train_his_cyclegan'.format(b_size, date)

data_path = '/scratch/zq415/work/pose/data'


In [3]:
render_img_paths, render_scr_paths, render_poses = make_render_dataset(data_path)
### set the render_img_paths to histogram matched
# render_img_paths = natsort.natsorted(glob.glob('/scratch/zq415/grammar_cor/pose/pose_estimate/train_render_matched'+'/**/*.png', recursive=True))
render_img_paths = natsort.natsorted(glob.glob('/scratch/zq415/work/pose/data/rendered_histogram'+'/**/*.png', recursive=True))
print(len(render_img_paths), len(render_scr_paths), render_poses.shape)


real_val_img_paths, val_scr_label_paths, val_poses = make_render_dataset(data_path, train_flag=False)


print(len(real_val_img_paths), len(val_scr_label_paths), val_poses.shape)


100000 100000 (100000, 4, 4)
1637 1637 (1637, 4, 4)


In [4]:
real_val_dataset = dataset_scr(real_val_img_paths, val_poses, val_scr_label_paths,
                              transform=transforms.Compose([fix_crop()]))
real_val_dataloader = DataLoader(real_val_dataset, batch_size=100,
                            shuffle=False, num_workers=4)


render_train_dataset = dataset_scr(render_img_paths, render_poses, render_scr_paths,
                               transform=transforms.Compose([random_crop()]))
render_train_dataloader = DataLoader(render_train_dataset, batch_size=b_size,
                            shuffle=True, num_workers=4)


In [5]:
def train(model, device, train_loader, optimizer, criterion, epoch, record):
    model.train()
    running_loss = 0.0
    for i_batch, sample_batched in enumerate(train_loader):
        inputs, labels = sample_batched['image'], sample_batched['scr']
        inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)
        # zero the parameter gradients
        optimizer.zero_grad()
        
        outputs = model(inputs)[1]
        loss = criterion(outputs, torch.nn.functional.interpolate(labels, scale_factor=1/8.0, mode='bilinear'),p=norm)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if (i_batch+1) % print_interval == 0:
            out_str = "epoch {}, batch {}, current loss {}.\n".format(epoch+1,i_batch,running_loss/print_interval)
            print(out_str)
            record.write(out_str)
            record.flush()
            
            img = inputs.permute(0,2,3,1).cpu().detach().numpy()[0,...]
            label = labels.permute(0,2,3,1).cpu().detach().numpy()[0,...]
            predict = torch.nn.functional.interpolate(outputs,scale_factor=8.0, mode='bilinear').permute(0,2,3,1).cpu().detach().numpy()[0,...]
            print(img.shape)
            running_loss = 0.0
            
def test(net, dataloader):
    net.eval()
    scr_predicts, true_poses = [], []
    scr_labels = []

    for i_batch, sample_batched in tqdm(enumerate(dataloader)):
        inputs, scrs, poses = sample_batched['image'], sample_batched['scr'], sample_batched['pose']
        inputs = inputs.to(device, dtype=torch.float)
        with torch.no_grad():
            scr_predict = net(inputs)[1]

        scr_predicts.append(scr_predict.cpu().numpy())
        scr_labels.append(torch.nn.functional.interpolate(scrs.cpu(), scale_factor=1/8.0, mode='bilinear').numpy())
        true_poses.append(poses.numpy())
        
    scr_predicts = np.concatenate(scr_predicts)
    scr_labels = np.concatenate(scr_labels)
    true_poses = np.concatenate(true_poses)
    print(scr_predicts.shape, scr_labels.shape, true_poses.shape)
    return scr_predicts, scr_labels, true_poses        



In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')
print(device)


using GPU for training
cuda:0


In [7]:
#### sc network
net = scr_net()
    
net.to(device)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)
    
model_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(model_total_params)


11572803


In [8]:
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=0.00001)
criterion = maskedPnorm


In [None]:
record = open('./save_info/'+ NAME +'.txt','w+')

basic_str = "norm: {}, batch_size: {}.\n".format(norm, b_size) 
print(basic_str)
record.write(basic_str)
record.flush()

for epoch in range(epochs):
    
    train(net, device, render_train_dataloader, optimizer, criterion, epoch, record)
    
    real_val_scr_predicts, val_scrs, val_true_poses = test(net, real_val_dataloader)
    
    out_str = 'real_mean_scr_error: {}.\n'.format(compute_mean_error(val_scrs, real_val_scr_predicts))
    print(out_str)
    record.write(out_str)
    record.flush()
    
    real_val_median = get_median(real_val_scr_predicts, val_true_poses)
    out_str = 'real_val, r_median: {}, t_median: {}.\n'.format(real_val_median[0], real_val_median[1])
    print(out_str)
    record.write(out_str)
    record.flush()
    
    torch.save(net.state_dict(), './check_point/scr_b{}_e{}_norm{}_{}_train_his_cyclegan.pth'.format(b_size,epoch,norm,date))
    
    

norm: 2, batch_size: 48.

epoch 1, batch 99, current loss 2.929746297597885.

(320, 320, 3)
epoch 1, batch 199, current loss 1.3059980708360672.

(320, 320, 3)
epoch 1, batch 299, current loss 1.0534222757816314.

(320, 320, 3)
epoch 1, batch 399, current loss 0.8925641840696334.

(320, 320, 3)
epoch 1, batch 499, current loss 0.7500084942579269.

(320, 320, 3)
epoch 1, batch 599, current loss 0.609468070268631.

(320, 320, 3)
epoch 1, batch 699, current loss 0.5636269581317902.

(320, 320, 3)
epoch 1, batch 799, current loss 0.5055500927567482.

(320, 320, 3)
epoch 1, batch 899, current loss 0.47024311661720275.

(320, 320, 3)
epoch 1, batch 999, current loss 0.6267395177483559.

(320, 320, 3)
epoch 1, batch 1099, current loss 0.49938901782035827.

(320, 320, 3)
epoch 1, batch 1199, current loss 0.4156216859817505.

(320, 320, 3)
epoch 1, batch 1299, current loss 0.412984139919281.

(320, 320, 3)
epoch 1, batch 1399, current loss 0.3664962854981422.

(320, 320, 3)
epoch 1, batch 1499,

17it [01:33,  5.50s/it]
  0%|          | 0/1637 [00:00<?, ?it/s]

(1637, 3, 42, 74) (1637, 3, 42, 74) (1637, 4, 4)
real_mean_scr_error: 0.635046660900116.



100%|██████████| 1637/1637 [01:10<00:00, 23.31it/s]


real_val, r_median: 107.08593896958797, t_median: 9.7109539441329.

epoch 2, batch 99, current loss 0.41751931309700013.

(320, 320, 3)
epoch 2, batch 199, current loss 0.3067796126008034.

(320, 320, 3)
epoch 2, batch 299, current loss 0.29421899244189265.

(320, 320, 3)
epoch 2, batch 399, current loss 0.3063937829434872.

(320, 320, 3)
epoch 2, batch 499, current loss 0.30761067882180215.

(320, 320, 3)
epoch 2, batch 599, current loss 0.2903783640265465.

(320, 320, 3)
epoch 2, batch 699, current loss 0.269980860799551.

(320, 320, 3)
epoch 2, batch 799, current loss 0.2578193497657776.

(320, 320, 3)
epoch 2, batch 899, current loss 0.46974185526371004.

(320, 320, 3)
epoch 2, batch 999, current loss 0.45532130151987077.

(320, 320, 3)
epoch 2, batch 1099, current loss 0.31641054034233096.

(320, 320, 3)
epoch 2, batch 1199, current loss 0.3063148435950279.

(320, 320, 3)
epoch 2, batch 1299, current loss 0.2910547052323818.

(320, 320, 3)
epoch 2, batch 1399, current loss 0.26530

17it [01:22,  4.86s/it]
  0%|          | 0/1637 [00:00<?, ?it/s]

(1637, 3, 42, 74) (1637, 3, 42, 74) (1637, 4, 4)
real_mean_scr_error: 0.5896453261375427.



100%|██████████| 1637/1637 [01:06<00:00, 24.56it/s]


real_val, r_median: 117.95274406327417, t_median: 8.58408712356746.

epoch 3, batch 99, current loss 0.24906323105096817.

(320, 320, 3)
epoch 3, batch 199, current loss 0.23136330187320708.

(320, 320, 3)
epoch 3, batch 299, current loss 0.2278200300037861.

(320, 320, 3)
epoch 3, batch 399, current loss 0.23351360261440277.

(320, 320, 3)
epoch 3, batch 499, current loss 0.21509096547961234.

(320, 320, 3)
epoch 3, batch 599, current loss 0.22863618552684783.

(320, 320, 3)
epoch 3, batch 699, current loss 0.2345196269452572.

(320, 320, 3)
epoch 3, batch 799, current loss 0.2929271651804447.

(320, 320, 3)
epoch 3, batch 899, current loss 0.2795758725702763.

(320, 320, 3)
epoch 3, batch 999, current loss 0.25489909663796423.

(320, 320, 3)
epoch 3, batch 1099, current loss 0.21496190384030342.

(320, 320, 3)
epoch 3, batch 1199, current loss 0.22460278928279875.

(320, 320, 3)
epoch 3, batch 1299, current loss 0.20632360458374024.

(320, 320, 3)
epoch 3, batch 1399, current loss 0.

17it [00:57,  3.40s/it]
  0%|          | 0/1637 [00:00<?, ?it/s]

(1637, 3, 42, 74) (1637, 3, 42, 74) (1637, 4, 4)
real_mean_scr_error: 0.6277364492416382.



100%|██████████| 1637/1637 [01:07<00:00, 24.32it/s]


real_val, r_median: 82.96733770109029, t_median: 8.992603107441212.

epoch 4, batch 99, current loss 0.21396866723895072.

(320, 320, 3)
epoch 4, batch 199, current loss 0.2026227517426014.

(320, 320, 3)
epoch 4, batch 299, current loss 0.20354780688881874.

(320, 320, 3)
epoch 4, batch 399, current loss 0.22474947839975357.

(320, 320, 3)
epoch 4, batch 499, current loss 0.9113996756076813.

(320, 320, 3)
epoch 4, batch 599, current loss 0.4280404564738274.

(320, 320, 3)
epoch 4, batch 699, current loss 0.3185541374981403.

(320, 320, 3)
epoch 4, batch 799, current loss 0.28733678847551347.

(320, 320, 3)
epoch 4, batch 899, current loss 0.25420532435178755.

(320, 320, 3)
epoch 4, batch 999, current loss 0.23524843096733095.

(320, 320, 3)
epoch 4, batch 1099, current loss 0.24590095445513727.

(320, 320, 3)
epoch 4, batch 1199, current loss 0.240287850946188.

(320, 320, 3)
epoch 4, batch 1299, current loss 0.22643935561180115.

(320, 320, 3)
epoch 4, batch 1399, current loss 0.21