<a href="https://colab.research.google.com/github/DingfanChen/GS-WGAN/blob/main/source/sanity_check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
import random
import torch.optim as optim

In [2]:
### Toy single layer
class toy_single(nn.Module):
    def __init__(self, operation, constant):
        super(toy_single, self).__init__()
        self.operation = operation
        self.constant = constant

    def forward(self, input):
        if self.operation == 'add':
            self.output = input + self.constant

        elif self.operation == 'mul': # pointwise-mulipication
            self.output = input * self.constant
        return self.output

### hook functions

In [3]:
def dynamic_hook_module(module,grad_input,grad_output):
    global global_backward_hook
    return global_backward_hook(module, grad_input, grad_output)

In [4]:
def printgrad_module(module, grad_input, grad_output):
    print('-------print grad enabled--------')
    print(grad_input)

In [5]:
def dummygrad_module(module, grad_input, grad_output):
    pass

In [6]:
def modifygradnorm_module(module,grad_input, grad_output):
    global target_norm
    grad_input_shape = grad_input[0].size()
    bs = grad_input[0].size()[0]
    grad_input_ = grad_input[0].view(bs,-1)
    grad_input_norm = torch.norm(grad_input_,p=2,dim=-1)
    clip_coef = target_norm / (grad_input_norm + 1e-6)  
    clip_coef = clip_coef.unsqueeze(-1)
    grad_input_ = clip_coef*grad_input_
    print('-------modify gradnorm enabled--------')
    print(grad_input)
    grad_input = (grad_input_.view(grad_input_shape),grad_input[1])
    print('grad_input norm:', torch.norm(grad_input_,p=2,dim=-1))
    return tuple(grad_input)

### Example (dummy hook)

In [7]:
## construct a toy 'Generator'
z = torch.tensor([[1,1,1.5],[2,2,2.5],[-1,-1,-1.5]])  # dim0(sample idx), dim1(feature idx)
weights_1 = torch.tensor([1.,-1.,1.])
bias_1 = torch.tensor([0,1.,1.])
z.requires_grad = True
weights_1.requires_grad=True
bias_1.requires_grad=True
G_layer1 = toy_single('mul',weights_1)
G_layer2 = toy_single('add',bias_1)
G_layer1.register_backward_hook(dummygrad_module)
G_layer2.register_backward_hook(dummygrad_module)
G_out = G_layer2(G_layer1(z))
optimizerG = optim.SGD([weights_1,bias_1], lr=0.5)

## construct a toy 'Discriminator'
d_param_1 = torch.tensor(2.)
d_param_2 = torch.tensor([-2.,2.,1.])
d_param_3 = torch.tensor(-1.)
D_layer1 = toy_single('add', d_param_1)
D_layer2 = toy_single('mul', d_param_2)
D_layer3 = toy_single('mul', d_param_3)
D_layer2.register_backward_hook(dummygrad_module)
D_layer3.register_backward_hook(dummygrad_module)
D_layer1.register_backward_hook(dynamic_hook_module)
D_out = D_layer3(D_layer2(D_layer1(G_out)))
global_backward_hook = printgrad_module



In [8]:
### check d_param status
print(d_param_1.requires_grad)
print(d_param_2.requires_grad)
print(d_param_3.requires_grad)

False
False
False


In [9]:
### check gradients on leaf nodes
D_out.mean().backward()
print('============grad============')
print(weights_1.grad, bias_1.grad)
print('============value(before)===========')
print(weights_1,bias_1)
optimizerG.step()
print('============value(after)===========')
print(weights_1,bias_1)

-------print grad enabled--------
(tensor([[ 0.2222, -0.2222, -0.1111],
        [ 0.2222, -0.2222, -0.1111],
        [ 0.2222, -0.2222, -0.1111]]), None)
tensor([ 0.4444, -0.4444, -0.2778]) tensor([ 0.6667, -0.6667, -0.3333])
tensor([ 1., -1.,  1.], requires_grad=True) tensor([0., 1., 1.], requires_grad=True)
tensor([ 0.7778, -0.7778,  1.1389], requires_grad=True) tensor([-0.3333,  1.3333,  1.1667], requires_grad=True)


### Example (modify grad_norm)

In [10]:
## construct a toy 'Generator'
z = torch.tensor([[1,1,1.5],[2,2,2.5],[-1,-1,-1.5]])
weights_1 = torch.tensor([1.,-1.,1.])
bias_1 = torch.tensor([0,1.,1.])
z.requires_grad = True
weights_1.requires_grad=True
bias_1.requires_grad=True
G_layer1 = toy_single('mul',weights_1)
G_layer2 = toy_single('add',bias_1)
G_layer1.register_backward_hook(dummygrad_module)
G_layer2.register_backward_hook(dummygrad_module)
G_out = G_layer2(G_layer1(z))  # G_out is the toy 'generated samples'
optimizerG = optim.SGD([weights_1,bias_1], lr=0.5)

## construct a toy 'Discriminator'
d_param_1 = torch.tensor(2.)
d_param_2 = torch.tensor([-2.,2.,1.])
d_param_3 = torch.tensor(-1.)
D_layer1 = toy_single('add', d_param_1)
D_layer2 = toy_single('mul', d_param_2)
D_layer3 = toy_single('mul', d_param_3)
D_layer2.register_backward_hook(dummygrad_module)
D_layer3.register_backward_hook(dummygrad_module)
D_layer1.register_backward_hook(dynamic_hook_module)
D_out = D_layer3(D_layer2(D_layer1(G_out))) # D_out mimics the score

## modify the grad_norm
target_norm = 1.
global_backward_hook = modifygradnorm_module

In [11]:
### check d_param status
print(d_param_1.requires_grad)
print(d_param_2.requires_grad)
print(d_param_3.requires_grad)

False
False
False


In [12]:
### check gradients on leaf nodes
D_out.mean().backward()
print('============grad============')
print(weights_1.grad, bias_1.grad)
print('============value(before)===========')
print(weights_1,bias_1)
optimizerG.step()
print('============value(after)===========')
print(weights_1,bias_1)

-------modify gradnorm enabled--------
(tensor([[ 0.2222, -0.2222, -0.1111],
        [ 0.2222, -0.2222, -0.1111],
        [ 0.2222, -0.2222, -0.1111]]), None)
grad_input norm: tensor([1.0000, 1.0000, 1.0000])
tensor([ 1.3333, -1.3333, -0.8333]) tensor([ 2.0000, -2.0000, -1.0000])
tensor([ 1., -1.,  1.], requires_grad=True) tensor([0., 1., 1.], requires_grad=True)
tensor([ 0.3333, -0.3333,  1.4167], requires_grad=True) tensor([-1.0000,  2.0000,  1.5000], requires_grad=True)
