In [1]:
import numpy as np
import os
import glob
from typing import List, Tuple, Optional
import re

def read_snn_memory_dumps(
    dump_directory: str = ".",
    bits_per_neuron: int = 9,
    out_channels: int = 4,
    img_width: int = 8,
    img_height: int = 8,
    dump_pattern: str = "feature_map_mem_*.mem"
) -> Tuple[np.ndarray, List[str]]:
    """
    Read SNN memory dump files and return organized membrane potential data.
    
    Parameters:
    -----------
    dump_directory : str
        Directory containing the memory dump files
    bits_per_neuron : int
        Number of bits per neuron (should match BITS_PER_NEURON from testbench)
    out_channels : int
        Number of output channels (should match OUT_CHANNELS from testbench)
    img_width : int
        Image width (should match IMG_WIDTH from testbench)
    img_height : int
        Image height (should match IMG_HEIGHT from testbench)
    dump_pattern : str
        Filename pattern for memory dumps
        
    Returns:
    --------
    membrane_potentials : np.ndarray
        4D array with shape (num_dumps, img_height, img_width, out_channels)
        Contains signed membrane potential values
    dump_files : List[str]
        List of dump filenames in order
    """
    
    # Find all memory dump files
    search_pattern = os.path.join(dump_directory, dump_pattern)
    dump_files = glob.glob(search_pattern)
    
    if not dump_files:
        raise FileNotFoundError(f"No memory dump files found matching pattern: {search_pattern}")
    
    # Sort files by dump number (extract number from filename)
    def extract_dump_number(filename):
        match = re.search(r'feature_map_mem_(\d+)\.mem', os.path.basename(filename))
        return int(match.group(1)) if match else 0
    
    dump_files.sort(key=extract_dump_number)
    
    print(f"Found {len(dump_files)} memory dump files")
    
    # Calculate memory organization
    total_memory_locations = img_width * img_height
    memory_width_bits = out_channels * bits_per_neuron
    
    print(f"Memory organization:")
    print(f"  - Total locations: {total_memory_locations}")
    print(f"  - Width per location: {memory_width_bits} bits")
    print(f"  - Bits per neuron: {bits_per_neuron}")
    print(f"  - Output channels: {out_channels}")
    print(f"  - Image size: {img_width}x{img_height}")
    
    # Initialize result array
    membrane_potentials = np.zeros((len(dump_files), img_height, img_width, out_channels), dtype=np.int32)
    
    # Process each dump file
    for dump_idx, dump_file in enumerate(dump_files):
        print(f"\nProcessing dump {dump_idx}: {os.path.basename(dump_file)}")
        
        try:
            # Read binary memory dump
            with open(dump_file, 'r') as f:
                lines = f.readlines()
            
            # Remove whitespace and empty lines
            memory_data = []
            for line in lines:
                line = line.strip()
                if line and not line.startswith('//'):  # Skip comments
                    memory_data.append(line)
            
            if len(memory_data) != total_memory_locations:
                print(f"Warning: Expected {total_memory_locations} memory locations, got {len(memory_data)}")
            
            # Process each memory location
            for addr, mem_line in enumerate(memory_data):
                if addr >= total_memory_locations:
                    break
                    
                # Convert binary string to integer
                if len(mem_line) != memory_width_bits:
                    print(f"Warning: Memory location {addr} has {len(mem_line)} bits, expected {memory_width_bits}")
                    continue
                
                # Calculate 2D coordinates from linear address
                # Assuming row-major addressing: addr = y * img_width + x
                y = addr // img_width
                x = addr % img_width
                
                if y >= img_height or x >= img_width:
                    print(f"Warning: Address {addr} maps to invalid coordinates ({x}, {y})")
                    continue
                
                # Extract each neuron's value from the packed memory word
                for ch in range(out_channels):
                    # Extract bits for this channel
                    start_bit = ch * bits_per_neuron
                    end_bit = start_bit + bits_per_neuron
                    
                    if end_bit > len(mem_line):
                        print(f"Warning: Channel {ch} extends beyond memory word at address {addr}")
                        continue
                    
                    # Extract binary substring for this neuron
                    neuron_bits = mem_line[start_bit:end_bit]
                    
                    # Convert to signed integer
                    neuron_value = binary_to_signed_int(neuron_bits, bits_per_neuron)
                    
                    # Store in result array
                    membrane_potentials[dump_idx, y, x, ch] = neuron_value
            
            print(f"  Processed {len(memory_data)} memory locations")
            
        except Exception as e:
            print(f"Error processing {dump_file}: {e}")
            continue
    
    return membrane_potentials, [os.path.basename(f) for f in dump_files]

