**This file contains 3 sections**

    1. Details on Deep Submodular Function
    
    2. Details on implementation of Greedy Cardinality Constrained submodular maximization
    
    3. Relevant portions of our original codes with comments

%%latex

# <span style="color:green">SECTION (1/3) Deep Submodular Function - DSF</span>
**Architecture**
* Currently, our work trains a DSF for each image separately.
* We learn each DSF on an image of size 28 x 28 and we need to assign a pixel-wise importance for each image. Hence, inputs to our DSF are bit-vectors of size 28 x 28. Following is the architecture:


        def sqrt(input):
            return torch.sqrt(input)

        class DSF(nn.Module): # In PyTorch, classes for Neural Networks should sub-class nn.Module which is the base-class.
            def __init__(self):
                super(DSF, self).__init__()
                self.fc1 = nn.Linear(28 * 28, 512)
                self.fc2 = nn.Linear(512, 256)
                self.fc3 = nn.Linear(256, 32)
                self.fc4 = nn.Linear(32, 1)

            def forward(self, x):
                x = x.view(-1, 28 * 28)
                x = self.fc1(x)
                x = sqrt(x)
                x = self.fc2(x)
                x = sqrt(x)
                x = self.fc3(x)
                x = sqrt(x)
                x = self.fc4(x)
                return x

<hr>

**Training**
* We use **Batch-Gradient descent** as we do not have large enough dataset for a mini-batch setup.
* OPTIMIZER: We use learning rates determined by Adagrad. Adagrad is usually preferred when the data is sparse & we observed the same.
* GRADIENT DESCENT: At each epoch, we backpropagate (using "**loss.backward()**") and update the weights using gradient descent (using "**optimizer.step()**").
* PROJECTION: The projection step with non-negativity constraints, is just the operation max(0, w) on weights w. Hence, after each weight update, we call "**clamp_zero**" class:

        class clamp_zero(object):
            def __init__(self):
                pass

            def __call__(self, module):
                if hasattr(module, 'weight'):
                    w = module.weight.data
                    w.copy_(torch.clamp(w, min=0))
                if hasattr(module, 'bias'):
                    w = module.bias.data
                    w.copy_(torch.clamp(w, min=0))
                    



%%latex

