In [2]:
import torch
from PIL import Image
from torchvision import transforms
from timeit import default_timer as timer

This is a pytorch implementation of ARAP Image Deformation, as implemented in Opt/Thallo. Below is the Thallo implemntation as reference.

```lua
local W,H = Dims("W","H")
Inputs {
	Offset = Unknown(float2,{W,H},0),
	Angle = Unknown(float,{W,H},1),	
	UrShape = Array(float2,{W,H},2), --original mesh position
	Constraints = Array(float2,{W,H},3), -- user constraints
	Mask = Array(float, {W,H},4), -- validity mask for mesh
	w_fitSqrt = Param(float, 5),
	w_regSqrt = Param(float, 6)
}

local x,y = W(),H()
Offset:Exclude(Not(eq(Mask(x,y),0)))
Angle:Exclude(Not(eq(Mask(x,y),0)))

local regs = {}
for dx,dy in Stencil { {1,0}, {-1,0}, {0,1}, {0, -1} } do
    local e_reg = w_regSqrt*((Offset(x,y) - Offset(x+dx,y+dy)) 
                             - Rotate2D(Angle(x,y),(UrShape(x,y) - UrShape(x+dx,y+dy))))
    local valid = InBounds(x+dx,y+dy) * eq(Mask(x,y),0) * eq(Mask(x+dx,y+dy),0)
    regs[#regs+1] = Select(valid,e_reg,0)
end
local e_fit = (Offset(x,y) - Constraints(x,y))
local valid = All(greatereq(Constraints(x,y),0))*eq(Mask(x,y),0)
r = Residuals {
    reg = regs,
    fit = w_fitSqrt*Select(valid, e_fit , 0.0)
}
```

In [3]:
mask_path = "../data/mask0.png"
mask_img = Image.open(mask_path)
print(mask_img.size)
 
Mask = transforms.ToTensor()(mask_img)
print(Mask.shape) 

(218, 300)
torch.Size([1, 300, 218])


In [4]:
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

In [5]:
W,H = mask_img.size[0],mask_img.size[1] # local W,H = Dims("W","H")

In [59]:
# Offset = Unknown(float2,{W,H},0)
# Angle = Unknown(float,{W,H},1),  

Offset_And_Angle = torch.zeros(3, H, W, device=device, dtype=torch.float, requires_grad=True) 
"""
Offset = Offset_And_Angle[:2:]
Angle = Offset_And_Angle[2::]
"""

# UrShape = Array(float2,{W,H},2), --original mesh position
# Constraints = Array(float2,{W,H},3), -- user constraints
# Mask = Array(float, {W,H},4), -- validity mask for mesh
UrShape     = torch.zeros(2, H, W, device=device, dtype=torch.float)
Constraints = torch.zeros(2, H, W, device=device, dtype=torch.float)
Mask        = torch.zeros(1, H, W, device=device, dtype=torch.float)

# w_fitSqrt = Param(float, 5)
# w_regSqrt = Param(float, 6)
w_fitSqrt = 10.0
w_regSqrt = 0.1

In [42]:
Mask = transforms.ToTensor()(mask_img)
#TODO: data import here

Fit residual:
```lua
local e_fit = (Offset(x,y) - Constraints(x,y))
local valid = All(greatereq(Constraints(x,y),0))*eq(Mask(x,y),0)
fit = w_fitSqrt*Select(valid, e_fit , 0.0)
```

In [43]:
def fit_residual(Offset_And_Angle):
    Offset = Offset_And_Angle[:2:]
    e_fit = Offset - Constraints
    valid = (Constraints > 0)*(Mask == 0.0)
    fit = w_fitSqrt*valid*e_fit
    return fit.square()

In [68]:
def vjp(f, x, v, create_graph=False):
    x = x.detach().requires_grad_()
    y = f(x)
    y.backward(v, create_graph=create_graph)
    return x.grad
"""
def alternative_Rop(f, x, u):
    v = f.type('v')       # Dummy variable v of same type as f
    g = T.Lop(f, x, v)    # Jacobian of f left multiplied by v
    return T.Lop(g, v, u)
"""
def jvp(f, x, u, v=False, create_graph=False):
    if not torch.is_tensor(v):
        v = torch.ones_like(f(x),device=device,requires_grad=True) # TODO: just use dimensions of result, don't compute f(x)
    g = lambda v: vjp(f, x, v, create_graph=True)
    return vjp(g, v, u, create_graph=create_graph)

def gn_step(f,x,p):
    vjp(f,x,jvp(f,x,p))
    
def gn_benchmark(f,x,iterations=1000):
    p = torch.ones_like(x,device=device,requires_grad=False)
    dummy = torch.ones_like(f(x),device=device,requires_grad=True)
    start = timer()
    for i in range(iterations):
        vjp(f,x,jvp(f,x,p,dummy))
    end = timer()
    return (end-start) / iterations

