In [1]:
# imports
import jax
import jax.numpy as jnp
import numpy as np
# Node and Network
from Node import Node
from Network import Cluster

In [2]:
class DataHandler:
    def __init__(self, cluster: Cluster, data: jnp.ndarray = jnp.array([]), target: jnp.ndarray = jnp.array([]), encoding : str = "spike_amp"): # TODO, encoding type, decoding type, start iterations, number of iterations to get the network started, batch size, sampling steps
        """
        This class handles input and output data.
        For now it only offer a simple way to input and output data trough a network. 

        Data encoding uses algorithms similar to SNNs

        # NOTE: Optimization is handled by the Cluster class, maybe change to runner class
        # TODO: Diffrerent learning type depending on network running type, pretraining, running...
        # TODO: Create subclasses which read in diffrent data formats, like csv, images, ...
        # TODO: Batch size and sampling steps
        # TODO: Dependency on network type, like pretrained or functional
        # TODO: Put some features in sperate network runner
        # TODO: Error handling
Planned:
- Encoding types (rate dependent and continous spike dependent)
- Output decoding (takes the average, median, ...)
- Start iterations, number of iterations to get the network started

        Args:
            cluster (Cluster): Cluster to which the data is fed
            data (jnp.ndarray): Input data
            target (jnp.ndarray): Target data
        """
        # size definitions from the cluster
        self.cluster = cluster
        self.input_size = len(cluster.input_nodes)
        self.output_size = len(cluster.output_nodes)
        # data
        self.data = data
        self.target = target
        self.encoding = encoding
        # size check
        # TODO : Make same error type as used in other files
        assert len(data) == len(target), "Data and target size mismatch"
        assert len(data[0]) == self.input_size, "Data input size mismatch"
        assert len(target[0]) == self.output_size, "Data output size mismatch"
        # data info
        self.index = 0
        self.data_size = len(data)

    # TODO: Create subclasses, which take images, csv, ...
    def load_data(self, data):
        """
        Takes in Data and stores it in the class to be later accessible by the runner
        # TODO: Implement functions to load data from diffrent formats

        Args:
            data (jnp.ndarray): Data to be loaded

        """
        raise NotImplementedError("Implemented trough a subclass")

    """
    # TODO : Check working
    Idea for Csv read in
    @classmethod
    def from_csv(cls, path):
        # convert csv to jnp array
        data = jnp.loadtxt(csv_file_path, delimiter=',')
        target = jnp.zeros(data.shape)
        return cls(cluster, data, target, encoding_type, decoding_type, batch_size, sampling_steps) # call cunstructor
    """
    # TODO: other formats and images
    
# Signal processing
    def encode_spike_amp(self, cur_data):
        """
        Encodes the data into a spike amplitude format.
        
        This format creates an array of the diffrent signal timesteps used for running the network
        """
        # give out the current data
        spike_amp_data = jnp.array([cur_data]) # TODO: Check if necessary
        # create zeros for the next spike
        zeros = jnp.zeros(spike_amp_data.shape)
        # merge to for two amplitudes for a spike
        cur_data = jnp.vstack([cur_data, zeros])
        return cur_data
    
    # TODO: Implement rate encoding

# Iterator
    """
    The iterator includes has batch size elements
    The elements have 2 features one is temporal and the other is the data for each input
    The temporal features should be switched in a circle for each iteration of the network
    """
    def  __iter__(self):
        return self
    
    def __next__(self):
        # TODO get batch
        # Return one element for now

        # select encoding mode and create elements
        if self.encoding == "spike_amp":
            cur_data = self.encode_spike_amp(self.data[self.index])
            cur_target = self.target[self.index]
        else:
            raise NotImplementedError("Encoding not implemented, try spike_amp")
        # increment index
        self.index += 1
        # check if index is out of bounds
        if self.index == self.data_size + 1: # TODO: One bigger sice data is accesed before
            self.index = 0
            raise StopIteration
        return cur_data, cur_target

# Helper functions
    def reset(self):
        """
        Resets the index of the data to 0
        """
        self.index = 0
    
    def convert_data(self, data, encoder):
        """ 
        Converts the data into the encoding format

        Args:
            data (jnp.ndarray): Data to be converted
            encoder (str): Encoding type
        """
        if encoder == "spike_amp":
            return self.encode_spike_amp(data)
        else:
            raise NotImplementedError("Encoding not implemented, try spike_amp")

In [3]:
# Assuming cur_data is already a jnp.array
cur_data = jnp.array([1, 2, 3, 4, 5])  # Example data
spike_amp_data = jnp.array([cur_data])

spike_amp_data = jnp.vstack([spike_amp_data, jnp.zeros(cur_data.shape)])  # Append a zero array to the spike_amp_data

# Print the spike_amp_data to see its contents
print(spike_amp_data)
# If using a Jupyter Notebook, you can also display it
spike_amp_data

[[1. 2. 3. 4. 5.]
 [0. 0. 0. 0. 0.]]


Array([[1., 2., 3., 4., 5.],
       [0., 0., 0., 0., 0.]], dtype=float32)

In [4]:
# Test DataHandler
# Nodes
# Create a few more nodes for testing
Node0 = Node(0, jnp.array([]), jnp.array([]), 0)
Node1 = Node(0, jnp.array([]), jnp.array([]), 1)
Node2 = Node(0, jnp.array([]), jnp.array([]), 2)
nodes = [Node0, Node1, Node2]
# Create a cluster, which is just a short linear connection
cluster = Cluster(1, 1, nodes=nodes, init_net=False)
cluster_big = Cluster(3, 3, nodes=nodes, init_net=False)
# add connections
cluster.add_connection(Node0, Node1) # NOTE: This needs to be changed when changed in cluster
cluster.add_connection(Node1, Node2)
# TODO: length finding for priming
# DataLoader
# Frist dimension is the element, second is the data
data = jnp.array([[1],[2],[3]]) # jnp.array([[1, 2, 3],[1,4,6]]) for 3 inputs
target = jnp.array([[1],[2],[3]])
data_big = jnp.array([[1, 2, 3],[1,4,6]])
target_big = jnp.array([[1, 2, 3],[1,4,6]])
print("Data and target size", data.shape, data.shape)
print("Data and target first element", data[0], target[0])

data_handler = DataHandler(cluster, data, target)
data_handler_big = DataHandler(cluster, data_big, target_big)


Data and target size (3, 1) (3, 1)
Data and target first element [1] [1]


AssertionError: Data input size mismatch

In [None]:
print(data_handler.convert_data(data[1], "spike_amp"))

[[2.]
 [0.]]


In [None]:
cluster.print_connections()
cluster.get_neighbors(Node0, 1)
input = jnp.array([1])
print(input)
print(input.shape)
print(cluster.run(input))

[[0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 0.]]
[1]
(1,)
inputs:  [Array([0.], dtype=float32), Array([0.], dtype=float32), Array([], shape=(0,), dtype=float32)]
out_conn:  [[1], [2], []]
in_conn:  [[], [0], [1]]
[0.]


In [None]:
for data, target in data_handler:
    for part in data:
        print(part)

[1.]
[0.]
[2.]
[0.]
[3.]
[0.]


In [None]:
for data, target in data_handler_big:
    for part in data:
        print(part)