In [2]:
%matplotlib notebook
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import torch
import torchvision.transforms as transforms
import numpy as np
import yaml
import sys
import time
from scipy import interpolate
import os
sys.path.append("../")
#sys.path.append("/mnt/lustre/share/zhanxiaohang/videoseg/lib/pydensecrf")
import pdb
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
import flowlib
import models
import utils
import importlib
importlib.reload(utils)
importlib.reload(models)
exp = '../experiments/vip+mpii/resnet50_skiplayer_residual'
load_iter = 46000
config = "{}/config.yaml".format(exp)

In [3]:
class ArgObj(object):
    def __init__(self):
        pass

class suibian(torch.nn.Module):
    def __init__(self, m):
        super(suibian, self).__init__()
        self.module = m
    def forward(self, input1, input2):
        return self.module(input1, input2)

def image_resize(img, short_size):
    w, h = img.width, img.height
    if w < h:
        neww = short_size
        newh = int(short_size / float(w) * h)
    else:
        neww = int(short_size / float(h) * w)
        newh = short_size
    img = img.resize((neww, newh), Image.BICUBIC)
    return img

def image_crop(img, crop_size):
    pad_h = max(crop_size[0] - img.height, 0)
    pad_w = max(crop_size[1] - img.width, 0)
    pad_h_half = int(pad_h / 2)
    pad_w_half = int(pad_w / 2)
    if pad_h > 0 or pad_w > 0:
        border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)
        img = ImageOps.expand(img, border=border, fill=(0,0,0))
    hoff = (img.height - crop_size[0]) // 2
    woff = (img.width - crop_size[1]) // 2
    return img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))

def flow_crop(flow, crop_size):
    pad_h = max(crop_size[0] - img.height, 0)
    pad_w = max(crop_size[1] - img.width, 0)
    pad_h_half = int(pad_h / 2)
    pad_w_half = int(pad_w / 2)
    if pad_h > 0 or pad_w > 0:
        flow_expand = np.zeros((img.height + pad_h, img.width + pad_w, 2), dtype=np.float32)
        flow_expand[pad_h_half:pad_h_half+img.height, pad_w_half:pad_w_half+img.width, :] = flow
        flow = flow_expand
    hoff = (img.height - crop_size[0]) // 2
    woff = (img.width - crop_size[1]) // 2
    return flow[hoff:hoff+crop_size[0], woff:woff+crop_size[1], :]

def image_flow_warp(img, flow, mask_th=1, copy=True, interp=True, interp_mode=0):
    
    warp_img = np.zeros(img.shape, dtype=img.dtype)
    flow_mask = (np.abs(flow[:,:,0]) > mask_th) | (np.abs(flow[:,:,1]) > mask_th)
    pts = np.where(flow_mask)
    vx_pts = flow[:,:,0][pts].astype(np.int)
    vy_pts = flow[:,:,1][pts].astype(np.int)
    v = flow[:,:,0][pts] ** 2 + flow[:,:,1][pts] ** 2
    sortidx = np.argsort(v)
    warp_pts = (pts[0] + vy_pts, pts[1] + vx_pts)
    warp_pts = (np.clip(warp_pts[0], 0, img.shape[0]-1), np.clip(warp_pts[1], 0, img.shape[1]-1))
    warp_pts = (warp_pts[0][sortidx], warp_pts[1][sortidx])
    pts = [pts[0][sortidx], pts[1][sortidx]]
    for c in range(3):
        if copy:
            warp_img[:,:,c][~flow_mask] = img[:,:,c][~flow_mask]
        warp_img[:,:,c][warp_pts] = img[:,:,c][pts]
    if interp:
        holes = ((warp_img.sum(axis=2) == 0) & flow_mask)
        if interp_mode == 0:
            hpts = np.where(holes)
            opts = (hpts[0]-flow[:,:,1][hpts].astype(np.int), hpts[1]-flow[:,:,0][hpts].astype(np.int))
            opts = (np.clip(opts[0], 0, img.shape[0]-1), np.clip(opts[1], 0, img.shape[1]-1))
            for c in range(3):
                warp_img[:,:,c][hpts] = img[:,:,c][opts]
        else:
            for c in range(3):
                warp_img[:,:,c][holes] = interpolate.griddata(np.where(~holes), warp_img[:,:,c][~holes], np.where(holes), method='cubic')
    return warp_img
    
    
