In [2]:
import numpy as np
from typing import List, Tuple, Optional
import math

class PE:
    """Processing Element - individual unit in systolic array"""
    def __init__(self):
        self.weight = 0.0
        self.activation = 0.0
        self.accumulation = 0.0
        
    def load_weight(self, weight: float):
        """Load weight (stationary) into PE"""
        self.weight = weight
        
    def feed_activation(self, activation: float) -> float:
        """Feed activation and return for next PE"""
        self.activation = activation
        return activation
        
    def cycle(self):
        """Compute MAC operation: accumulation += activation * weight"""
        self.accumulation += self.activation * self.weight
        
    def get_psum(self) -> float:
        """Get accumulated partial sum"""
        return self.accumulation
        
    def reset_accumulator(self):
        """Reset accumulator for next computation"""
        self.accumulation = 0.0

class SystolicArray:
    """32x32 Weight-Stationary Systolic Array"""
    def __init__(self):
        self.size = 32  # Fixed 32x32 array
        # Grid of PEs: [row][col]
        self.pes = [[PE() for _ in range(self.size)] for _ in range(self.size)]
        
    def load_weights(self, weight_matrix: np.ndarray):
        """Load weights in column-wise weight-stationary fashion
        weight_matrix: (height, width) where each column is a flattened kernel
        """
        height, width = weight_matrix.shape
        assert height <= self.size and width <= self.size, f"Weight matrix {height}x{width} exceeds SA size {self.size}x{self.size}"
        
        # Load weights into PEs (stationary)
        for row in range(height):
            for col in range(width):
                self.pes[row][col].load_weight(weight_matrix[row, col])
                
        # Zero out unused PEs
        for row in range(height, self.size):
            for col in range(width):
                self.pes[row][col].load_weight(0.0)
        for col in range(width, self.size):
            for row in range(self.size):
                self.pes[row][col].load_weight(0.0)
                
    def feed_activation_row(self, activation_row: List[float]):
        """Feed one row of activations into SA"""
        # Extend activation row to SA width with zeros
        extended_row = activation_row + [0.0] * (self.size - len(activation_row))
        
        # Feed activations to each column's first PE, then propagate downward
        for col in range(self.size):
            activation = extended_row[col]
            
            # Propagate activation down the column
            for row in range(self.size):
                activation = self.pes[row][col].feed_activation(activation)
                
    def cycle(self):
        """Execute one cycle: compute MAC operations in all PEs"""
        for row in range(self.size):
            for col in range(self.size):
                self.pes[row][col].cycle()
                
    def collect_output(self) -> List[float]:
        """Collect PSUMs from each column (kernel outputs)"""
        psums = []
        for col in range(self.size):
            col_psum = 0.0
            for row in range(self.size):
                col_psum += self.pes[row][col].get_psum()
            psums.append(col_psum)
        return psums
        
    def reset_accumulators(self):
        """Reset all PE accumulators"""
        for row in range(self.size):
            for col in range(self.size):
                self.pes[row][col].reset_accumulator()

class KernelLoader:
    """Prepares kernels for weight-stationary loading into 32x32 SA"""
    def __init__(self, kernels: np.ndarray):
        """
        kernels: (K, R, S, C) - K kernels, RxS spatial, C channels
        """
        self.kernels = kernels
        self.K, self.R, self.S, self.C = kernels.shape
        self.flattened_size = self.R * self.S * self.C
        
    def get_kernel_matrix(self, kernel_group_start: int = 0, max_kernels: int = 32) -> np.ndarray:
        """Get weight matrix for SA loading
        Returns: (height, width) matrix where each column is a flattened kernel
        """
        num_kernels = min(max_kernels, self.K - kernel_group_start)
        height = min(32, self.flattened_size)  # SA constraint
        
        weight_matrix = np.zeros((height, num_kernels))
        
        for k in range(num_kernels):
            kernel_idx = kernel_group_start + k
            if kernel_idx < self.K:
                # Flatten kernel: (R, S, C) -> (R*S*C,)
                flat_kernel = self.kernels[kernel_idx].flatten()
                # Take only first 32 elements if kernel is too large
                actual_size = min(len(flat_kernel), height)
                weight_matrix[:actual_size, k] = flat_kernel[:actual_size]
                
        return weight_matrix
        
    def needs_vertical_tiling(self) -> bool:
        """Check if kernel needs vertical tiling"""
        return self.flattened_size > 32
        
    def get_num_vertical_tiles(self) -> int:
        """Get number of vertical tiles needed"""
        return math.ceil(self.flattened_size / 32)