def gradient_benchmark(f,x,iterations=1000):
    y=f(x)
    start = timer()
    for i in range(iterations):
        y.backward(torch.ones_like(y,device=device),retain_graph=True) # torch.ones_like(x,device=device),
    end = timer()
    return (end-start) / iterations

def residual_benchmark(f,x,iterations=1000):
    x_0 = torch.ones_like(x,device=device,requires_grad=False)
    start = timer()
    for i in range(iterations):
        f(x_0)
    end = timer()
    return (end-start) / iterations

def cost_benchmark(f,x,iterations=1000):
    x_0 = torch.ones_like(x,device=device,requires_grad=False)
    start = timer()
    for i in range(iterations):
        f(x_0).sum()
    end = timer()
    return (end-start) / iterations

def all_benchmarks(f,x,iterations=10000):
    print("Residuals: " + str(residual_benchmark(f,x,iterations)*1000.0) + "ms")
    print("Cost:      " + str(cost_benchmark(f,x,iterations)*1000.0) + "ms")
    print("Gradient:  " + str(gradient_benchmark(f,x,iterations)*1000.0) + "ms")
    print("JtJp:      " + str(gn_benchmark(f,x,iterations)*1000.0) + "ms")

In [54]:
all_benchmarks(fit_residual,Offset_And_Angle,3000)

Residuals: 0.26379225566688547ms
Cost:      0.29175251400010893ms
torch.Size([3, 300, 218])
torch.Size([2, 300, 218])
Gradient:  0.3642981050000647ms
JtJp:      2.8534251909998907ms


In [71]:
"""
local regs = {}
for dx,dy in Stencil { {1,0}, {-1,0}, {0,1}, {0, -1} } do
    local e_reg = w_regSqrt*((Offset(x,y) - Offset(x+dx,y+dy)) 
                             - Rotate2D(Angle(x,y),(UrShape(x,y) - UrShape(x+dx,y+dy))))
    local valid = InBounds(x+dx,y+dy) * eq(Mask(x,y),0) * eq(Mask(x+dx,y+dy),0)
    regs[#regs+1] = Select(valid,e_reg,0)
end
local e_fit = (Offset(x,y) - Constraints(x,y))
local valid = All(greatereq(Constraints(x,y),0))*eq(Mask(x,y),0)
r = Residuals {
    reg = regs,
    fit = w_fitSqrt*Select(valid, e_fit , 0.0)
}
"""

"""
function L.Rotate2D(angle, v)
    local CosAlpha, SinAlpha = ad.cos(angle), ad.sin(angle)
    local matrix = ad.Vector(CosAlpha, -SinAlpha, SinAlpha, CosAlpha)
    return ad.Vector(matrix(0)*v(0)+matrix(1)*v(1), matrix(2)*v(0)+matrix(3)*v(1))
end

"""
def Rotate2D(alpha,v):
    #local CosAlpha, SinAlpha = ad.cos(angle), ad.sin(angle)
    cosAlpha = torch.cos(alpha)
    sinAlpha = torch.sin(alpha)
    #local matrix = ad.Vector(CosAlpha, -SinAlpha, SinAlpha, CosAlpha)
    #return ad.Vector(matrix(0)*v(0)+matrix(1)*v(1), matrix(2)*v(0)+matrix(3)*v(1))
    #TODO: Find a way to make this one operation?
    x = cosAlpha*v[:1:]-sinAlpha*v[1::]
    y = sinAlpha*v[:1:]+cosAlpha*v[1::]
    return torch.cat([x,y])
    

def reg_residual(Offset_And_Angle,offsets):
    Offset = Offset_And_Angle[:2:]
    Angle = Offset_And_Angle[2::]
    #anything cheaper than roll available?
    Off01   = Offset.roll(shifts=offsets,dims=(1,2))
    UrShape01 = UrShape.roll(shifts=offsets,dims=(1,2))
    
    """
    local e_reg = w_regSqrt*((Offset(x,y) - Offset(x+dx,y+dy)) 
                             - Rotate2D(Angle(x,y),(UrShape(x,y) - UrShape(x+dx,y+dy))))
    """
    e_reg = w_regSqrt*(Offset-Off01) - Rotate2D(Angle,UrShape-UrShape01)
    
    Mask01 = Mask.roll(shifts=offsets,dims=(1,2))
    # TODO: For now, we assume InBounds(x+dx,y+dy) is redundant with the mask
    valid = (Mask==0.0) * (Mask01==0.0)
    return (e_reg*valid).square()


def make_reg_residual(offsets):
    return lambda x: reg_residual(x,offsets)

residuals = [make_reg_residual(off) for off in [(0,1),(0,-1),(1,0),(-1,0)]]
residuals += [fit_residual]

all_residuals = lambda x: torch.cat([residual(x) for residual in residuals])

In [72]:
all_benchmarks(all_residuals,Offset_And_Angle,1000)

Residuals: 5.779629956000463ms
Cost:      5.7800542519999ms
Gradient:  6.241290219999428ms
JtJp:      55.89417330299966ms