def binary_to_signed_int(binary_str: str, bit_width: int) -> int:
    """
    Convert binary string to signed integer using two's complement.
    
    Parameters:
    -----------
    binary_str : str
        Binary string (e.g., "101010")
    bit_width : int
        Number of bits for the representation
        
    Returns:
    --------
    int
        Signed integer value
    """
    if len(binary_str) != bit_width:
        raise ValueError(f"Binary string length ({len(binary_str)}) doesn't match bit width ({bit_width})")
    
    # Convert to unsigned integer first
    unsigned_val = int(binary_str, 2)
    
    # Check if MSB is set (negative number in two's complement)
    if unsigned_val >= (1 << (bit_width - 1)):
        # Convert from two's complement
        signed_val = unsigned_val - (1 << bit_width)
    else:
        signed_val = unsigned_val
    
    return signed_val

def analyze_membrane_potentials(membrane_potentials: np.ndarray, dump_files: List[str]) -> None:
    """
    Print analysis of the membrane potential data.
    
    Parameters:
    -----------
    membrane_potentials : np.ndarray
        4D array from read_snn_memory_dumps
    dump_files : List[str]
        List of dump filenames
    """
    num_dumps, height, width, channels = membrane_potentials.shape
    
    print(f"\n=== Membrane Potential Analysis ===")
    print(f"Shape: {membrane_potentials.shape}")
    print(f"Data type: {membrane_potentials.dtype}")
    print(f"Number of dumps: {num_dumps}")
    print(f"Spatial dimensions: {height}x{width}")
    print(f"Channels: {channels}")
    
    print(f"\nValue statistics across all dumps:")
    print(f"  Min: {np.min(membrane_potentials)}")
    print(f"  Max: {np.max(membrane_potentials)}")
    print(f"  Mean: {np.mean(membrane_potentials):.2f}")
    print(f"  Std: {np.std(membrane_potentials):.2f}")
    
    print(f"\nNon-zero values per dump:")
    for i in range(num_dumps):
        non_zero_count = np.count_nonzero(membrane_potentials[i])
        total_elements = height * width * channels
        percentage = (non_zero_count / total_elements) * 100
        print(f"  Dump {i} ({dump_files[i]}): {non_zero_count}/{total_elements} ({percentage:.1f}%)")

In [2]:
dump_directory = r"E:\rtsprojects\general_conv\general_conv.sim\test_all\behav\xsim"
BITS_PER_NEURON = 9
OUT_CHANNELS = 4
IMG_WIDTH, IMG_HEIGHT = 8, 8
membrane_potentials, dump_files = read_snn_memory_dumps(
    dump_directory=dump_directory,
    bits_per_neuron=BITS_PER_NEURON,
    out_channels=OUT_CHANNELS,
    img_width=IMG_WIDTH,
    img_height=IMG_HEIGHT
)

Found 9 memory dump files
Memory organization:
  - Total locations: 64
  - Width per location: 36 bits
  - Bits per neuron: 9
  - Output channels: 4
  - Image size: 8x8

Processing dump 0: feature_map_mem_0.mem
  Processed 64 memory locations

Processing dump 1: feature_map_mem_1.mem
  Processed 64 memory locations

Processing dump 2: feature_map_mem_2.mem
  Processed 64 memory locations

Processing dump 3: feature_map_mem_3.mem
  Processed 64 memory locations

Processing dump 4: feature_map_mem_4.mem
  Processed 64 memory locations

Processing dump 5: feature_map_mem_5.mem
  Processed 64 memory locations

Processing dump 6: feature_map_mem_6.mem
  Processed 64 memory locations

Processing dump 7: feature_map_mem_7.mem
  Processed 64 memory locations

Processing dump 8: feature_map_mem_8.mem
  Processed 64 memory locations


In [3]:
import torch
import torch.nn as nn
import numpy as np
import json
from pathlib import Path
from typing import Union, Tuple, Optional