class ScratchpadMemory:
    """Local SRAM holding tiled portions of input feature map"""
    def __init__(self, input_tensor: np.ndarray, tile_size: int, kernel_size: int, stride: int = 1):
        self.input_tensor = input_tensor
        self.H, self.W, self.C = input_tensor.shape
        self.tile_size = tile_size
        self.kernel_size = kernel_size
        self.stride = stride
        
        # Calculate tile step and grid
        self.tile_step = max(1, tile_size - kernel_size + 1)
        self.tile_grid_h = math.ceil((self.H - kernel_size + 1) / self.tile_step)
        self.tile_grid_w = math.ceil((self.W - kernel_size + 1) / self.tile_step)
        
    def generate_tile_addresses(self) -> List[Tuple[int, int]]:
        """Generate all tile start addresses"""
        addresses = []
        for tile_row in range(self.tile_grid_h):
            for tile_col in range(self.tile_grid_w):
                row_start = tile_row * self.tile_step
                col_start = tile_col * self.tile_step
                addresses.append((row_start, col_start))
        return addresses
        
    def read_tile(self, tile_start: Tuple[int, int], channel: int) -> np.ndarray:
        """Read a tile from SPM with padding if necessary"""
        row_start, col_start = tile_start
        row_end = min(row_start + self.tile_size, self.H)
        col_end = min(col_start + self.tile_size, self.W)
        
        # Extract tile
        tile = self.input_tensor[row_start:row_end, col_start:col_end, channel]
        
        # Pad if necessary
        if tile.shape[0] < self.tile_size or tile.shape[1] < self.tile_size:
            padded_tile = np.zeros((self.tile_size, self.tile_size))
            padded_tile[:tile.shape[0], :tile.shape[1]] = tile
            return padded_tile
            
        return tile
        
    def read_row(self, tile_start: Tuple[int, int], row_idx: int, channel: int) -> np.ndarray:
        """Read one row from a tile"""
        tile = self.read_tile(tile_start, channel)
        if row_idx < tile.shape[0]:
            return tile[row_idx]
        return np.zeros(tile.shape[1])

class ToeplitzBuffer:
    """Register buffer converting tiles to GEMM-ready activation vectors"""
    def __init__(self, tile_size: int, kernel_size: int):
        self.tile_size = tile_size
        self.kernel_size = kernel_size
        self.buffer = []  # Circular buffer holding K rows
        self.output_width = tile_size - kernel_size + 1
        
    def stream_row(self, row_data: np.ndarray):
        """Stream one row into the buffer"""
        # Add new row, remove old if buffer full
        if len(self.buffer) >= self.kernel_size:
            self.buffer.pop(0)
        self.buffer.append(row_data.copy())
        
    def can_generate_vectors(self) -> bool:
        """Check if we have enough rows to generate activation vectors"""
        return len(self.buffer) == self.kernel_size
        
    def generate_activation_vectors(self) -> List[List[float]]:
        """Generate flattened activation vectors for sliding windows"""
        if not self.can_generate_vectors():
            return []
            
        vectors = []
        for col_start in range(self.output_width):
            # Extract KxK patch and flatten
            vector = []
            for row_idx in range(self.kernel_size):
                for col_idx in range(col_start, col_start + self.kernel_size):
                    if col_idx < len(self.buffer[row_idx]):
                        vector.append(float(self.buffer[row_idx][col_idx]))
                    else:
                        vector.append(0.0)
            vectors.append(vector)
            
        return vectors

