In [1]:
import torch
import torch.nn as nn
import numpy as np

import sys
sys.path.append('../')
import layers
import patterns

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Create innvestigate network with dilation layer

In [2]:
import innvestigate
import keras

Using TensorFlow backend.


In [3]:
inp_size = (7,7)
ks = (3,3)
ks_prod = np.product(ks)

In [4]:
inp_np = np.arange(-5,inp_size[0]*inp_size[1]-5,1).reshape(1,inp_size[0],inp_size[1],1).astype(float)
inp_2 = np.random.rand(10,inp_size[0],inp_size[1],1)*5

In [5]:
# create one layer network with a convolutional layer with dilation > 1
keras_conv = keras.models.Sequential([
        keras.layers.Conv2D(filters=2, kernel_size=ks, 
                            dilation_rate=(2,2), input_shape=(inp_size[0],inp_size[1],1),
                            activation=None, use_bias=False)])

# created weights such that there are both positive and negative results
kernel_weights = np.arange(-10,ks_prod*2-10,1).reshape(ks[0],ks[1],1,2)
kernel_weights[:,:,0,0] -= 5
keras_conv.set_weights([kernel_weights])

In [6]:
analyzer = innvestigate.create_analyzer('pattern.net', keras_conv)
analyzer.fit(np.concatenate([inp_np,inp_2]))

In [7]:
patterns_keras = analyzer._patterns  # relu patterns are in patterns_keras[0]
stats_keras = analyzer.computer._pattern_instances['relu'][0].stats_dict

# Pytorch dilation layer

In [8]:
# check if kernel with zeros in between same result as with dilation=2 --> tested, not the same!
# create pytorch input data
inp_torch = torch.FloatTensor(inp_np).permute(0,3,1,2)
inp_2_torch = torch.FloatTensor(inp_2).permute(0,3,1,2)

# normal conv layer with dilation (2,2)
conv_dil = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=ks, dilation=2, bias=False)

# use same weights as innvestigate layer
k3 = torch.FloatTensor(kernel_weights).permute(3,2,0,1)
conv_dil.weight.data = k3

In [9]:
# use this function to create a dilation kernel with zeros from the innvestigate kernel
def dilation_kernel(kernel, dilation):
    ks = kernel.shape
    dil_kernel = torch.zeros(ks[0], ks[1], ks[2]+(dilation[0]-1)*(ks[2]-1), 
                             ks[3]+(dilation[1]-1)*(ks[3]-1))

    
    locs_x = np.arange(0,dil_kernel.shape[2], dilation[0]) # right dimension?
    locs_y = np.arange(0,dil_kernel.shape[3], dilation[1])
    inds_x, inds_y = np.meshgrid(locs_x, locs_y)
    
    dil_kernel[:,:,inds_x, inds_y] = kernel.permute(0,1,3,2)
    
    return dil_kernel


# sets values in pattern to zero according to dilation pattern
def dilation_pattern(pattern, dilation):
    
    locs_x = np.arange(0,pattern.shape[2], dilation[0]) # right dimesion?
    locs_y = np.arange(0,pattern.shape[3], dilation[1])
    inds_x, inds_y = np.meshgrid(locs_x, locs_y)
    
    mask = torch.ones(pattern.shape)
    mask[:,:,inds_x, inds_y] = 0
    
    pattern[mask == 1] = 0
    

def dilation_mask(kernel_size, dilation):
    
    mask = torch.zeros(kernel_size[0], kernel_size[1], 
                       kernel_size[2]+(dilation[0]-1)*(kernel_size[2]-1),
                       kernel_size[3]+(dilation[1]-1)*(kernel_size[3]-1))
    
    locs_x = np.arange(0, mask.shape[2], dilation[0])
    locs_y = np.arange(0, mask.shape[3], dilation[1])
    inds_x, inds_y = np.meshgrid(locs_x, locs_y)
    
    mask[:,:,inds_x, inds_y] = 1
    
    return mask

In [10]:
# create pattern layer
pattern_layer_dil = layers.PatternConv2d(conv_dil)
conv_out_dil = pattern_layer_dil(inp_torch)
pattern_layer_dil.compute_statistics(inp_torch, conv_out_dil)
pattern_layer_dil.compute_statistics(inp_2_torch, pattern_layer_dil(inp_2_torch))

In [11]:
print('Output of dil layer and its pattern layer the same:', torch.equal(pattern_layer_dil(inp_2_torch), 
                                                                         conv_dil(inp_2_torch)))

