In [1]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

import sys

sys.path.append("../stochman")

from stochman import nnj

def compute_output_edge(input_edge, kernel_size=1,padding=0,stride=1,dilation=1):
    output_edge = ( input_edge - dilation*(kernel_size-1) + 2*padding -1 )/stride +1 #output edge can be not-integer if stride!=1
    return int(output_edge)

In [2]:
batch_size = 2
IN_c, OUT_c = 1, 1 #number of channels (for input and output)
IN_h, IN_w = 28, 28 #number of pixels per input edges
kernel_h, kernel_w = 3,3
padding_h, padding_w = 1,1

# some parameters to define tests
OUT_h = compute_output_edge(IN_h, kernel_size=kernel_h, padding=padding_h, stride=1, dilation=1)
OUT_w = compute_output_edge(IN_w, kernel_size=kernel_w, padding=padding_w, stride=1, dilation=1)
IN_size, OUT_size = IN_c*IN_h*IN_w, OUT_c*OUT_h*OUT_w

In [3]:
conv = nnj.Conv2d(IN_c, 
                  OUT_c, 
                  kernel_size=(kernel_h, kernel_w), 
                  padding=(padding_h,padding_w), 
                  bias=None)

feat_in = torch.zeros(batch_size, IN_c, IN_h, IN_w)
feat_out = conv(feat_in)

# Right and Left multiplications (J wrt to input)

In [4]:
# define a random tmp matrix
tmp = torch.randint(0, 10, (batch_size, OUT_size, OUT_size)).type(torch.float)

# compute Jt*tmp*J defining the full jacobians (correct for sure but NOT memory efficient)
slow_Jt_tmp_J = torch.einsum(
        'Bji,Bjk,Bkq->Biq',
        conv._jacobian_wrt_input(feat_in, feat_out),
        tmp,
        conv._jacobian_wrt_input(feat_in, feat_out)
    )

# compute (Jt*tmp)*J efficiently
fast_Jt_tmp_J_1 = conv._jacobian_wrt_input_mult_left(feat_in, feat_out,
    conv._jacobian_wrt_input_T_mult_right(feat_in, feat_out, tmp)
)

# compute Jt*(tmp*J) efficiently
fast_Jt_tmp_J_2 = conv._jacobian_wrt_input_T_mult_right(feat_in, feat_out,
    conv._jacobian_wrt_input_mult_left(feat_in, feat_out, tmp)
)

# check if the shape are the same
print(slow_Jt_tmp_J.shape, fast_Jt_tmp_J_1.shape, fast_Jt_tmp_J_2.shape)

# check if the elements are the same
assert torch.abs(torch.max(slow_Jt_tmp_J - fast_Jt_tmp_J_1)) < 1e-5
assert torch.abs(torch.max(slow_Jt_tmp_J - fast_Jt_tmp_J_2)) < 1e-5

torch.Size([2, 784, 784]) torch.Size([2, 784, 784]) torch.Size([2, 784, 784])


# Right and Left multiplications (J wrt to weight)

In [5]:
# compute Jt*tmp defining the full jacobians (correct for sure but NOT memory efficient)
slow_Jt_tmp = torch.einsum('Bji,Bjq->Biq', conv._jacobian_wrt_weight(feat_in, feat_out),tmp)

if batch_size==1:
    # compute Jt*tmp efficiently
    fast_Jt_tmp = conv._jacobian_wrt_weight_T_mult_right(feat_in, feat_out, tmp[0], use_less_memory=False)

    # check if the shape are the same
    print(slow_Jt_tmp.shape, fast_Jt_tmp.shape)
    # check if the elements are the same
    assert torch.max(torch.abs(slow_Jt_tmp_J - fast_Jt_tmp_J_1)) < 1e-5