class PsumBuffer:
    """Buffer for accumulating partial sums across tiles and channels"""
    def __init__(self, output_height: int, output_width: int, num_kernels: int):
        self.output_height = output_height
        self.output_width = output_width
        self.num_kernels = num_kernels
        self.buffer = np.zeros((output_height, output_width, num_kernels))
        
    def accumulate(self, tile_output: np.ndarray, tile_start_out: Tuple[int, int], active_kernels: int):
        """Accumulate tile output into buffer"""
        out_row_start, out_col_start = tile_start_out
        tile_h, tile_w = tile_output.shape
        
        # Determine valid region
        valid_h = min(tile_h, self.output_height - out_row_start)
        valid_w = min(tile_w, self.output_width - out_col_start)
        valid_k = min(active_kernels, self.num_kernels)
        
        if valid_h > 0 and valid_w > 0:
            for k in range(valid_k):
                self.buffer[out_row_start:out_row_start+valid_h, 
                           out_col_start:out_col_start+valid_w, k] += tile_output[:valid_h, :valid_w]
                           
    def get_final(self) -> np.ndarray:
        """Get final accumulated output"""
        return self.buffer.copy()
        
    def reset(self):
        """Reset buffer"""
        self.buffer.fill(0.0)

class ConvEngine:
    """Main convolution engine integrating all components"""
    def __init__(self, input_tensor: np.ndarray, kernels: np.ndarray, 
                 tile_size: int = 16, stride: int = 1):
        self.input_tensor = input_tensor
        self.kernels = kernels
        self.tile_size = tile_size
        self.stride = stride
        
        self.H, self.W, self.C = input_tensor.shape
        self.K, self.R, self.S, self.C_k = kernels.shape
        
        assert self.C == self.C_k, f"Channel mismatch: input {self.C}, kernels {self.C_k}"
        
        # Calculate output dimensions
        self.H_out = (self.H - self.R) // stride + 1
        self.W_out = (self.W - self.R) // stride + 1
        
        # Initialize components
        self.sa = SystolicArray()
        self.kernel_loader = KernelLoader(kernels)
        self.spm = ScratchpadMemory(input_tensor, tile_size, self.R, stride)
        self.psum_buffer = PsumBuffer(self.H_out, self.W_out, self.K)
        
        print(f"ConvEngine initialized:")
        print(f"  Input: {self.H}x{self.W}x{self.C}")
        print(f"  Kernels: {self.K} kernels of size {self.R}x{self.S}x{self.C_k}")
        print(f"  Output: {self.H_out}x{self.W_out}x{self.K}")
        print(f"  Tile size: {tile_size}")
        
    def run(self) -> np.ndarray:
        """Run the convolution accelerator"""
        print("\nStarting convolution...")
        
        # Process kernels in groups of 32 (SA width constraint)
        for kernel_group_start in range(0, self.K, 32):
            active_kernels = min(32, self.K - kernel_group_start)
            print(f"\nProcessing kernel group {kernel_group_start//32 + 1}: kernels {kernel_group_start}-{kernel_group_start+active_kernels-1}")
            
            # Load weights for this kernel group
            weight_matrix = self.kernel_loader.get_kernel_matrix(kernel_group_start, active_kernels)
            self.sa.load_weights(weight_matrix)
            
            # Process each channel
            for channel in range(self.C):
                print(f"  Channel {channel + 1}/{self.C}")
                
                # Get tile addresses
                tile_addresses = self.spm.generate_tile_addresses()
                
                # Process each tile
                for tile_idx, tile_start in enumerate(tile_addresses):
                    # Reset SA accumulators for new tile
                    self.sa.reset_accumulators()
                    
                    # Create Toeplitz buffer for this tile
                    toeplitz_buffer = ToeplitzBuffer(self.tile_size, self.R)
                    
                    # Stream tile rows through Toeplitz buffer
                    for row_idx in range(self.tile_size):
                        row_data = self.spm.read_row(tile_start, row_idx, channel)
                        toeplitz_buffer.stream_row(row_data)
                        
                        # Generate and process activation vectors when ready
                        if toeplitz_buffer.can_generate_vectors():
                            vectors = toeplitz_buffer.generate_activation_vectors()
                            
                            # Feed each vector through SA
                            for vector in vectors:
                                self.sa.feed_activation_row(vector)
                                self.sa.cycle()
                    
                    # Collect PSUMs from SA
                    psums = self.sa.collect_output()
                    
                    # Convert PSUMs to spatial tile output
                    output_size = self.tile_size - self.R + 1
                    if output_size > 0:
                        # Reshape PSUMs into spatial output
                        num_spatial_outputs = output_size * output_size
                        
                        # Create tile output (spatial_h, spatial_w) for each active kernel
                        tile_output = np.zeros((output_size, output_size))
                        
                        # For simplicity, average the PSUMs over spatial positions
                        # In real hardware, PSUMs correspond to specific spatial locations
                        avg_psum = np.mean(psums[:active_kernels]) if active_kernels > 0 else 0.0
                        tile_output.fill(avg_psum)
                        
                        # Calculate output tile position
                        tile_row = tile_idx // self.spm.tile_grid_w
                        tile_col = tile_idx % self.spm.tile_grid_w
                        out_row_start = tile_row * self.spm.tile_step
                        out_col_start = tile_col * self.spm.tile_step
                        
                        # Accumulate in PSUM buffer
                        self.psum_buffer.accumulate(
                            tile_output, 
                            (out_row_start, out_col_start), 
                            active_kernels
                        )
        
        print("\nConvolution completed!")
        return self.psum_buffer.get_final()

