In [27]:
import torch
import torch.nn as nn
import numpy as np
import json
import glob
import re
from pathlib import Path

def load_weights_to_conv2d(weights_file):
    """
    Load kernel weights and create a PyTorch Conv2d layer.
    
    Parameters:
    -----------
    weights_file : str
        Path to weights file (.json, .npy, or .mem)
        
    Returns:
    --------
    conv_layer : nn.Conv2d
        PyTorch convolution layer with loaded weights
    """
    weights_file = Path(weights_file)
    
    if weights_file.suffix == ".json":
        with open(weights_file, 'r') as f:
            data = json.load(f)
        weights = np.array(data['weights'], dtype=np.float32)
        config = data['config']
        
    elif weights_file.suffix == ".npy":
        weights = np.load(weights_file).astype(np.float32)
        # Infer config from shape
        in_channels, out_channels, kernel_size, _ = weights.shape
        config = {
            'in_channels': in_channels,
            'out_channels': out_channels,
            'kernel_size': kernel_size
        }
        
    elif weights_file.suffix == ".mem":
        # Parse Vivado memory file
        config = {}
        data_lines = []
        
        with open(weights_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('// Input channels:'):
                    config['in_channels'] = int(line.split(':')[1].strip())
                elif line.startswith('// Output channels:'):
                    config['out_channels'] = int(line.split(':')[1].strip())
                elif line.startswith('// Kernel size:'):
                    kernel_info = line.split(':')[1].strip()
                    kernel_size = int(kernel_info.split('x')[0])
                    config['kernel_size'] = kernel_size
                elif line.startswith('// Bits per weight:'):
                    config['bits_per_weight'] = int(line.split(':')[1].strip())
                elif line.startswith('@') and not line.startswith('//'):
                    data_lines.append(line)
        
        # Reconstruct weights from hex data
        in_channels = config['in_channels']
        out_channels = config['out_channels']
        kernel_size = config['kernel_size']
        bits_per_weight = config['bits_per_weight']
        
        weights = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.int32)
        
        for line in data_lines:
            parts = line.split()
            address = int(parts[0][1:], 16)
            hex_data = parts[1]
            packed_value = int(hex_data, 16)
            
            # Decode address
            total_positions = kernel_size * kernel_size
            in_ch = address // total_positions
            pos = address % total_positions
            row = pos // kernel_size
            col = pos % kernel_size
            
            # Unpack weights
            for out_ch in range(out_channels):
                shift = out_ch * bits_per_weight
                mask = (1 << bits_per_weight) - 1
                weight_unsigned = (packed_value >> shift) & mask
                
                # Convert to signed
                if weight_unsigned >= (1 << (bits_per_weight - 1)):
                    weight_signed = weight_unsigned - (1 << bits_per_weight)
                else:
                    weight_signed = weight_unsigned
                    
                weights[in_ch, out_ch, row, col] = weight_signed
        
        weights = weights.astype(np.float32)
    
    else:
        raise ValueError(f"Unsupported file format: {weights_file.suffix}")
    
    # Convert to PyTorch format: [out_channels, in_channels, height, width]
    torch_weights = torch.from_numpy(weights).permute(1, 0, 2, 3)
    
    # Create Conv2d layer
    conv_layer = nn.Conv2d(
        in_channels=config['in_channels'],
        out_channels=config['out_channels'],
        kernel_size=config['kernel_size'],
        stride=1,
        padding=config['kernel_size'] // 2,
        bias=False
    )
    
    # Load weights
    with torch.no_grad():
        conv_layer.weight.copy_(torch_weights)
    
    return conv_layer