Output of dil layer and its pattern layer the same: True


## Compare pytorch and innvestigate prediction

In [12]:
out_dil = pattern_layer_dil(inp_2_torch)
out_inn = keras_conv.predict(inp_2)

print(out_dil.shape, torch.FloatTensor(out_inn).permute(0,3,1,2).shape)
print(torch.equal(out_dil, torch.FloatTensor(out_inn).permute(0,3,1,2)))
if not torch.equal(out_dil, torch.FloatTensor(out_inn).permute(0,3,1,2)):
    print('Maximum difference:', 
           torch.max(out_dil - torch.FloatTensor(out_inn).permute(0,3,1,2)).detach())

torch.Size([10, 2, 3, 3]) torch.Size([10, 2, 3, 3])
False
Maximum difference: tensor(1.5259e-05)


## Compare pytorch and innvestigate statistics

In [13]:
keras_keys = ['cnt_x', 'e_y', 'e_x', 'e_xy']
torch_keys = ['cnt', 'e_y', 'e_x', 'e_xy']

k3_inds = torch.tensor([0,2,4,10,12,14,20,22,24], dtype=torch.long)

for i in range(len(keras_keys)):
    print(torch_keys[i])
    
    if torch_keys[i] in ['e_x', 'e_xy']:
        stat_inn = torch.FloatTensor(stats_keras[keras_keys[i]])
        stat_dil = pattern_layer_dil.statistics['positive'][torch_keys[i]].detach().permute(1,0)
        print('Innvestigate and standard dilation layer the same:', torch.equal(stat_inn, stat_dil))
        if not torch.equal(stat_inn, stat_dil):
            print('Maximum difference:', torch.max(stat_inn - stat_dil))
            
    elif torch_keys[i] == 'e_y':
        stat_inn = torch.FloatTensor(stats_keras[keras_keys[i]])
        stat_dil = pattern_layer_dil.statistics['positive'][torch_keys[i]].detach()[0]

        print('Innvestigate and standard dilation layer the same:', torch.equal(stat_inn,stat_dil))
        if not torch.equal(stat_inn,stat_dil):
            print('Maximum difference:', torch.max(stat_inn - stat_dil))
            
    else:
        stat_inn = torch.FloatTensor(stats_keras[keras_keys[i]])
        stat_dil = pattern_layer_dil.statistics['positive'][torch_keys[i]].detach()

        print('Innvestigate and standard dilation layer the same:',torch.equal(stat_inn, stat_dil))
        if not torch.equal(stat_inn, stat_dil):
            print('Maximum difference:', torch.max(stat_inn - stat_dil))
    print()


cnt
Innvestigate and standard dilation layer the same: True

e_y
Innvestigate and standard dilation layer the same: False
Maximum difference: tensor(0.)

e_x
Innvestigate and standard dilation layer the same: False
Maximum difference: tensor(1.1921e-07)

e_xy
Innvestigate and standard dilation layer the same: False
Maximum difference: tensor(0.)



## Compare patterns

In [14]:
pattern_layer_dil.compute_patterns()
pattern_layer_dil.set_patterns()

In [15]:
# first convert innvestigate pattern to pytorch tensor, then permute it
pattern_inn = torch.FloatTensor(patterns_keras[0]).permute(2,3,0,1)

def revert_tensor(tensor, axis=0):
    idx = [i for i in range(tensor.size(axis) - 1, -1, -1)]
    idx = torch.LongTensor(idx)
    return tensor.index_select(axis, idx)

pattern_inn = revert_tensor(revert_tensor(pattern_inn,2),3)

print('Patterns innvestigate and standard dilation layer the same:', torch.equal(pattern_inn,
                                                                                 pattern_layer_dil.patterns['A_plus']))
if not torch.equal(pattern_inn, pattern_layer_dil.patterns['A_plus']):
    print('Maximum difference:', torch.max(torch.abs(pattern_inn - pattern_layer_dil.patterns['A_plus'].detach())))

Patterns innvestigate and standard dilation layer the same: False
Maximum difference: tensor(1.4901e-08)


## Compare signal

In [16]:
signal_inn = torch.FloatTensor(analyzer.analyze(inp_np).transpose(0,3,1,2))
signal_dil = pattern_layer_dil.backward(conv_out_dil)

print('Signal innvestigate and standard dilation layer the same:', torch.equal(signal_inn, signal_dil))
# if not torch.equal(signal_inn, signal_dil):
#     print('Maximum difference:', torch.max(signal_inn - signal_dil.detach()))    