class Demo(object):
    def __init__(self, configfn, load_iter):
        args = ArgObj()
        with open(configfn) as f:
            config = yaml.load(f)
        for k, v in config.items():
            setattr(args, k, v)
        setattr(args, 'load_iter', load_iter)
        setattr(args, 'exp_path', os.path.dirname(configfn))
        
        self.model = models.__dict__[args.model['arch']](args.model, dist_model=False)
        
        self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False)
        self.model.switch_to('eval')
      
        self.data_mean = args.data['data_mean']
        self.data_div = args.data['data_div']
        
        self.img_transform = transforms.Compose([
            transforms.Normalize(self.data_mean, self.data_div)])

        self.args = args
        
    def def_input(self, image, repeat=1):
        self.rgb = image
        tensor = self.img_transform(torch.from_numpy(np.array(image).astype(np.float32).transpose((2,0,1))))
        self.image = tensor.unsqueeze(0).repeat(repeat,1,1,1)
        
    def run(self, arrows):
        sparse = np.zeros((1, 2, self.image.size(2), self.image.size(3)), dtype=np.float32)
        mask = np.zeros((1, 2, self.image.size(2), self.image.size(3)), dtype=np.float32)
        for arr in arrows:
            sparse[0, :, int(arr[1]), int(arr[0])] = np.array(arr[2:4])
            mask[0, :, int(arr[1]), int(arr[0])] = np.array([1, 1])
        image = self.image.cuda()
        sparse = torch.from_numpy(sparse).cuda()
        mask = torch.from_numpy(mask).cuda()

        self.model.set_input(image, torch.cat([sparse, mask], dim=1))
        self.model.forward_gie()
        out_flow = self.model.flow.detach().cpu().numpy()[0].transpose((1,2,0))
        out_warped = torch.clamp(utils.unormalize(self.model.warped.detach().cpu(), mean=self.data_mean, div=self.data_div), 0, 255).numpy()[0].transpose((1,2,0))
        out_rgb_gen = torch.clamp(utils.unormalize(self.model.rgb_gen.detach().cpu(), mean=self.data_mean, div=self.data_div), 0, 255).numpy()[0].transpose((1,2,0))
        return out_flow, out_warped.astype(np.uint8), out_rgb_gen.astype(np.uint8)

