Let's load some libraries first

In [1]:
from scipy.spatial import Voronoi, voronoi_plot_2d
import numpy as np
import matplotlib.pyplot as plt
from primaldual_multi import PrimalDual
import torch
from utils import * 


def setDevice():
    if torch.cuda.is_available(): # cuda gpus
        device = torch.device("cuda")
        #torch.cuda.set_device(int(gpu_id))
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    elif torch.backends.mps.is_available(): # mac gpus
        device = torch.device("mps")
    elif torch.backends.mkl.is_available(): # intel cpus
        device = torch.device("mkl")
    torch.set_grad_enabled(True)
    return device

# detect GPU device and set it as default
dev = setDevice()
g = DeviceMode(torch.device(dev))
g.__enter__()

  from .autonotebook import tqdm as notebook_tqdm


<utils.DeviceMode at 0x1218cb2e0>

In [2]:
def run_primal_dual(f, repeats, level, lmbda, nu):
    # create an instance of the class
    primal_dual = PrimalDual(f, repeats, level, lmbda, nu)
    # call the run function
    primal_dual.run()


In [17]:
# compile the function and the class into a Torchscript module
f = torch.randn(10, 10, 1)
repeats = torch.tensor(10)
level = torch.tensor(16)
lmbda = torch.tensor(1)
nu = torch.tensor(0.1)
tol = torch.tensor(1e-3)
# repeats = 10
# level = 16
# lmbda = 1
# nu = 0.1

#scripted_primal_dual = torch.jit.trace(PrimalDual(), (f, repeats, level, lmbda, nu))
scripted_primal_dual = torch.jit.script(PrimalDual(), example_inputs = [f, repeats, level, lmbda, nu, tol])

In [3]:
f.dim()

3

In [12]:
torch.jit.save(scripted_primal_dual, 'scripted_primal_dual.pt')

In [19]:
scripted_primal_dual = torch.jit.load('scripted_primal_dual.pt')

In [10]:
scripted_primal_dual = scripted_primal_dual.to(dev)


In [20]:
import cv2

#----- parameters
image = "marilyn.png"
gray = True
#-------
 


def convert_interleaved_to_layered(aOut, aIn, w, h, nc):
    if nc==1:
        aOut=aIn
        return aOut
    nOmega = w*h
    for y in range(h):
        for x in range(w):
            for c in range(nc):
                aOut[x + w*y + nOmega*c] = aIn[(nc-1-c) + nc*(x + w*y)]
    return aOut

def convert_mat_to_layered(aOut, mIn, nc, w, h):
    return convert_interleaved_to_layered(aOut, mIn, w, h, nc)



image = "marylin.png"
mIn = cv2.imread(image, (0 if gray else 1))
mIn = mIn.astype(np.float32)
mIn /= 255
w = mIn.shape[1]         # width
h = mIn.shape[0]         # height
nc = mIn.shape[2] if mIn.ndim == 3 else 1  # number of channels
h_img = np.zeros((h,w,nc,), dtype = np.float32)
h_img = convert_mat_to_layered(h_img, mIn.flatten(), nc, w, h)
f = torch.as_tensor(h_img).view(h,w,nc,).detach().clone()

repeats = torch.tensor(1000)
level = torch.tensor(16)
lmbda = torch.tensor(1)
nu = torch.tensor(0.1)
tol = torch.tensor(5e-5)
# repeats = 1000
# level = 16
# lmbda = 1
# nu = 0.001

# model = PrimalDual()
# u = model.forward(f, repeats, level, lmbda, nu)

u = scripted_primal_dual(f, repeats, level, lmbda, nu, tol)

In [21]:
def interpolate(k, uk0, uk1, l):
    return (k + (0.5 - uk0) / (uk1 - uk0)) / l

def isosurface(u, l, h, w, nc):

    u = u.detach().cpu().numpy()
    mask = (u[:,:,:,:-1] > 0.5) & (u[:,:,:,1:] <= 0.5)
    # Find the indices of the first True value along the last dimension, and set all the following ones to False
    mask[:, :, :, 1:] = (mask[:, :, :, 1:]) & (mask.cumsum(-1)[:,:,:,:-1] < 1)

    uk0 = u[:,:,:,:-1][mask]
    uk1 = u[:,:,:,1:][mask]
    
    # get the indices of the last dimension where mask is True
    k = np.where(mask == True)[-1] + 1
    
    h_img = interpolate(k, uk0, uk1, l).reshape(h, w, nc)
    
    return h_img

        
h_img = isosurface(u, int(level), int(h), int(w), int(nc)) # back out estimated image from superlevel sets using 0.5-isosurface and assign to self.h_img

# save input and result
cv2.imwrite("result.png",h_img*255)

True

In [15]:
image = "marylin.png"
cv2.imread(image, (0 if gray else 1))

array([[243, 243, 241, ..., 241, 239, 241],
       [242, 242, 241, ..., 242, 241, 241],
       [244, 242, 243, ..., 239, 240, 241],
       ...,
       [224, 223, 221, ..., 229, 228, 229],
       [227, 223, 222, ..., 230, 231, 230],
       [225, 224, 220, ..., 231, 231, 231]], dtype=uint8)

In [8]:
f.shape

torch.Size([127, 127, 1])