# <span style="color:green">SECTION (2/3) MORE ON LOSS COMPUTATION</span>
* The [original DSF paper](https://arxiv.org/pdf/1701.08939.pdf) trains DSF with only discrete supervision.  

* Our loss (equation (2) in our [paper](https://arxiv.org/pdf/2104.09073.pdf)) comes from supervision via real inputs (multiplied with $\lambda_1$ in the paper) and supervision via bit-vectors (multiplied with $\lambda_2$ in the paper).

* In order to compute our loss with discrete supervision, we need to solve Cardinality Constrained Submodular Maximization problems at each training epoch.
    * We need solutions to this problem for a list of cardinalities. However, due to the greedy nature of Greedy Cardinality Constrained Submodular Maximization algorithm, we need to solve the problem only for the maximum value in the list of cardinalities.

 **Overview of the Greedy Cardinality Constrained Submodular Maximization**:
 
$\qquad$Initially, $A = \{\}; f(A) = f(0)$.

$\qquad$1.  Let $\bar{A} = A \cup \{argmax_{v\in V\setminus A}f(A\cup\{v\})\}$

$\qquad$2. if $f(\bar{A})>f(A)$:

$\qquad$   $\qquad A = \bar{A}$
   
$\qquad$   else:
   
$\qquad$   $\qquad \textrm{return } A$
   
The above two steps are repeated atmost $K$ times.

**Problem with a naive implementation**: This would demand $O(|V|K)$ calls to the DSF Neural Network at every training epoch.

**Solution**: At every training epoch, we can just have $O(K)$ calls to the DSF by everytime inferring on a batch of $|V|$-sized inputs where the $i^{th}$ input in the batch represents $A\cup\{v_i\}$.

<hr>

### <span style="color:green">Implementation Details with an example</span>
* Let $V$ be the universe with $|V| = 4$. For our work, $|V|$ is the resolution of images which was always a perfect square in the datasets we used.

* Initially,  $A = \{\}; f(A) = f(0)$.

* We use a matrix $\mathbf{x}$ whose column $i$ represents $A\cup \{v_i\}$ where $v_i\in V$. Initially,  $\mathbf{x}$ = 
\begin{bmatrix}
1 & 0 & 0 & 0\\
0 & 1 & 0 & 0\\
0 & 0 & 1 & 0\\
0 & 0 & 0 & 1
\end{bmatrix}

* We repeat the following atmost $K$ times

$\qquad$ We reshape $\mathbf{x}$ to get $\mathbf{inputs}$ because PyTorch demands inputs to be of the form *(batch_size, number_of_channels, height, width)*. For our work, batch_size is $|V|$, number_of_channels is 1, height=width=$\sqrt{|V|}$.

$\qquad$ $\mathbf{outputs} = f(\mathbf{inputs})$  # *The $i^{th}$ entry in this vector corresponds to $f(A\cup \{v_i\})$*

$\qquad$ $i = argmax_i ~ \mathbf{outputs}[i]$

$\qquad$ if $\mathbf{outputs}[i]>f(A):$

$\qquad$ $\qquad f(A) = \mathbf{outputs}[i]$ # *We include \{$v_i$\} in A and update f(A)*

$\qquad$ $\qquad$ As $A$ has been updated to $A\cup \{v_i\}$, we update the $i^{th}$ row of $\mathbf{x}$ to all 1's. E.g. if $i = 1$ then 
$\mathbf{x}$ = \begin{bmatrix}
1 & 0 & 0 & 0\\
1 & 1 & 1 & 1\\
0 & 0 & 1 & 0\\
0 & 0 & 0 & 1
\end{bmatrix}
    
$\qquad$ $\qquad$ The 2nd column is not of interest to us anymore as $v_1$ has already been chosen. We can remove that column and reduce our batch-size by 1 but for simplicity, we did not. *NOTE that due to monotonicity of DSF, the 2nd column won't again be chosen via argmax as it contains lesser number of 1's.*

### <span style="color:green">Implementation of the procedure described above</span>

    import torch
    def greedy_cardinality_constrained_submodular_max(f, k, n, device):
        """
        Returns solution of cardinality constrained submodular maximization
            Parameters:
                f (PyTorch model) : DSF
                k (int)           : Cardinality we want to solve
                n (int)           : Where input is of the size n x n
                device (str)      : Device ('CPU' or 'CUDA')
            Returns :
                selected (array)  : Solution
        """
        card_V = n*n
        x = torch.eye(card_V)
        fA = f(torch.zeros(n, n).view(1, 1, n, n).double().to(device)).item() #reshaping for PyTorch
        for iteration in range(1, k+1):
            inputs = x.t().view(card_V, 1, n, n) #reshaping for PyTorch
            outputs = f(torch.Tensor(inputs).double().to(device))
            i = outputs.argmax(dim = 0).item()
            if outputs[i]>fA:
                fA = outputs[i]
                selected = x[:, i] #solution when cardinality constraint is j
                x[i, :] = 1
            else:
                break
        return selected

**<span style="color:red">NOTE : The recently launched [submodlib](https://arxiv.org/pdf/2202.10680.pdf) package, might have a more efficient solver.</span>**

# <span style="color:green">SECTION (3/3) Relevant portions of our original codes with comments</span>

In [None]:
"""
_____Greedy cardinality constrained submodular maximization
"""

import torch
def c_sb_mx(f, Klist, sq_n_sb_px, device):
    '''
    Returns solution of cardinality constrained submodular maximization
            Parameters:
                    f (PyTorch model): DSF
                    Klist (list)     : List of cardinalities
                    sq_n_sb_px (int) : square-root of no. of sub-pixels (ie. resolution of the sub-sampled image)
                    device (str)     : device('cpu' or 'cuda') on which to run
            Returns:
                    AList (dic): Dictionary with keys as cardinalities and values as solutions  
    '''
    k = int(np.array(Klist).max())#we only need to solve for max cardinality
    card_V = sq_n_sb_px*sq_n_sb_px#cardinality of V
    x = torch.eye(card_V)#card_V number of candidate A's each arranged in columns
    fA = f(torch.zeros(sq_n_sb_px, sq_n_sb_px).view(1, 1, sq_n_sb_px, sq_n_sb_px).double().to(device)).item()#f(A), initially A is {}
    AList = {}#dic with key k, value A*_k where A*_k is the optimal subset of cardinality k

    for it in range(1, k+1):#here iteration j means solving for cardinality j
        inputs = x.t().view(card_V, 1, sq_n_sb_px, sq_n_sb_px)#'x' reshaped as PyTorch input
        outputs = f(torch.Tensor(inputs).double().to(device))
        i = outputs.argmax(dim=0).item()
        if outputs[i]>fA:
            fA = outputs[i]
            selected = x[:, i] #solution 
            x[i, :] = 1
            if it in Klist: #Recall that in iteration j, we are solving for cardinality j.
                AList[it] = selected.detach().clone()
        else:
            break
    try:
        for it in Klist:
            if it not in AList: #e.g. we want solution for cardinality j but there is no further gain after i<j elements.
                AList[it] = selected.detach().clone()    
        return AList
    except:
        # Execution of this code indicates no element was chosen.
        print("EmptySet{}".format(outputs[i].item()))
        return torch.zeros(sq_n_sb_px*sq_n_sb_px)

In [None]:
"""
______TRAINING DSF_____
"""

sp_w, sp_h = 28, 28 #super-pixel width & height
sq_n_sb_px = 28 # square root of resolution of sub-sampled image
ht = pre_process.final_ht_proc(sp_w, sp_h, thresholds, I_ALL) # hard-thresholded maps
sub_h = pre_process.final_subI_proc(sp_w, sp_h, I_ALL) # sub-sampled hard-thresholded maps

clipper = clamp_zero()

for epoch in range(epochs):
    # loss_1: loss with hard thresholded maps sub-sampled
    # loss_2: loss with original attribution maps
    loss_1 = None; loss_2 = None 
    """
    Computing loss_1
    """
    # Get solutions to the submodular maximization problem for list of cardinalities given by ht.keys()
    # Adic is a dictionary with key as the cardinality & corresponding value as the solution for that cardinality
    Adic = submod.c_sb_mx(f, list(ht.keys()), sq_n_sb_px, device)
    
    # Convert Adic dictionary to a list & feed all these solutions to the DSF neural network
    ASList = list(Adic.values())
    AList_f = f(torch.stack(ASList).double().view(len(ASList), 1, sq_n_sb_px, sq_n_sb_px).to(device))
    tensor_ht = {}
    
    for xk, k in enumerate(ht): #Here k is the cardinality
        tensor_ht[k] = [torch.Tensor(ht) for ht in ht[k]] # hard thresholded sub-sampled maps(of all methods) having cardinality k
        # Feed all the hard-thresholded maps having cardinality k to the DSF neural network
        all_S_f = f(torch.stack(tensor_ht[k]).double().view(len(tensor_ht[k]), 1, sq_n_sb_px, sq_n_sb_px).to(device))
        for xs, _ in enumerate(tensor_ht[k]):
            to_add = AList_f[xk]-all_S_f[xs]+delta # computes \delta + f_w(A^*) - f_w(\mathcal{H}_i^k) in equation (2)
            if to_add>0:
                if loss_1 is None:
                    loss_1 = to_add
                else:
                    loss_1 = loss_1+to_add
    """
    Computing loss_2
    """
    ones_f = f(torch.ones(sq_n_sb_px*sq_n_sb_px).double().view(1, 1, sq_n_sb_px, sq_n_sb_px).to(device))#f_w(\mathcal{H}^*)
    tensor_sub_h = [torch.Tensor(s_h) for s_h in sub_h] #list of sub-sampled heatmaps
    
    # Feed all sub-sampled hearmaps to the DSF neural network
    sub_h_f = f(torch.stack(tensor_sub_h).double().view(len(tensor_sub_h), 1, sq_n_sb_px, sq_n_sb_px).to(device))
    for xs_h, _ in enumerate(tensor_sub_h):
        to_also_add = ones_f-sub_h_f[xs_h] #computes f_w(\mathcal{H}^*)-f_w(\mathcal{H}_i)
        if to_also_add>0:
            if loss_2 is None:
                loss_2 = to_also_add
            else:
                loss_2 = loss_2+to_also_add
    loss = None
    if loss_1 is not None:
        loss = ld1*loss_1
    if loss_2 is not None:
        if loss is not None:
            loss = loss+ld2*loss_2
        else:
            loss = ld2*loss_2
    if loss is None:
        break
    loss_plt.append(loss.item())
    f.zero_grad()
    loss.backward()
    optimizer.step()
    f.apply(clipper)