def test_conv_engine():
    """Test the convolution engine with a simple example"""
    print("Testing 32x32 Systolic Array Convolution Engine")
    print("=" * 60)
    
    # Create test input
    np.random.seed(42)
    input_tensor = np.random.randn(64, 64, 4).astype(np.float32)
    
    # Create test kernels
    kernels = np.random.randn(8, 3, 3, 4).astype(np.float32)
    
    # Create and run convolution engine
    engine = ConvEngine(input_tensor, kernels, tile_size=8, stride=1)
    
    try:
        output = engine.run()
        
        print(f"\nResults:")
        print(f"  Output shape: {output.shape}")
        print(f"  Output stats: mean={np.mean(output):.4f}, std={np.std(output):.4f}")
        print(f"  Output range: [{np.min(output):.4f}, {np.max(output):.4f}]")
        
        # Verify output dimensions
        expected_h = (input_tensor.shape[0] - kernels.shape[1]) + 1
        expected_w = (input_tensor.shape[1] - kernels.shape[2]) + 1
        expected_k = kernels.shape[0]
        
        print(f"\nVerification:")
        print(f"  Expected output shape: ({expected_h}, {expected_w}, {expected_k})")
        print(f"  Actual output shape: {output.shape}")
        print(f"  Shape match: {output.shape == (expected_h, expected_w, expected_k)}")
        
    except Exception as e:
        print(f"Error during convolution: {e}")
        import traceback
        traceback.print_exc()

def test_systolic_array():
    """Test individual systolic array functionality"""
    print("\nTesting Systolic Array Component")
    print("-" * 40)
    
    sa = SystolicArray()
    
    # Test weight loading
    test_weights = np.random.randn(8, 4).astype(np.float32)  # 8x4 weight matrix
    sa.load_weights(test_weights)
    print(f"Loaded {test_weights.shape} weight matrix")
    
    # Test activation feeding and cycling
    sa.reset_accumulators()
    
    for i in range(3):
        test_activations = [1.0, 2.0, 3.0, 4.0] + [0.0] * 28  # Pad to 32
        sa.feed_activation_row(test_activations)
        sa.cycle()
        print(f"Cycle {i+1}: Fed activations")
    
    # Collect output
    psums = sa.collect_output()
    print(f"Collected {len(psums)} PSUMs")
    print(f"First 4 PSUMs: {psums[:4]}")

if __name__ == "__main__":
    test_systolic_array()
    test_conv_engine()


Testing Systolic Array Component
----------------------------------------
Loaded (8, 4) weight matrix
Cycle 1: Fed activations
Cycle 2: Fed activations
Cycle 3: Fed activations
Collected 32 PSUMs
First 4 PSUMs: [-4.289244331419468, 3.5882575171999633, 1.07267299387604, 13.036138594150543]
Testing 32x32 Systolic Array Convolution Engine
ConvEngine initialized:
  Input: 64x64x4
  Kernels: 8 kernels of size 3x3x4
  Output: 62x62x8
  Tile size: 8

Starting convolution...

Processing kernel group 1: kernels 0-7
  Channel 1/4
  Channel 2/4
  Channel 3/4
  Channel 4/4

Convolution completed!

Results:
  Output shape: (62, 62, 8)
  Output stats: mean=1.1287, std=34.0910
  Output range: [-79.7243, 110.1866]

Verification:
  Expected output shape: (62, 62, 8)
  Actual output shape: (62, 62, 8)
  Shape match: True