In [4]:
class Draw(object):
    def __init__(self):
        self.demo = Demo(config, load_iter)
    
    def init_image(self, img, scale=[1.0]):
        img = image_resize(img, self.demo.args.data['short_size'])
        img = image_crop(img, self.demo.args.data['crop_size'])
        self.img = img
        self.fig = plt.figure(figsize=(10, 5))
        self.ax1 = self.fig.add_subplot(141)
        self.ax2 = self.fig.add_subplot(142)
        self.ax3 = self.fig.add_subplot(143)
        self.ax4 = self.fig.add_subplot(144)
        self.status = self.ax1.text(0, 0, "", va="bottom", ha="left", color='r')
        self.ax1.imshow(img)
        self.ax2.imshow(np.ones((img.height, img.width, 3), dtype=np.uint8) * 255)
        self.ax3.imshow(np.ones((img.height, img.width, 3), dtype=np.uint8) * 255)
        self.ax4.imshow(np.ones((img.height, img.width, 3), dtype=np.uint8) * 255)
        self.demo.def_input(img)
        self.ax1.figure.show()
        self.coords = []
        self.start = (0, 0)
        self.obj_arrows = []
        self.connect()
        self.pressed = False

    
    def connect(self):
        self.cidpress = self.ax1.figure.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidkeypress = self.ax1.figure.canvas.mpl_connect('key_press_event', self.on_key)
    
    def on_key(self, event):
        if event.key == 'e':
            self.status.set_text("Running")
            self.ax2.figure.canvas.draw()
            self.flow, self.warped, self.rgb_gen = self.demo.run(self.coords)
            self.ax2.imshow(flowlib.flow_to_image(self.flow))
            self.ax3.imshow(self.warped)
            self.ax4.imshow(self.rgb_gen)
            self.ax4.figure.canvas.draw()
            self.status.set_text("Done")
        elif event.key == '[':
            self.ax4.imshow(self.img)
            self.ax4.figure.canvas.draw()
            self.status.set_text("Origin")
        elif event.key == ']':
            self.ax4.imshow(self.warp_img)
            self.ax4.figure.canvas.draw()
            self.status.set_text("GIE")
            
    def on_press(self, event):
        if event.button == 1:
            if self.pressed == False:
                self.start = (event.xdata, event.ydata)
                self.pressed = True
                self.obj_start_point = self.ax1.plot(self.start[0], self.start[1], '+', color='r')
                self.status.set_text("start point added")
            else:
                self.pressed = False
                dx = event.xdata - self.start[0]
                dy = event.ydata - self.start[1]
                curr_arr = self.ax1.arrow(self.start[0], self.start[1], dx, dy, head_width=5, head_length=10, color='r')
                self.obj_arrows.append(curr_arr)
                self.coords.append([self.start[0], self.start[1], dx, dy])
                self.obj_start_point.pop().remove()
                self.status.set_text("new arrow #{} added".format(len(self.coords)-1))
        elif event.button == 3:
            x, y = event.xdata, event.ydata
            if self.pressed == False and len(self.coords) > 0:
                dist = [abs(x - cd[0]) + abs(y - cd[1]) for cd in self.coords]
                delidx = np.argmin(dist)
                self.obj_arrows[delidx].remove()
                del self.obj_arrows[delidx]
                del self.coords[delidx]
                self.status.set_text("arrow #{} removed".format(delidx))
        self.ax1.figure.canvas.draw()

    def disconnect(self):
        self.ax1.figure.canvas.mpl_disconnect(self.cidpress)
        self.ax1.figure.canvas.mpl_disconnect(self.cidkeypress)


obj = Draw()

KeyError: 'warp_mode'

In [None]:
fn = '/mnt/lustre/share/panxingang/VIP/group1/videos9/000000000801/000000000798.jpg'
#fn = '/mnt/lustre/share/panxingang/VIP/group1/videos9/000000001326/000000001334.jpg'
#fn = '/mnt/lustre/share/panxingang/VIP/group1/videos9/000000000901/000000000908.jpg'
#fn = '/mnt/lustre/share/panxingang/VIP/group1/videos21/000000001776/000000001775.jpg'
#fn = '/mnt/lustre/share/panxingang/VIP/group1/videos11/000000000101/000000000105.jpg'
#fn = '/mnt/lustre/share/panxingang/VIP/group1/videos16/000000001151/000000001143.jpg'
#fn = '/mnt/lustre/share/panxingang/VIP/group1/videos2/000000000976/000000000966.jpg'

#fn = '/mnt/lustre/share/panxingang/youtube9000/images/06/0364/002970.jpg'
#fn = '/mnt/lustre/share/panxingang/youtube9000/images/07/0343/003000.jpg'
#fn = '/mnt/lustre/share/panxingang/youtube9000/images/07/0355/002520.jpg'
#fn = '/mnt/lustre/share/panxingang/youtube9000/images/11/0355/005310.jpg'
#fn = '/mnt/lustre/share/panxingang/youtube9000/images/18/0479/001500.jpg'
#fn = '/mnt/lustre/share/panxingang/youtube9000/images/06/0196/003660.jpg'
#fn = '/mnt/lustre/share/panxingang/youtube9000/images/06/0348/003870.jpg'
#fn = '/mnt/lustre/share/panxingang/youtube9000/images/04/0411/001710.jpg'
img = Image.open(fn).convert("RGB")
obj.init_image(img)
#obj.init_image(img, warp=True, scale=[0.5], crf=False,
#               warp_mask_th=1, warp_copy=True, warp_interp=True, warp_interp_mode=0)

In [7]:
obj.warped.min()

0.0