<a href="https://colab.research.google.com/github/DingfanChen/GS-WGAN/blob/main/source/sanity_check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
import random
import torch.optim as optim

In [2]:
### Toy single layer
class toy_single(nn.Module):
    def __init__(self, operation, constant):
        super(toy_single, self).__init__()
        self.operation = operation
        self.constant = constant

    def forward(self, input):
        if self.operation == 'add':
            self.output = input + self.constant

        elif self.operation == 'mul': # pointwise-mulipication
            self.output = input * self.constant
        return self.output

### hook functions

In [3]:
def dynamic_hook_module(module,grad_input,grad_output):
    global global_backward_hook
    return global_backward_hook(module, grad_input, grad_output)

In [4]:
def printgrad_module(module, grad_input, grad_output):
    print('-------print grad enabled--------')
    print(grad_input)

In [5]:
def dummygrad_module(module, grad_input, grad_output):
    pass

In [6]:
def modifygradnorm_module(module,grad_input, grad_output):
    global target_norm
    grad_input_shape = grad_input[0].size()
    bs = grad_input[0].size()[0]
    grad_input_ = grad_input[0].view(bs,-1)
    grad_input_norm = torch.norm(grad_input_,p=2,dim=-1)
    clip_coef = target_norm / (grad_input_norm + 1e-6)  
    clip_coef = clip_coef.unsqueeze(-1)
    grad_input_ = clip_coef*grad_input_
    print('-------modify gradnorm enabled--------')
    print(grad_input)
    print('grad_input norm(before):', grad_input_norm)
    grad_input = (grad_input_.view(grad_input_shape),grad_input[1])
    print('grad_input norm(after):', torch.norm(grad_input_,p=2,dim=-1))
    return tuple(grad_input)

### Example (dummy hook)

Construct a toy **generator** that receives a (batch of) 3-dim latent code $z=[z_1,z_2,z_3]$, and output:   
$\qquad \mathcal{G}(z)=[z_1,z_2,z_3] \odot [w_1,w_2,w_3]+[b_1,b_2,b_3]$  
&nbsp;   

Construct a toy **discriminator** that receives a (batch of) "generated samples" $\mathcal{G}(z)$, and output (the summation here is taken over the feature dimension):   
$\qquad \mathcal{D}(\mathcal{G}(z))=\sum \left(\left([z_1,z_2,z_3] \odot [w_1,w_2,w_3]+[b_1,b_2,b_3]\right)\cdot d^{param 1}+[d_1^{param 2},d_2^{param 2},d_3^{param 2}]\right)\cdot d^{param 3}$    
&nbsp;   

And the loss (averaged over the batch) should be ($N$ is the number of samples, and the summation here is taken over the samples within a batch):   
$\qquad \mathcal{L} =\frac{1}{N}\sum_{i=1}^N \mathcal{D}(\mathcal{G}(z^i))$    
&nbsp;  

Compute by-hand, the gradients should be:   
$\qquad \nabla_\mathbf{w}\mathcal{L}=\frac{d^{param 1}\cdot d^{param 3}}{N} \sum_{i=1}^N [z^i_1,z^i_2,z^i_3]$   
$\qquad \nabla_\mathbf{b}\mathcal{L}=d^{param 1}\cdot d^{param 3}\cdot [1,1,1]$

In [7]:
## construct a toy 'Generator'
z = torch.tensor([[1,2,1.5],[2,0,2.5],[-1,-1,-1.5], [0.5,-0.5,1]])  # dim0(sample idx), dim1(feature idx)
g_weights = torch.tensor([1.,-1.,1.])
g_bias = torch.tensor([0,1.,1.])
g_weights.requires_grad=True
g_bias.requires_grad=True
G_layer1 = toy_single('mul', g_weights)
G_layer2 = toy_single('add', g_bias)
G_layer1.register_backward_hook(dummygrad_module)
G_layer2.register_backward_hook(dummygrad_module)
G_out = G_layer2(G_layer1(z))  # G_out is the toy 'generated samples'
optimizerG = optim.SGD([g_weights,g_bias], lr=0.5)

## construct a toy 'Discriminator'
d_param_1 = torch.tensor(2.)
d_param_2 = torch.tensor([-2.,2.,1.])
d_param_3 = torch.tensor(-1.)
d_param_1.requires_grad = False
d_param_2.requires_grad = False
d_param_3.requires_grad = False
D_layer1 = toy_single('mul', d_param_1)
D_layer2 = toy_single('add', d_param_2)
D_layer3 = toy_single('mul', d_param_3)
D_layer2.register_backward_hook(dummygrad_module)
D_layer3.register_backward_hook(dummygrad_module)
D_layer1.register_backward_hook(dynamic_hook_module)
D_out = D_layer3(D_layer2(D_layer1(G_out))).sum(dim=1) # D_out mimics the (negated) per_sample loss for generator
global_backward_hook = printgrad_module



In [8]:
### check gradients on leaf nodes
D_out.mean().backward()
print('============grad============')
print(g_weights.grad, g_bias.grad) 
print('============value(before)===========')
print(g_weights,g_bias)
optimizerG.step()
print('============value(after)===========')
print(g_weights,g_bias)

-------print grad enabled--------
(tensor([[-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000]]), None)
tensor([-1.2500, -0.2500, -1.7500]) tensor([-2., -2., -2.])
tensor([ 1., -1.,  1.], requires_grad=True) tensor([0., 1., 1.], requires_grad=True)
tensor([ 1.6250, -0.8750,  1.8750], requires_grad=True) tensor([1., 2., 2.], requires_grad=True)