def load_kernel_weights_to_conv2d(
    weights_file: Union[str, Path],
    file_format: str = "auto",
    normalize_weights: bool = False,
    bias: bool = False
) -> Tuple[nn.Conv2d, dict]:
    """
    Load kernel weights from SNN generator and create a PyTorch Conv2d layer.
    
    Parameters:
    -----------
    weights_file : str or Path
        Path to the weights file (.json, .npy, or .mem)
    file_format : str
        File format: "json", "numpy", "vivado", or "auto" (detect from extension)
    normalize_weights : bool
        Whether to normalize weights to [-1, 1] range for better training
    bias : bool
        Whether to include bias terms in the Conv2d layer
        
    Returns:
    --------
    conv_layer : nn.Conv2d
        PyTorch convolution layer with loaded weights
    metadata : dict
        Metadata about the loaded weights and layer configuration
    """
    
    weights_file = Path(weights_file)
    
    # Auto-detect file format
    if file_format == "auto":
        if weights_file.suffix == ".json":
            file_format = "json"
        elif weights_file.suffix == ".npy":
            file_format = "numpy"
        elif weights_file.suffix == ".mem":
            file_format = "vivado"
        else:
            raise ValueError(f"Cannot auto-detect format for file: {weights_file}")
    
    print(f"Loading weights from {weights_file} (format: {file_format})")
    
    # Load weights based on format
    if file_format == "json":
        weights, config = _load_json_weights(weights_file)
    elif file_format == "numpy":
        weights, config = _load_numpy_weights(weights_file)
    elif file_format == "vivado":
        weights, config = _load_vivado_weights(weights_file)
    else:
        raise ValueError(f"Unsupported file format: {file_format}")
    
    # Extract configuration
    in_channels = config['in_channels']
    out_channels = config['out_channels']
    kernel_size = config['kernel_size']
    bits_per_weight = config.get('bits_per_weight', None)
    
    print(f"Loaded weights configuration:")
    print(f"  Input channels: {in_channels}")
    print(f"  Output channels: {out_channels}")
    print(f"  Kernel size: {kernel_size}")
    print(f"  Bits per weight: {bits_per_weight}")
    print(f"  Weight shape: {weights.shape}")
    print(f"  Weight range: [{weights.min()}, {weights.max()}]")
    
    # Convert weights to PyTorch format: [out_channels, in_channels, height, width]
    # Your format is: [in_channels, out_channels, height, width]
    torch_weights = torch.from_numpy(weights).float()
    torch_weights = torch_weights.permute(1, 0, 2, 3)  # Transpose to PyTorch format
    
    # Normalize weights if requested
    if normalize_weights:
        weight_max = torch.abs(torch_weights).max()
        if weight_max > 0:
            torch_weights = torch_weights / weight_max
            print(f"Normalized weights by factor {weight_max:.3f}")
    
    # Create Conv2d layer
    conv_layer = nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=1,
        padding=kernel_size // 2,  # Same padding
        bias=bias
    )
    
    # Load weights into the layer
    with torch.no_grad():
        conv_layer.weight.copy_(torch_weights)
        if bias:
            # Initialize bias to zero
            conv_layer.bias.zero_()
    
    # Prepare metadata
    metadata = {
        'original_config': config,
        'pytorch_shape': list(torch_weights.shape),
        'original_range': [float(weights.min()), float(weights.max())],
        'normalized': normalize_weights,
        'has_bias': bias,
        'file_format': file_format,
        'source_file': str(weights_file)
    }
    
    if normalize_weights:
        metadata['normalization_factor'] = float(weight_max)
    
    print(f"Created Conv2d layer: {conv_layer}")
    
    return conv_layer, metadata

def _load_json_weights(weights_file: Path) -> Tuple[np.ndarray, dict]:
    """Load weights from JSON file generated by the kernel generator."""
    with open(weights_file, 'r') as f:
        data = json.load(f)
    
    weights = np.array(data['weights'], dtype=np.float32)
    config = data['config']
    
    return weights, config

def _load_numpy_weights(weights_file: Path) -> Tuple[np.ndarray, dict]:
    """Load weights from numpy file."""
    weights = np.load(weights_file)
    
    # Try to load accompanying JSON config file
    json_file = weights_file.with_suffix('.json')
    if json_file.exists():
        with open(json_file, 'r') as f:
            data = json.load(f)
            config = data['config']
    else:
        # Infer configuration from shape
        in_channels, out_channels, kernel_size, _ = weights.shape
        config = {
            'in_channels': in_channels,
            'out_channels': out_channels,
            'kernel_size': kernel_size,
            'bits_per_weight': None
        }
        print(f"Warning: No JSON config found, inferred configuration from shape")
    
    return weights.astype(np.float32), config

