In [2]:
import numpy as np

In [3]:
def generate_index_shapes(shape: int) -> list[tuple[int, ...]]:
    return [(1,) * i + (shape[i],) + (1,) * (len(shape) - i - 1) for i in range(len(shape))]


def generate_tile_shapes(shape: tuple[int, ...]) -> tuple[tuple[int, ...]]:
    return tuple(tuple(1 if j == i else shape[j] for j in range(len(shape))) for i in range(len(shape)))


def create_index_matrix(shape: tuple[int, ...]) -> tuple[np.ndarray, ...]:
    # Calculate the center indices
    center = [dim//2 for dim in shape]
    indices = [
        np.arange(dim).reshape(*index_shape) - center
        for dim, center, index_shape in zip(shape, center, generate_index_shapes(shape))
    ]

    return tuple(
        np.tile(
            index, tile_shape
        )
        for index, tile_shape in zip(indices, generate_tile_shapes(shape))
    )


def select_data(data: np.ndarray, indices: np.ndarray) -> np.ndarray:
    # Check dimensions
    indices_ndim = len(indices) if isinstance(indices, tuple) else indices.ndim
    if indices_ndim >= data.ndim:
        # Directly use indices for selection
        return data[indices]
    else:
        # Prepare a tuple for advanced indexing
        idx = [slice(None)] * (data.ndim - indices_ndim) + [indices]
        print(tuple(idx))
        return data[tuple(idx)]

def wrap_indices(indices: tuple[int, ...], shape: tuple[int, ...]) -> tuple[int, ...]:
    if len(indices) != len(shape):
        raise ValueError('Indices and shape must have the same length')
    return tuple((((i % s) + s) % s for i, s in zip(indices, shape)))

In [4]:
class WrapperIndexer:
    def __init__(self, shape: tuple[int, ...]):
        self.shape = shape

    def __getitem__(self, item: tuple[int, ...]):
        return wrap_indices(item, self.shape)

class AttractorIndexer:
    def __init__(self, shape: tuple[int, ...]):
        self.shape = shape
        self.indexer = WrapperIndexer(shape)

    def __getitem__(self, item: tuple[int, ...]):
        return np.ix_(*[
            (np.arange(dim) - shift) % dim
            for shift, dim in zip(self.indexer[item], self.shape)
        ])

class AttractorState:
    def __init__(
        self,
        kernel: np.ndarray,
        weights: np.ndarray,
        indexer: AttractorIndexer,
        inplace: bool = False
    ):
        self.kernel = kernel
        self.weights = weights
        self.indexer = indexer
        self.inplace = inplace

    @property
    def shape(self):
        return self.kernel.shape

    @property
    def ndim(self):
        return self.kernel.ndim

    def __getitem__(self, indices: np.ndarray):
        indices = wrap_indices(indices, self.kernel.shape)
        ratio = self.kernel[*indices]

        return select_data(self.weights, self.indexer[indices])*ratio

    def __matmul__(self, other: np.ndarray):
        if not self.inplace:
            other = other.copy()
        # for indices in IteratorIndexer(self.shape):
        #     if not np.isclose(self.kernel[indices], 0):
        for indices in zip(*np.nonzero(self.kernel)):
                other += self[indices]
        return other



class Attractor:
    def __init__(
        self,
        kernel: np.ndarray,
        inplace: bool = False
    ):
        self.kernel = kernel
        self.kernel_shifted = self.kernel[create_index_matrix(self.kernel.shape)]
        self.indexer = AttractorIndexer(self.shape)
        self.inplace = inplace

    @property
    def shape(self):
        return self.kernel.shape

    def __call__(self, weights: np.ndarray) -> AttractorState:
        if not self.inplace:
            weights = weights.copy()
        return AttractorState(self.kernel_shifted, weights, self.indexer, self.inplace)

In [None]:
class DoebleAttractorState:
    def __init__(
        self,
        kernels: tuple[np.ndarray, np.ndarray],
        weights: tuple[np.ndarray, np.ndarray],
        indexer: AttractorIndexer,
        inplace: bool = False
    ):
        self.ring1_kernel, self.ring2_kernel = kernels
        self.ring1_weights, self.ring2_weights = weights
        self.indexer = indexer
        self.inplace = inplace

    @property
    def shape(self):
        return self.kernel.shape

    @property
    def ndim(self):
        return self.kernel.ndim

    def __getitem__(self, indices: np.ndarray):
        indices = wrap_indices(indices, self.kernel.shape)

        return (select_data(weights, self.indexer[indices])*kernel[*indices] for weights, kernel in zip(self.weights, self.kernels))

    def __matmul__(self, other: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, np.ndarray]:
        if not self.inplace:
            other = (inst.copy() for inst in other)
        for indices in zip(*np.nonzero(np.add(*self.kernels))):
                weight_update2, weight_update1 = self[indices]
                other = (weight + update for weight, update in zip(other, (weight_update1, weight_update2)))
        return other

In [10]:
list(range(10))[-1:] + list(range(10))[1:-1] 

[9, 1, 2, 3, 4, 5, 6, 7, 8]

In [None]:
class RingAttractorState:
    def __init__(
        self,
        kernels: tuple[np.ndarray, ...],
        weights: tuple[np.ndarray, ...],
        indexer: AttractorIndexer,
        inplace: bool = False
    ):
        self.kernels = kernels
        self.weights = weights
        self.indexer = indexer
        self.inplace = inplace

    @property
    def shape(self):
        return self.kernel.shape

    @property
    def ndim(self):
        return self.kernel.ndim

    def __getitem__(self, indices: np.ndarray):
        indices = wrap_indices(indices, self.kernel.shape)

        return (select_data(weights, self.indexer[indices])*kernel[*indices] for weights, kernel in zip(self.weights, self.kernels))

    def __matmul__(self, other: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, np.ndarray]:
        if not self.inplace:
            other = (inst.copy() for inst in other)
        for indices in zip(*np.nonzero(np.add(*self.kernels))):
                updates = self[indices]
                other = (weight + update for weight, update in zip(other, updates))
        return other


class RingAttractor:
    def __init__(
        self,
        kernels: tuple[np.ndarray, ...],
        inplace: bool = False
    ):
        self.kernels = kernels
        self.kernels_shifted = (kernel[create_index_matrix(kernel.shape)] for kernel in kernels)
        self.indexer = AttractorIndexer(self.shape)
        self.inplace = inplace

    @property
    def shape(self):
        return self.kernel.shape

    def __call__(self, weights: tuple[np.ndarray, ...]) -> AttractorState:
        if not self.inplace:
            weights = (weight.copy() for weight in weights)
        return AttractorState(self.kernel_shifted, weights, self.indexer, self.inplace)

In [5]:
a = np.arange(9).reshape(3, 3)
b = np.arange(9).reshape(3, 3)
np.add(a, b)

array([[ 0,  2,  4],
       [ 6,  8, 10],
       [12, 14, 16]])