print()
print()
# print(signal_inn)
print()
print(signal_dil.detach())

Signal innvestigate and standard dilation layer the same: False



tensor([[[[0.0229, 0.0225, 0.0668, 0.0439, 0.1100, 0.0656, 0.0642],
          [0.0196, 0.0191, 0.0568, 0.0373, 0.0936, 0.0558, 0.0544],
          [0.1927, 0.1885, 0.4149, 0.2254, 0.4885, 0.2621, 0.2561],
          [0.1506, 0.1469, 0.3129, 0.1655, 0.3497, 0.1838, 0.1791],
          [0.4551, 0.4445, 0.9267, 0.4813, 1.0000, 0.5178, 0.5053],
          [0.2819, 0.2750, 0.5687, 0.2932, 0.6052, 0.3115, 0.3036],
          [0.2334, 0.2265, 0.4684, 0.2415, 0.4985, 0.2565, 0.2487]]]])


In [20]:
print('INNvestigate')
print()
print('Input for backward pass')
print(keras_conv.predict(inp_np)[0,:,:,0])
print(keras_conv.predict(inp_np)[0,:,:,1])
print(keras_conv.predict(inp_np).shape)
print('Patterns')
print(patterns_keras[0].transpose(2,3,0,1))#[:,:,0,1])
print(patterns_keras[0].shape)
print('Output of backward pass')
print(signal_inn[0,0].numpy())
print(analyzer.analyze(inp_np).shape)

INNvestigate

Input for backward pass
[[ -165.  -228.  -291.]
 [ -606.  -669.  -732.]
 [-1047. -1110. -1173.]]
[[ 429.  420.  411.]
 [ 366.  357.  348.]
 [ 303.  294.  285.]]
(1, 3, 3, 2)
Patterns
[[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[ 0.00571974  0.01117426  0.01671792]
   [ 0.04401109  0.04959474  0.05506112]
   [ 0.08239748  0.0878601   0.0933373 ]]]]
(3, 3, 1, 2)
Output of backward pass
[[ 0.06128033  0.          0.1197191   0.          0.17911293  0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.        ]
 [ 0.47152737  0.          0.5313496   0.          0.58991551  0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.        ]
 [ 0.88279259  0.          0.94131821  0.          1.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.          0.   

In [18]:
# pattern_layer_dil.backward_layer.weight

In [19]:
conv_test = nn.Conv2d(2,1, [3,3],padding=(4,4),dilation=(2,2), stride=1,bias=False) # (6,6) and (3,3)
conv_test.weight.data = pattern_layer_dil.patterns['A_plus']
out = conv_test(conv_out_dil).detach()
print(out.shape)
out /= torch.max(torch.abs(out))
print(out)
# print(pattern_layer_dil.patterns['A_plus'].shape)
# print(conv_out_dil.shape)
print()
back_res_0 = conv_out_dil[0,0,0,0]*pattern_layer_dil.patterns['A_plus'][0,0]
back_res_1 = conv_out_dil[0,1,0,0]*pattern_layer_dil.patterns['A_plus'][0,1]
sum_back_res = back_res_0 + back_res_1
# print(back_res_0)
# print()
# print(back_res_1)
print(sum_back_res.shape)
# print(back_res_0 + back_res_1)
# print()
print(revert_tensor(revert_tensor(sum_back_res / torch.max(torch.abs(sum_back_res)),0),1))

torch.Size([1, 1, 7, 7])
tensor([[[[0.0229, 0.0225, 0.0668, 0.0439, 0.1100, 0.0656, 0.0642],
          [0.0196, 0.0191, 0.0568, 0.0373, 0.0936, 0.0558, 0.0544],
          [0.1927, 0.1885, 0.4149, 0.2254, 0.4885, 0.2621, 0.2561],
          [0.1506, 0.1469, 0.3129, 0.1655, 0.3497, 0.1838, 0.1791],
          [0.4551, 0.4445, 0.9267, 0.4813, 1.0000, 0.5178, 0.5053],
          [0.2819, 0.2750, 0.5687, 0.2932, 0.6052, 0.3115, 0.3036],
          [0.2334, 0.2265, 0.4684, 0.2415, 0.4985, 0.2565, 0.2487]]]])

torch.Size([3, 3])
tensor([[0.0613, 0.1197, 0.1791],
        [0.4715, 0.5313, 0.5899],
        [0.8828, 0.9413, 1.0000]], grad_fn=<IndexSelectBackward>)