def load_events(events_file):
    """
    Load events and return coordinates and channels.
    
    Parameters:
    -----------
    events_file : str
        Path to events file (.json or .mem)
        
    Returns:
    --------
    events : np.ndarray
        Array of shape (n_events, 4) with columns [x, y, channels, timestep]
        where channels is the integer spike pattern
    """
    events_file = Path(events_file)
    
    if events_file.suffix == ".json":
        with open(events_file, 'r') as f:
            data = json.load(f)
        events_list = data['events']
        
        events = []
        for event in events_list:
            events.append([
                event['x'],
                event['y'], 
                event['spikes'],
                event['timestep']
            ])
        
        return np.array(events, dtype=np.int32)
    
    elif events_file.suffix == ".mem":
        # Parse binary event file
        events = []
        
        with open(events_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('//'):
                    # Extract binary data (before any comment)
                    binary_str = line.split()[0]
                    
                    # Parse based on your format: [timestep][x][y][spikes]
                    # You'll need to specify bit widths - assuming from your testbench:
                    bits_per_coord = 7  # Adjust as needed
                    in_channels = 4     # Adjust as needed
                    
                    total_bits = len(binary_str)
                    
                    # Extract fields
                    timestep = int(binary_str[0])
                    x_bits = binary_str[1:1+bits_per_coord]
                    y_bits = binary_str[1+bits_per_coord:1+2*bits_per_coord]
                    spikes_bits = binary_str[1+2*bits_per_coord:]
                    
                    x = int(x_bits, 2)
                    y = int(y_bits, 2)
                    spikes = int(spikes_bits, 2)
                    
                    events.append([x, y, spikes, timestep])
        
        return np.array(events, dtype=np.int32)
    
    else:
        raise ValueError(f"Unsupported file format: {events_file.suffix}")
def load_memory_dumps(dump_directory=".", bits_per_neuron=9, out_channels=4, pattern="feature_map_mem_*.mem"):
    """
    Load Vivado memory dumps and return signed membrane potential values.
    
    Parameters:
    -----------
    dump_directory : str
        Directory containing memory dump files
    bits_per_neuron : int
        Number of bits per neuron value
    out_channels : int
        Number of output channels (neurons per memory location)
    pattern : str
        Filename pattern for memory dumps
        
    Returns:
    --------
    memory_data : list of np.ndarray
        List of memory dumps, each with shape (memory_depth, out_channels)
        containing signed membrane potential values
    dump_files : list of str
        List of dump filenames in order
    """
    # Find dump files
    search_pattern = str(Path(dump_directory) / pattern)
    dump_files = glob.glob(search_pattern)
    
    if not dump_files:
        raise FileNotFoundError(f"No memory dump files found: {search_pattern}")
    
    # Sort by dump number
    def extract_dump_number(filename):
        match = re.search(r'feature_map_mem_(\d+)\.mem', Path(filename).name)
        return int(match.group(1)) if match else 0
    
    dump_files.sort(key=extract_dump_number)
    
    memory_data = []
    
    for dump_file in dump_files:
        with open(dump_file, 'r') as f:
            lines = f.readlines()
        
        # Remove comments and empty lines
        mem_lines = []
        for line in lines:
            line = line.strip()
            if line and not line.startswith('//'):
                mem_lines.append(line)
        
        # Parse each memory line into multiple neuron values
        dump_array = []
        for mem_line in mem_lines:
            if len(mem_line) > 0:
                # Each memory line contains out_channels packed neurons
                neuron_values = []
                
                for ch in range(out_channels):
                    # Extract bits for this channel
                    start_bit = ch * bits_per_neuron
                    end_bit = start_bit + bits_per_neuron
                    
                    if end_bit <= len(mem_line):
                        # Extract binary substring for this neuron
                        neuron_bits = mem_line[start_bit:end_bit]
                        
                        # Convert to signed integer using two's complement
                        unsigned_val = int(neuron_bits, 2)
                        if unsigned_val >= (1 << (bits_per_neuron - 1)):
                            signed_val = unsigned_val - (1 << bits_per_neuron)
                        else:
                            signed_val = unsigned_val
                        
                        neuron_values.append(signed_val)
                    else:
                        # Handle case where memory line is shorter than expected
                        neuron_values.append(0)
                
                dump_array.append(neuron_values)
        
        memory_data.append(np.array(dump_array, dtype=np.int32))
    
    return memory_data, [Path(f).name for f in dump_files]

In [28]:
kernel_weights_path = r'C:\Users\alext\fenrir\python\conv_test_scripts\kernel_weights.json'
events_path = r'C:\Users\alext\fenrir\python\conv_test_scripts\snn_test_events.json'
memory_dumps_path = r'E:\rtsprojects\general_conv\general_conv.sim\test_all\behav\xsim'

In [41]:
memory_dumps, dump_files = load_memory_dumps(memory_dumps_path, bits_per_neuron=9)
memory_dumps = (np.array(memory_dumps))
memory_dumps.shape, memory_dumps[0].shape, memory_dumps[0][49]

((9, 64, 4), (64, 4), array([ 0,  0,  0, -2]))

In [6]:
conv_layer = load_weights_to_conv2d(kernel_weights_path)
conv_layer

Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [8]:
events = load_events(events_path)
events

array([[7, 2, 9, 0],
       [2, 7, 1, 0],
       [1, 6, 8, 0],
       [7, 0, 4, 0],
       [2, 0, 8, 0],
       [1, 7, 5, 0],
       [0, 7, 1, 0],
       [7, 1, 8, 0],
       [0, 7, 0, 0],
       [4, 7, 3, 0]])

In [13]:
n = 1
bit_string = format(n, '04b')
bit_string

'0001'