def _load_vivado_weights(weights_file: Path) -> Tuple[np.ndarray, dict]:
    """Load weights from Vivado .mem file."""
    
    # Parse header to extract configuration
    config = {}
    data_lines = []
    
    with open(weights_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('// Input channels:'):
                config['in_channels'] = int(line.split(':')[1].strip())
            elif line.startswith('// Output channels:'):
                config['out_channels'] = int(line.split(':')[1].strip())
            elif line.startswith('// Kernel size:'):
                kernel_info = line.split(':')[1].strip()
                kernel_size = int(kernel_info.split('x')[0])
                config['kernel_size'] = kernel_size
            elif line.startswith('// Bits per weight:'):
                config['bits_per_weight'] = int(line.split(':')[1].strip())
            elif line.startswith('@') and not line.startswith('//'):
                data_lines.append(line)
    
    # Validate configuration
    required_keys = ['in_channels', 'out_channels', 'kernel_size', 'bits_per_weight']
    for key in required_keys:
        if key not in config:
            raise ValueError(f"Could not extract {key} from Vivado file header")
    
    # Parse data lines and reconstruct weights
    in_channels = config['in_channels']
    out_channels = config['out_channels']
    kernel_size = config['kernel_size']
    bits_per_weight = config['bits_per_weight']
    
    weights = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.int32)
    
    for line in data_lines:
        # Parse line: @ADDRESS HEXDATA // comment
        parts = line.split()
        if len(parts) < 2:
            continue
            
        address = int(parts[0][1:], 16)  # Remove @ and convert hex to int
        hex_data = parts[1]
        packed_value = int(hex_data, 16)
        
        # Decode address
        total_positions = kernel_size * kernel_size
        in_ch = address // total_positions
        pos = address % total_positions
        row = pos // kernel_size
        col = pos % kernel_size
        
        # Unpack weights for all output channels
        for out_ch in range(out_channels):
            # Extract bits for this channel
            shift = out_ch * bits_per_weight
            mask = (1 << bits_per_weight) - 1
            weight_unsigned = (packed_value >> shift) & mask
            
            # Convert to signed
            if weight_unsigned >= (1 << (bits_per_weight - 1)):
                weight_signed = weight_unsigned - (1 << bits_per_weight)
            else:
                weight_signed = weight_unsigned
                
            weights[in_ch, out_ch, row, col] = weight_signed
    
    return weights.astype(np.float32), config

def create_snn_conv_comparison(
    weights_file: Union[str, Path],
    input_tensor: torch.Tensor,
    normalize_weights: bool = True
) -> dict:
    """
    Create a comparison between SNN-style convolution and standard Conv2d.
    
    Parameters:
    -----------
    weights_file : str or Path
        Path to the weights file
    input_tensor : torch.Tensor
        Input tensor to test with, shape: [batch, channels, height, width]
    normalize_weights : bool
        Whether to normalize weights
        
    Returns:
    --------
    dict
        Results containing conv layer, output, and metadata
    """
    
    # Load weights and create Conv2d layer
    conv_layer, metadata = load_kernel_weights_to_conv2d(
        weights_file, 
        normalize_weights=normalize_weights
    )
    
    # Apply convolution
    conv_layer.eval()
    with torch.no_grad():
        output = conv_layer(input_tensor)
    
    return {
        'conv_layer': conv_layer,
        'output': output,
        'metadata': metadata,
        'input_shape': list(input_tensor.shape),
        'output_shape': list(output.shape)
    }

def visualize_kernels(conv_layer: nn.Conv2d, save_path: Optional[str] = None):
    """
    Visualize the kernels in a Conv2d layer.
    
    Parameters:
    -----------
    conv_layer : nn.Conv2d
        The convolution layer to visualize
    save_path : str, optional
        Path to save the visualization
    """
    try:
        import matplotlib.pyplot as plt
        
        weights = conv_layer.weight.detach().cpu().numpy()
        out_channels, in_channels, h, w = weights.shape
        
        # Create subplot grid
        fig, axes = plt.subplots(out_channels, in_channels, figsize=(in_channels*2, out_channels*2))
        if out_channels == 1:
            axes = axes.reshape(1, -1)
        if in_channels == 1:
            axes = axes.reshape(-1, 1)
            
        for out_ch in range(out_channels):
            for in_ch in range(in_channels):
                ax = axes[out_ch, in_ch] if out_channels > 1 else axes[in_ch]
                kernel = weights[out_ch, in_ch]
                
                im = ax.imshow(kernel, cmap='coolwarm', vmin=weights.min(), vmax=weights.max())
                ax.set_title(f'Out:{out_ch}, In:{in_ch}')
                ax.axis('off')
                
                # Add colorbar
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Kernel visualization saved to {save_path}")
        else:
            plt.show()
            
    except ImportError:
        print("Matplotlib not available for visualization")


In [4]:
weights_file = r'C:\Users\alext\fenrir\python\conv_test_scripts\kernel_weights.json'
conv_layer, metadata = load_kernel_weights_to_conv2d(
    weights_file
)

Loading weights from C:\Users\alext\fenrir\python\conv_test_scripts\kernel_weights.json (format: json)
Loaded weights configuration:
  Input channels: 4
  Output channels: 4
  Kernel size: 3
  Bits per weight: 6
  Weight shape: (4, 4, 3, 3)
  Weight range: [-2.0, 8.0]
Created Conv2d layer: Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
