# Weakly-Convex-Ridge Regularizers

This notebook gives a few basic snippets to use the pretrained WCRR-NN.

In [1]:
import torch
from torchmetrics.functional import peak_signal_noise_ratio as psnr

import sys
sys.path.append('../')
from models import utils

In [2]:
device = 'cuda:0'
torch.set_grad_enabled(False)
torch.set_num_threads(4)

## 1. Loading a Model

The defaults WCRR-NN was trained with noise level $\sigma\in [0, 25/255]$ with a DEQ on BSD images.

In [3]:
model = utils.load_model("WCRR-CNN", device)
# update the convolutional layer Lipschitz bound use the power method
sn_pm = model.conv_layer.spectral_norm(mode="power_method", n_steps=200)

/home/goujon/weakly_convex_ridge_regularizer/trained_models/WCRR-CNN/checkpoints/*.pth
Multi convolutionnal layer:  {'num_channels': [1, 4, 8, 60], 'size_kernels': [5, 5, 5]}


## 2. Using a Model

Recall that $R$ and $\nabla R$ take two inputs:
- a tensor (image or batch of images) with following dimensions:
    1. batch
    2. channel (one since grayscale only)
    3. and 4. spatial dimensions
    
    
- a tensor with the noise levels (dimension (batch, 1, 1, 1))


Gradient

In [4]:
im = torch.empty((1, 1, 100, 100), device=device).uniform_()
sigma = torch.tensor([25.], device=device).view(-1,1,1,1)

grad = model.grad(im, sigma)

Lipschitz constant of the gradient

Recall that $\nabla R = W^T \sigma (W \cdot)$, where:
- $\sigma = \mu \sigma_+ - \sigma_-$ is a pointwise activation with $\sigma_{+/-}'\in[0,1]$
- $\mu>1$
- $\|W\|=1$

Hence we can show that $\mathrm{Lip}(\nabla R)\leq \mu$.

In [5]:
Lip_grad = model.get_mu()

Regularization Cost

The regularization cost is not used during training.

On the fisrt call, the construction of the potential function is triggered. The activation functions are expressed with linear B-splines and the profile functions with quadratic B-splines.

In [6]:
model.cost(im, sigma)

**** Updating integrated spline coefficients ****
**** Updating integrated spline coefficients ****


tensor([159.7290], device='cuda:0')