In [1]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from solvers import power_method, uball_project
from utils   import pre_process_3d, post_process_3d

def ST(x,t):
    """ shrinkage-thresholding operation. 
    """
    return x.sign()*F.relu(x.abs()-t)

class CDLNetVideo(nn.Module):
    """ Convolutional Dictionary Learning Network for Video Denoising:
    Interpretable denoising DNN with adaptive thresholds for robustness.
    """
    def __init__(self,
                 K=3,            # num. unrollings
                 M=64,           # num. filters in each filter bank operation
                 P=7,            # cubic filter side length
                 s=1,            # stride of convolutions
                 C=1,            # num. input channels
                 t0=0,           # initial threshold
                 adaptive=False, # noise-adaptive thresholds
                 init=True):     # False -> use power-method for weight init
        super(CDLNetVideo, self).__init__()
        
        # -- OPERATOR INIT --
        self.A = nn.ModuleList([nn.Conv3d(C, M, P, stride=s, padding=(P-1)//2, bias=False) for _ in range(K)])
        self.B = nn.ModuleList([nn.ConvTranspose3d(M, C, P, stride=s, padding=(P-1)//2, output_padding=s-1, bias=False) for _ in range(K)])
        self.D = self.B[0]                              # alias D to B[0], otherwise unused as z0 is zero
        self.t = nn.Parameter(t0 * torch.ones(K, 2, M, 1, 1, 1)) # learned thresholds (added one more dimension)
        
        # set weights 
        W = torch.randn(M, C, P, P, P)
        for k in range(K):
            self.A[k].weight.data = W.clone()
            self.B[k].weight.data = W.clone()
        
        # Don't bother running code if initializing trained model from state-dict
        if init:
            print("Running power-method on initial dictionary...")
            with torch.no_grad():
                DDt = lambda x: self.D(self.A[0](x))
                L = power_method(DDt, torch.rand(1, C, 16, 128, 128), num_iter=200, verbose=False)[0]
                print(f"Done. L={L:.3e}.")
                
                if L < 0:
                    print("STOP: something is very very wrong...")
                    sys.exit()
                
            # spectral normalization (note: D is alised to B[0])
            for k in range(K):
                self.A[k].weight.data /= np.sqrt(L)
                self.B[k].weight.data /= np.sqrt(L)
        
        # set parameters
        self.K = K
        self.M = M
        self.P = P
        self.s = s
        self.t0 = t0
        self.adaptive = adaptive

    @torch.no_grad()
    def project(self):
        """ \ell_2 ball projection for filters, R_+ projection for thresholds
        """
        self.t.clamp_(0.0)
        for k in range(self.K):
            self.A[k].weight.data = uball_project(self.A[k].weight.data, dim=(2,3,4)) #onto the unit ball for 3D convolutions
            self.B[k].weight.data = uball_project(self.B[k].weight.data, dim=(2,3,4))

    def forward(self, y, sigma=None, mask=1):
        """ LISTA + D w/ noise-adaptive thresholds
        """
        yp, params, mask = pre_process_3d(y, self.s, mask=mask)
        
        # THRESHOLD SCALE-FACTOR c
        c = 0 if sigma is None or not self.adaptive else sigma / 255.0
        
        # LISTA
        z = ST(self.A[0](yp), self.t[0, :1] + c * self.t[0, 1:2])
        for k in range(1, self.K):
            z = ST(z - self.A[k](mask * self.B[k](z) - yp), self.t[k, :1] + c * self.t[k, 1:2])
        
        # DICTIONARY SYNTHESIS
        xphat = self.D(z)
        xhat = post_process_3d(xphat, params)
        return xhat, z

    def forward_generator(self, y, sigma=None, mask=1):
        """ same as forward but yields intermediate sparse codes
        """
        yp, params, mask = pre_process_3d(y, self.s, mask=mask)
        c = 0 if sigma is None or not self.adaptiave else sigma / 255.0
        z = ST(self.A[0](yp), self.t[0, :1] + c * self.t[0, 1:2]); yield z
        for k in range(1, self.K):
            z = ST(z - self.A[k](mask * self.B[k](z) - yp), self.t[k, :1] + c * self.t[k, 1:2]); yield z
        xphat = self.D(z)
        xhat = post_process3d(xphat, params)
        yield xhat

In [2]:
model = CDLNetVideo(
    K=3,
    M=64,
    P=7,
    s=1,
    C=1,           
    t0=0.1,
    adaptive=True,
    init=True
)

Running power-method on initial dictionary...
Done. L=3.161e+04.


In [3]:
batch_size = 1
channels = 1
depth = 16
height = 128
width = 128
dummy_input = torch.randn(batch_size, channels, depth, height, width)

In [4]:
output, z = model(dummy_input, sigma=25)

In [5]:
target = torch.randn_like(output)

criterion = nn.MSELoss()

loss = criterion(output, target)

model.zero_grad()
loss.backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(f'{name} - Grad shape: {param.grad.shape}')
    else:
        print(f'{name} - No gradient')

t - Grad shape: torch.Size([3, 2, 64, 1, 1, 1])
A.0.weight - Grad shape: torch.Size([64, 1, 7, 7, 7])
A.1.weight - Grad shape: torch.Size([64, 1, 7, 7, 7])
A.2.weight - Grad shape: torch.Size([64, 1, 7, 7, 7])
B.0.weight - Grad shape: torch.Size([64, 1, 7, 7, 7])
B.1.weight - Grad shape: torch.Size([64, 1, 7, 7, 7])
B.2.weight - Grad shape: torch.Size([64, 1, 7, 7, 7])
