In [None]:
import numpy as np
import re
import matplotlib.pyplot as plt

**Loading and Parsing CNN Coefficients**

This notebook loads the pre-trained CNN coefficients from the text file and reshapes them into usable arrays for convolution operations.

CNN Architecture (for 3x3 kernel):
- Conv Layer 1: 64 filters of size 3×3×3 (3×3 spatial, 3 input channels RGB)
- Conv Layer 2: 32 filters of size 3×3×64 (3×3 spatial, 64 input channels)
- Conv Layer 3: 20 filters of size 3x3x32 (3×3 spatial, 32 input channels)
- Fully Connected Layer 2: 10 filters of size 180

In [None]:
def load_cnn_weights(filename, kernel_size=3):
    """Parse CNN weight tensors from the text file and reshape them."""
    with open(filename, 'r') as f:
        content = f.read()

    tensors = {}
    sections = content.split('tensor_name:')[1:]

    for section in sections:
        lines = section.strip().split('\n', 1)
        tensor_name = lines[0].strip()

        if len(lines) > 1:
            array_text = lines[1]
            start_idx = array_text.find('[')
            end_idx = array_text.rfind(']')
            if start_idx != -1 and end_idx != -1:
                numbers_text = array_text[start_idx+1:end_idx]
                numbers = re.findall(r'[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?', numbers_text)
                tensor_data = np.array([float(x) for x in numbers], dtype=np.float32)
                tensors[tensor_name] = tensor_data

    tensors['conv1/weights'] = tensors['conv1/weights'].reshape(kernel_size, kernel_size, 3, 64)
    tensors['conv2/weights'] = tensors['conv2/weights'].reshape(kernel_size, kernel_size, 64, 32)
    tensors['conv3/weights'] = tensors['conv3/weights'].reshape(kernel_size, kernel_size, 32, 20)
    tensors['local3/weights'] = tensors['local3/weights'].reshape(180, 10)

    return tensors

In [None]:
tensors_3x3 = load_cnn_weights('../dataset/CNN_coeff_3x3.txt')
tensors_5x5 = load_cnn_weights('../dataset/CNN_coeff_5x5.txt', kernel_size=5)

**Understanding the Weight Structure**

Conv Layer 1:
- Shape: `(3, 3, 3, 64)`
- Interpretation: 
  - 64 different filters/kernels
  - Each filter is 3×3 pixels
  - Each filter processes all 3 RGB channels
  - Each filter produces 1 feature map

Accessing individual kernels:
- `conv1_w[:, :, :, i]` gives you the i-th filter (3×3×3 array)
- `conv1_w[:, :, c, i]` gives you the 3×3 kernel for channel c of filter i

In [None]:
kernel_size = 3
num_filters_to_show = 8

if kernel_size == 5:
    conv1_w = tensors_5x5['conv1/weights']
else:
    conv1_w = tensors_3x3['conv1/weights']

fig, axes = plt.subplots(num_filters_to_show, 3, figsize=(8*kernel_size/3, kernel_size*num_filters_to_show))

for i in range(num_filters_to_show):
    for c in range(3):  # R, G, B channels
        kernel = conv1_w[:, :, c, i]  # 3x3 or 5x5 kernel for channel c, filter i
        
        im = axes[i, c].imshow(kernel, cmap='RdBu', vmin=-1, vmax=1)
        axes[i, c].set_title(f"Filter {i}, {'RGB'[c]} channel")
        axes[i, c].axis('off')
        
        # Add colorbar
        plt.colorbar(im, ax=axes[i, c], fraction=0.046)
        
        # Annotate with values
        for row in range(kernel_size):
            for col in range(kernel_size):
                axes[i, c].text(col, row, f'{kernel[row, col]:.2f}',
                              ha='center', va='center', color='black', fontsize=8)

plt.tight_layout()
plt.show()

print(f"\nShowing first {num_filters_to_show} filters out of {conv1_w.shape[3]} total filters in Conv1")