# Part II: Model-based Regularization

____________

Lorenz: some TV auxiliaries, can be exported to file

In [None]:
class total_variation():
    """
    total variation of 2D image u with shape (n1, n2). Scaled by a constant
    regularization parameter scale. Corresponds to the functional 
        scale * TV(u)
    with u in R^{n1 x n2}
    
    __init__ input:
        - n1, n2:   shape of u
        - scale:    scaling factor, usually a regularization parameter
        
    __call__ input:
        - u:        image of shape n1,n2 or n1*n2,
                    
    TV is computed on a grid via finite differences, assuming equidistant 
    spacing of the grid. The gradient of this potential does not exist since 
    TV is not smooth.
    The proximal mapping is approximated using the dual problem. Pass either 
    a maximum number of steps, an accuracy (in the primal-dual gap), or both 
    to the prox evaluation, for more details see 
        total_variation.inexact_prox
    """
    def __init__(self, n1, n2, scale=1):
        self.n1 = n1
        self.n2 = n2
        self.scale = scale
        
    def _imgrad(self, u):
        """
        applies a 2D image gradient to the image u of shape (n1,n2)
        
        Parameters
        ----------
        u : numpy 2D array, shape n1, n2
            Image

        Returns
        -------
        (px,py) image gradients in x- and y-directions.

        """
        px = np.concatenate((u[1:,:]-u[0:-1,:], np.zeros((1,self.n2))),axis=0)
        py = np.concatenate((u[:,1:]-u[:,0:-1], np.zeros((self.n1,1))),axis=1)
        return np.concatenate((px[np.newaxis,:,:],py[np.newaxis,:,:]), axis=0)
    
    def _imdiv(self, p):
        """
        Computes the negative divergence of the 2D vector field px,py.
        can also be seen as a tensor from R^(n1xn2x2) to R^(n1xn2)

        Parameters
        ----------
            - p : 2 x n1 x n2 np.array

        Returns
        -------
            - divergence, n1 x n2 np.array
        """
        u1 = np.concatenate((-p[0,0:1,:], -(p[0,1:-1,:]-p[0,0:-2,:]), p[0,-2:-1,:]), axis = 0)
        u2 = np.concatenate((-p[1,:,0:1], -(p[1,:,1:-1]-p[1,:,0:-2]), p[1,:,-2:-1]), axis = 1)
        return u1+u2
    
    def __call__(self, u):
        """
        Computes the TV-seminorm of u
        
        Parameters 
        ----------
        u : numpy array of shape n1, n2
        
        Returns
        -------
        TV(u) (scalar)
        """
        return self.scale * np.sum(np.sqrt(np.sum(self._imgrad(u)**2,axis=0)))
    
    def inexact_prox(self, u, gamma=1, epsilon=None, max_iter=np.Inf, verbose=False):
        """
        Computing the prox of TV is solving the ROF model with FISTA
            
        parameters:
            - u:        image to be denoised, shape self.n1, self.n2
            - gamma:    prox step size
            - epsilon:  accuracy for duality gap stopping criterion
            - maxiter:  maximum number of iterations
            - verbose:  verbosity
        """
        if epsilon is None and max_iter is np.Inf:
            raise ValueError('provide either an accuracy or a maximum number of iterations to the tv prox please')
        checkAccuracy = True if epsilon is not None else False
        # iterative scheme to minimize the dual objective
        p = np.zeros((2,self.n1,self.n2))
        q = np.copy(p)
        
        stopcrit = False
        t, t_prev = 1, 1
        
        i = 0
        if verbose: sys.stdout.write('run FISTA on dual ROF model: {:3d}% '.format(0)); sys.stdout.flush()
        
        while i < max_iter and not stopcrit:
            i = i + 1
            p_prev = np.copy(p)
            
            v = q - 1/(8*gamma) * self._imgrad(gamma*self._imdiv(q) - u)
            p = v/np.maximum(1, np.sqrt(np.sum(v**2,axis=0))/self.scale)[np.newaxis,:,:]
            
            t_new = (1+np.sqrt(1+4*t**2))/2
            t, t_prev = t_new, t
            q = p + (t_prev-1)/t * (p - p_prev)
            
            # stopping criterion: check if primal-dual gap < epsilon
            if checkAccuracy:
                div_p = self._imdiv(p)
                h = gamma/2 * np.sum(div_p**2)
                primal = self(u-gamma*div_p) + h
                conj_TV_p = np.max(np.sqrt(np.sum(p**2,axis=0)))
                dual_inadmissible = conj_TV_p > self.scale+1e-12
                dual = np.Inf if dual_inadmissible else h - np.sum(div_p * u) # dual value. dual iterate should never be inadmissible since we project in the end
                dgap = primal+dual
                stopcrit = dgap <= epsilon
                if dgap < -5e-15: # for debugging purpose
                    raise ValueError('Duality gap was negative (which should never happen), please check the prox computation routine!')
                if verbose: sys.stdout.write('\b'*5 + '{:3d}% '.format(int(i/max_iter*100))); sys.stdout.flush()
        if verbose: sys.stdout.write('\b'*5 + '100% '); sys.stdout.flush()
        return (u - gamma*self._imdiv(p)), i