In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib 
import os

import torch
from torch import nn
# from torch.utils.data import DataLoader

import PT_files.save_load as sl
from DnCNN_NP.layers  import relu, np_Conv2d

In [2]:
def pytorch_numpy_comparison(input_data,
                             pytorch_output,
                             numpy_output,
                             sample_idx):
    
    
    
    fig, ax = plt.subplots(1, 3, figsize=(24,20))
    vmin, vmax = np.percentile(input_data[sample_idx], (1,99))
    # vmin, vmax = np.percentile(pytorch_output[sample][feature_map], (1,99))


    ax[0].imshow(pytorch_output[sample_idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
    ax[0].axis('off')
    ax[0].set_title('Pytorch BatchNorm', fontsize=30)
    ax[1].imshow(input_data[sample_idx][0],vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
    ax[1].axis('off')
    ax[1].set_title('Input Sample', fontsize=30)
    ax[2].imshow(numpy_output[sample_idx][0], vmin=vmin, vmax=vmax, origin='lower', interpolation='none')
    ax[2].axis('off')
    ax[2].set_title('Numpy BatchNorm', fontsize=30)
    
def np_BatchNorm2d(x, weights_dict, prefix, epsilon=1e-5):
    """
    Computes the batch normalized version of the input.
    
    This function implements a BatchNorm2d from PyTorch. A caveat to
    remember is that this implementation is equivalent to nn.BatchNorm2d
    in `model.eval()` mode. Batch normalization renormalizes the input 
    to the layer to a more parsable data range.
    
    Parameters:
    -----------
    x: numpy.ndarray
        Input image data.
    mean: numpy.ndarray
        Running mean of the dataset, computed during training.
    var: numpy.ndarray
        Running variance of the dataset, computed during training.
    beta: numpy.ndarray
        Offset value added to the normalized output.
        (These are the biases from the model parameter dictionary).
    gamma: numpy.ndarray
        Scale value to rescale the normalzied output.
        (These are the weights from the model parameter dictionary).
    epsilon: float
        Small constant for numerical stability. 
        Default = 1e-5.
        
    Returns:
    --------
    numpy.ndarray
        Output of the batch normalization.
        
    Notes:
    ------
    The operation implemented in this function is:
    
    .. math:: \\frac{\gamma (x - \mu)}{\sigma + \epsilon} + \\beta
    
    where :math:`\mu` is the running mean of the dataset and :math:`\sigma` is
    the running variance of the dataset, both of which are computed during
    training.
    
    For more details and documentation on the PyTorch BatchNorm2d function
    that this function mimics can be found at 
    https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
    """
    gamma = weights_dict[str(prefix) + 'weight'].detach().numpy().reshape(-1, 1, 1)
    beta = weights_dict[str(prefix) + 'bias'].detach().numpy().reshape(-1, 1, 1)
    mean = weights_dict[str(prefix) + 'running_mean'].detach().numpy().reshape(-1, 1, 1)
    var = weights_dict[str(prefix) + 'running_var'].detach().numpy().reshape(-1, 1, 1)
        
        
    output = ((x - mean) / np.sqrt(var + epsilon)) * gamma + beta
    return output

In [3]:
#Load the actual data that we're working on & print the shape of this data
test_data = sl.NERSC_load('test_data_40%_6000.npy')
sample = test_data[0]
print('Shape of test set=', sample.shape)

# Create a minibatch of size 3 and cut the samples into 200x200 patch_sizes
# as well as converting it to pytorch for it to be used in the pytorch model
sample = sample[0:3, :, 1400:1600, 1400:1600]
sample_torch = torch.from_numpy(sample)

# Create the first layer of DnCNN from pytorch 
# & get the pytorch dictionary that is created to be used in the numpy version of Conv2d
# & get the output of the first layer
model = nn.Conv2d(in_channels=1, out_channels=58, kernel_size=3, stride=1, padding='same') # 1 input channel, 1 output channels, kernelsize=3, stride=1, padding=0
params = model.state_dict()
pytorch_conv_output = model(sample_torch)
# print the output shape
print('PyTorch Conv shape output =', pytorch_conv_output.shape)

model = nn.BatchNorm2d(num_features=58)#, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) # 1 output channels (everything after num features is default)
batch_params = model.state_dict()
model.eval()
with torch.no_grad():
    pytorch_batch_output = model(pytorch_conv_output)

print('PyTorch Batch shape output =', pytorch_batch_output.shape)

# Lastly let's visualize the first samples first feature map
# plt.imshow(output[0][0].detach().numpy(), origin='lower')

# Numpy version using the pytorchs weights 
numpy_conv_output = np_Conv2d(input_data=sample,
                         weights_dict=params,
                         padding='same')
print('Numpy Conv output =',numpy_conv_output.shape)


# Numpy version using the pytorchs weights 
numpy_batch_output = np_BatchNorm2d(x=numpy_conv_output,
                         weights_dict=batch_params)
print('Numpy Batch output =',numpy_batch_output.shape)


Shape of test set= (108, 1, 6000, 6000)
PyTorch Conv shape output = torch.Size([3, 58, 200, 200])
PyTorch Batch shape output = torch.Size([3, 58, 200, 200])
Numpy Conv output = (3, 58, 200, 200)
Numpy Batch output = (3, 58, 200, 200)


In [4]:
#np.isclose(numpy_conv_output, pytorch_conv_output.detach().cpu())

**Note for np.allclose**

If the following equation is element-wise True, then allclose returns True.

`np.allclose(a,b)`

`absolute(a - b) <= (atol + rtol * absolute(b))`

In [5]:
atol = 1e-07
rtol = 1e-05
# print(np.abs(numpy_conv_output - pytorch_conv_output.detach().cpu().numpy()) <= (atol + rtol * np.abs(pytorch_conv_output.detach().cpu().numpy())))
print(np.abs(numpy_conv_output - pytorch_conv_output.detach().cpu().numpy())[0, 0, 0, :10])
print((atol + rtol * np.abs(pytorch_conv_output.detach().cpu().numpy()))[0, 0, 0, :10])
                                                                   

[0.00000000e+00 1.49011612e-08 1.49011612e-08 0.00000000e+00
 1.49011612e-08 1.49011612e-08 1.49011612e-08 1.49011612e-08
 1.49011612e-08 1.49011612e-08]
[1.3108317e-06 1.5506392e-06 1.6083989e-06 1.7209094e-06 1.8190431e-06
 1.8757454e-06 1.9428971e-06 1.9566774e-06 1.9642291e-06 1.9737972e-06]


In [6]:
np.allclose(numpy_conv_output, pytorch_conv_output.detach().cpu(), rtol=1e-05, atol=1e-07)

True

In [7]:
np.allclose(numpy_conv_output, pytorch_conv_output.detach().cpu(), rtol=1e-05, atol=0.71e-07)

False

In [8]:
np.allclose(numpy_batch_output, pytorch_batch_output.detach().cpu(), rtol=1e-05, atol=1e-07)

True

In [9]:
numpy_conv_output[0]

array([[[ 0.12108318,  0.14506395,  0.15083991, ...,  0.13537174,
          0.13633195,  0.14160129],
        [ 0.10593177,  0.25874233,  0.3043707 , ...,  0.22633338,
          0.21385492,  0.34497994],
        [ 0.11713965,  0.28264287,  0.3246758 , ...,  0.23837422,
          0.24536915,  0.34443921],
        ...,
        [ 0.16251622,  0.37535477,  0.36967564, ...,  0.1648812 ,
          0.12990579,  0.24573709],
        [ 0.16811386,  0.39523226,  0.3938899 , ...,  0.15333824,
          0.20289654,  0.23135443],
        [ 0.18527606,  0.32578528,  0.3317979 , ...,  0.16770712,
          0.12977284,  0.1615704 ]],

       [[-0.08950314, -0.00235143,  0.0086441 , ...,  0.06907723,
          0.05905644,  0.19233812],
        [-0.193326  , -0.05886739, -0.05855225, ...,  0.03289407,
          0.0086733 ,  0.28571102],
        [-0.22521693, -0.07597569, -0.08498289, ...,  0.01893511,
          0.01440035,  0.29060641],
        ...,
        [-0.25254905, -0.01496804, -0.02354488, ...,  

In [10]:
numpy_conv_output[-1][-2] * batch_params['weight'].detach().numpy()

ValueError: operands could not be broadcast together with shapes (200,200) (58,) 

In [None]:
numpy_conv_output.shape

In [None]:
pytorch_numpy_comparison(input_data=sample,
                             pytorch_output=pytorch_conv_output,
                             numpy_outputnumpy_conv_output,
                             sample_idx=2)

In [None]:
pytorch_numpy_comparison(input_data=sample,
                             pytorch_output=pytorch_batch_output,
                             numpy_output=numpy_batch_output,
                             sample_idx=2)