### Example (modify grad_norm)

Recall the formulation of the toy **generator** , **discriminator**, and the loss:   
$\qquad \mathcal{G}(z)=[z_1,z_2,z_3] \odot [w_1,w_2,w_3]+[b_1,b_2,b_3]$  
$\qquad \mathcal{D}(\mathcal{G}(z))=\sum \left(\left([z_1,z_2,z_3] \odot [w_1,w_2,w_3]+[b_1,b_2,b_3]\right)\cdot d^{param 1}+[d_1^{param 2},d_2^{param 2},d_3^{param 2}]\right)\cdot d^{param 3}$    
$\qquad \mathcal{L} =\frac{1}{N}\sum_{i=1}^N \mathcal{D}(\mathcal{G}(z^i))$    
&nbsp;  

For each generated sample the gradient of the loss w.r.t. it (which we modify) is:  
$\qquad \nabla_{\mathcal{G}(z)}\mathcal{L}=\frac{d^{param 1}\cdot d^{param 3}}{N} \cdot [1,1,1]$   
After the modification, this gradient should be ("factor" means the muliplication factor for changing the gradient norm)  
$\qquad \nabla_{\mathcal{G}(z)}\widehat{\mathcal{L}}=\text{factor}\cdot\frac{d^{param 1}\cdot d^{param 3}}{N} \cdot [1,1,1]$  

We also need the following Jacobian to compute the parameter gradients by-hand:  
$\qquad J_\mathbf{w}(\mathcal{G}(z))=\begin{bmatrix} z_1 & 0& 0 \\ 
0 & z_2 & 0 \\
0 & 0 & z_3 \end{bmatrix}$  
$\qquad J_\mathbf{b}(\mathcal{G}(z))=\begin{bmatrix} 1 & 0& 0 \\ 
0 & 1 & 0 \\
0 & 0 & 1 \end{bmatrix}$  

Thus, we have:  
$\qquad \nabla_\mathbf{w}\mathcal{L}=\nabla_{\mathcal{G}(z)}\widehat{\mathcal{L}} \cdot J_\mathbf{w}(\mathcal{G}(z)) = \frac{d^{param 1}\cdot d^{param 3}}{N} \sum_{i=1}^N \text{factor}^i \cdot[z^i_1,z^i_2,z^i_3]$   
$\qquad \nabla_\mathbf{b}\mathcal{L}=\nabla_{\mathcal{G}(z)}\widehat{\mathcal{L}} \cdot J_\mathbf{b}(\mathcal{G}(z))=d^{param 1}\cdot d^{param 3}\cdot [1,1,1] \cdot  \sum_{i=1}^N\text{factor}^i $

In [9]:
## construct a toy 'Generator'
z = torch.tensor([[1,2,1.5],[2,0,2.5],[-1,-1,-1.5], [0.5,-0.5,1]])  # dim0(sample idx), dim1(feature idx)
g_weights = torch.tensor([1.,-1.,1.])
g_bias = torch.tensor([0,1.,1.])
g_weights.requires_grad=True
g_bias.requires_grad=True
G_layer1 = toy_single('mul', g_weights)
G_layer2 = toy_single('add', g_bias)
G_layer1.register_backward_hook(dummygrad_module)
G_layer2.register_backward_hook(dummygrad_module)
G_out = G_layer2(G_layer1(z))  # G_out is the toy 'generated samples'
optimizerG = optim.SGD([g_weights,g_bias], lr=0.5)

## construct a toy 'Discriminator'
d_param_1 = torch.tensor(2.)
d_param_2 = torch.tensor([-2.,2.,1.])
d_param_3 = torch.tensor(-1.)
d_param_1.requires_grad = False
d_param_2.requires_grad = False
d_param_3.requires_grad = False
D_layer1 = toy_single('mul', d_param_1)
D_layer2 = toy_single('add', d_param_2)
D_layer3 = toy_single('mul', d_param_3)
D_layer2.register_backward_hook(dummygrad_module)
D_layer3.register_backward_hook(dummygrad_module)
D_layer1.register_backward_hook(dynamic_hook_module)
D_out = D_layer3(D_layer2(D_layer1(G_out))).sum(dim=1) # D_out mimics the (negated) per_sample loss for generator
global_backward_hook = printgrad_module

## modify the grad_norm
target_norm = 1.
global_backward_hook = modifygradnorm_module

In [10]:
### check gradients on leaf nodes
D_out.mean().backward()
print('============grad============')
print(g_weights.grad, g_bias.grad)
print('============value(before)===========')
print(g_weights,g_bias)
optimizerG.step()
print('============value(after)===========')
print(g_weights,g_bias)

-------modify gradnorm enabled--------
(tensor([[-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000]]), None)
grad_input norm(before): tensor([0.8660, 0.8660, 0.8660, 0.8660])
grad_input norm(after): tensor([1.0000, 1.0000, 1.0000, 1.0000])
tensor([-1.4434, -0.2887, -2.0207]) tensor([-2.3094, -2.3094, -2.3094])
tensor([ 1., -1.,  1.], requires_grad=True) tensor([0., 1., 1.], requires_grad=True)
tensor([ 1.7217, -0.8557,  2.0104], requires_grad=True) tensor([1.1547, 2.1547, 2.1547], requires_grad=True)