else:
    slow_Jt_tmp_sum = torch.zeros(slow_Jt_tmp.shape[1:])
    for s in slow_Jt_tmp:
        slow_Jt_tmp_sum =+ s

    # compute Jt*tmp efficiently
    fast_Jt_tmp = conv._jacobian_wrt_weight_T_mult_right(feat_in, feat_out, tmp, use_less_memory=True)

    # check if the shape are the same
    print(slow_Jt_tmp.shape, fast_Jt_tmp.shape)
    # check if the elements are the same
    assert torch.max(torch.abs(slow_Jt_tmp - fast_Jt_tmp)) < 1e-5

torch.Size([2, 9, 784]) torch.Size([2, 9, 784])


# DIAGONAL APPROXIMATION wrt input

In [6]:
diag_tmp = torch.diagonal(tmp, dim1=1, dim2=2)

# SLOW METHOD
# upscale tmp diagonal to full matrix
tmp_simple = torch.diag_embed(diag_tmp)
# compute J^T * tmp * J
slow_Jt_tmp_J = torch.einsum('Bji,Bjk,Bkq->Biq', conv._jacobian_wrt_input(feat_in, feat_out), tmp_simple, conv._jacobian_wrt_input(feat_in, feat_out))
# take only the diagonal (and discard all other elements)
slow_diag_Jt_tmp_J = torch.diagonal(slow_Jt_tmp_J, dim1=1, dim2=2)

# LESS SLOW METHOD
# # compute J^T * tmp * J
lessslow_Jt_tmp_J = conv._jacobian_wrt_input_sandwich_full_to_full(feat_in, feat_out, tmp_simple)
# take only the diagonal (and discard all other elements)
lessslow_diag_Jt_tmp_J = torch.diagonal(lessslow_Jt_tmp_J, dim1=1, dim2=2)

# FAST METHOD
fast_diag_Jt_tmp_J  = conv._jacobian_wrt_input_sandwich_diag_to_diag(feat_in, feat_out, diag_tmp)

##################################################################################
print(slow_diag_Jt_tmp_J.shape, lessslow_diag_Jt_tmp_J.shape, fast_diag_Jt_tmp_J.shape)
assert torch.max(torch.abs(slow_diag_Jt_tmp_J - lessslow_diag_Jt_tmp_J)) < 1e-5
assert torch.max(torch.abs(slow_diag_Jt_tmp_J - fast_diag_Jt_tmp_J)) < 1e-5

torch.Size([2, 784]) torch.Size([2, 784]) torch.Size([2, 784])


# Diagonal approximation wrt weights


In [7]:
diag_tmp = torch.diagonal(tmp, dim1=1, dim2=2)

# SLOW METHOD
# upscale tmp diagonal to full matrix
tmp_simple = torch.diag_embed(diag_tmp)
# compute J^T * tmp * J
slow_Jt_tmp_J = torch.einsum('Bji,Bjk,Bkq->Biq', conv._jacobian_wrt_weight(feat_in, feat_out), tmp_simple, conv._jacobian_wrt_weight(feat_in, feat_out))
# take only the diagonal (and discard all other elements)
slow_diag_Jt_tmp_J = torch.diagonal(slow_Jt_tmp_J, dim1=1, dim2=2)

# LESS SLOW METHOD
# # compute J^T * tmp * J
lessslow_Jt_tmp_J = conv._jacobian_wrt_weight_sandwich_full_to_full(feat_in, feat_out, tmp_simple)
# take only the diagonal (and discard all other elements)
lessslow_diag_Jt_tmp_J = torch.diagonal(lessslow_Jt_tmp_J, dim1=1, dim2=2)

# FAST METHOD
fast_diag_Jt_tmp_J  = conv._jacobian_wrt_weight_sandwich_diag_to_diag(feat_in, feat_out, diag_tmp)

##################################################################################
print(slow_diag_Jt_tmp_J.shape, lessslow_diag_Jt_tmp_J.shape, fast_diag_Jt_tmp_J.shape)
assert torch.max(torch.abs(slow_diag_Jt_tmp_J - lessslow_diag_Jt_tmp_J)) < 1e-5
assert torch.max(torch.abs(slow_diag_Jt_tmp_J - fast_diag_Jt_tmp_J)) < 1e-5

torch.Size([2, 9]) torch.Size([2, 9]) torch.Size([2, 9])
