In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import skimage as ski
import scipy.sparse
import time
from sklearn.cluster import KMeans

In [None]:
def forward_differences(M,N):
    row = np.arange(0,M*N)
    dat = np.ones(M*N)
    col = np.arange(0,M*N).reshape(M,N)
    col_xp = np.hstack([col[:,1:], col[:,-1:]])
    col_yp = np.vstack([col[1:,:], col[-1:,:]])
    
    FD1 = scipy.sparse.coo_matrix((dat, (row, col_xp.flatten())), shape=(M*N, M*N))- \
          scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))

    FD2 = scipy.sparse.coo_matrix((dat, (row, col_yp.flatten())), shape=(M*N, M*N))- \
          scipy.sparse.coo_matrix((dat, (row, col.flatten())), shape=(M*N, M*N))
    
    FD = scipy.sparse.vstack([FD1, FD2])
    
    return FD

In [None]:
def proj_inf_l2(p, tau):
    # size must be (K,N), l2 over K, inf over N
    norm_p = np.sqrt(np.sum(p**2, axis=0, keepdims=True))
    p /= np.maximum(1, norm_p/tau)
    return p

def proj_inf(p, tau):
    return np.clip(p, -tau, tau)

def proj_simplex(x):
    # Code adapted from Laurent Condat
    # projection onto unit simplex along the first dimension
    s = x.shape
    K = s[0]
    x = x.reshape(K,-1)
    k = np.linspace(1, K, K)
    x_s = -np.sort(-x, axis=0)
    t = (np.cumsum(x_s, axis=0)-1.0)/k[:,None]
    mu = np.max(t, axis=0, keepdims=True)
    return np.maximum(0.0, x-mu).reshape(s)

In [None]:
def potts_simple_pd(f, maxit=1000, check=100, verbose=0):
    
    # Simple convex relaxation (Zach)
    K,M,N = f.shape
    f = f.reshape(K,M*N)
    
    # primal variables
    u = np.zeros((K,M*N))

    # dual variable
    p = np.zeros((K,2*M*N))
    
    # make nabla operator
    FD = forward_differences(M,N)

    # primal and dual step sizes
    # Chosen using diagonal preconditioning
    fact = 1.0
    tau_u = fact*1.0/4.0
    sigma_p = 1.0/(2.0)/fact
    theta = 1.0
    
    t0 = time.time()
    for it in range(0,maxit):

        # remeber old
        ui = u.copy()
        
        # primal update
        for i in range(K):
            u[i] -= tau_u*(FD.T@(p[i]) + f[i])
                
        # projection of u onto simplex
        u = proj_simplex(u)
        
        # overrelaxation
        ui = u + theta*(u-ui)
        
        # dual update
        for i in range(K):
            p[i] += sigma_p*(FD@(ui[i]))
            tmp = p[i].reshape(2,M*N)
            tmp = proj_inf_l2(tmp, 0.5)
            p[i] = tmp.reshape(2*M*N)
            
        if verbose > 0:
            if it%check == check-1:
                TV = 0
                for i in range(K):
                    TV += np.sum((FD@u[i])*p[i])
                    
                print("iter = ", it,
                      ", time = ", "{:.3f}".format(time.time()-t0),
                      ", TV = ", "{:.3f}".format(TV),
                      end="\r")
    print("\n")
    return u.reshape(K,M,N)

In [None]:

def potts_tight_pd(f, maxit=1000, check=100, verbose=0):
    # Tight convex relaxation (CCP 2008)
    K,M,N = f.shape
    f = f.reshape(K,M*N)
    
    # number of dual constraints
    KK = np.int(K*(K-1)/2)

    # primal variables
    u = np.zeros((K,M*N))
    v = np.zeros((KK,2*M*N))
    
    # dual variable
    p = np.zeros((K,2*M*N))
    q = np.zeros((KK,2*M*N))
    
    # make nabla operator
    FD = forward_differences(M,N)

    # primal and dual step sizes
    # Chosen using diagonal preconditioning
    fact = 1.0
    tau_u = fact*1.0/4.0
    tau_v = fact*1.0/3.0
    sigma_p = 1.0/(2.0+K-1)/fact
    sigma_q = 1.0/fact
    theta = 1.0
    
    t0 = time.time()
    for it in range(0,maxit):

        # remeber old
        ui = u.copy()
        vi = v.copy()
        
        # primal update
        for i in range(K):
            u[i] -= tau_u*(FD.T@(p[i]) + f[i])

        idx = 0
        for i in range(K-1):
            for j in range(i+1,K):
                v[idx] -= tau_v*(p[i]-p[j]-q[idx])
                idx +=1
                
        # projection of u onto simplex
        u = proj_simplex(u)
        
        # overrelaxation
        ui = u + theta*(u-ui)
        vi = v + theta*(v-vi)
        
        # dual update
        for i in range(K):
            p[i] += sigma_p*(FD@(ui[i]))
            
        idx = 0
        for i in range(K-1):
            for j in range(i+1,K):
                p[i] += sigma_p*vi[idx]
                p[j] -= sigma_p*vi[idx]
                q[idx] -= sigma_q*vi[idx]
                idx +=1
                
        # projection
        idx = 0
        for i in range(K-1):
            for j in range(i+1,K):
                tmp = q[idx].reshape(2,M*N)
                tmp = proj_inf_l2(tmp, 1.0)
                q[idx] = tmp.reshape(2*M*N)
                idx +=1
        
        if verbose > 0:
            if it%check == check-1:
                TV = 0
                for i in range(K):
                    TV += np.sum((FD@u[i])*p[i])

                print("iter = ", it,
                      ", time = ", "{:.3f}".format(time.time()-t0),
                      ", TV = ", "{:.3f}".format(TV),
                      end="\r")
    print("\n")            
    return u.reshape(K,M,N)

In [None]:
def make_rgb(u,means):
    K,M,N = u.shape
    u_rgb = np.zeros((M,N,3))
    for i in range(K):
        mu = means[i]
        u_rgb[:,:,0] += u[i]*mu[0]
        u_rgb[:,:,1] += u[i]*mu[1]
        u_rgb[:,:,2] += u[i]*mu[2]
    return u_rgb

In [None]:
def compute_squared_dist(g, means):
    M,N,_ = g.shape
    K = means.shape[0]
    f = np.zeros((K, M,N))
    for i in range(K):
        mu = means[i]
        f[i] = (np.sum(0.5*(pixels-mu)**2, axis=1).reshape(M,N))
    return f

In [None]:
# Load image
g = np.double(ski.io.imread("col3.png"))
M,N,_ = g.shape
pixels = g.reshape(-1,3)

# number of phases
K = 3

# mean values
means_rgb = np.array([[255., 0.,0.],[0.,255,0.],[0.,0.,255.]])

# compute field
f = compute_squared_dist(g, means_rgb)

In [None]:
print("Simple convex relaxation")
u_simple = potts_simple_pd(f, verbose=1, maxit=5000, check=100)
u_rgb = make_rgb(u_simple, means_rgb)

plt.figure(figsize=(12,3))

plt.subplot(141)
plt.imshow(u_rgb/255.0)

plt.subplot(142)
plt.imshow(u_simple[0], cmap="gray")

plt.subplot(143)
plt.imshow(u_simple[1], cmap="gray")

plt.subplot(144)
plt.imshow(u_simple[2], cmap="gray")

In [None]:
print("Tight convex relaxation")
u_tight = potts_tight_pd(f, verbose=1, maxit=5000, check=100)
u_rgb = make_rgb(u_tight, means_rgb)

plt.figure(figsize=(12,3))

plt.subplot(141)
plt.imshow(u_rgb/255.0)

plt.subplot(142)
plt.imshow(u_tight[0], cmap="gray")

plt.subplot(143)
plt.imshow(u_tight[1], cmap="gray")

plt.subplot(144)
plt.imshow(u_tight[2], cmap="gray")

In [None]:
# Load image
g = ski.io.imread("tagpfauenauge.jpg")/255.0
M,N,_ = g.shape
pixels = g.reshape(-1,3)

In [None]:
# Number of phases
K = 16

# K-means clustering
kmeans = KMeans(n_clusters=K, random_state=0).fit(pixels)

# compute field
f = compute_squared_dist(g, kmeans.cluster_centers_)

In [None]:
lamb = 5.0
u_simple = potts_simple_pd(f*lamb, maxit=300, check=100, verbose=1)
u_rgb = make_rgb(u_simple, kmeans.cluster_centers_)
    
plt.figure(figsize=(10,3))

plt.subplot(121)
plt.imshow(g)

plt.subplot(122)
plt.imshow(u_rgb)

In [None]:
lamb = 5.0
u_tight = potts_tight_pd(f*lamb, maxit=300, check=100, verbose=1)
u_rgb = make_rgb(u_tight, kmeans.cluster_centers_)

plt.figure(figsize=(10,3))

plt.subplot(121)
plt.imshow(g)

plt.subplot(122)
plt.imshow(u_rgb)