## Tutorial \#1: Modules and Components

<b><i>Spark</i></b> is intended as a research framework. As such one of the first things that one may want is to add a new element to the systems. <br>
<br>
Fortunately, adding new components is quite straight forward, although it may be a little confusing the first time. Let's start from the beginning. In <b><i>Spark</i></b>, as in any other major library, every module inherits from <b>spark.nn.Module</b>. This class is at the heart of most of the <b><i>Spark</i></b>'s magic. Depending on the specificity of the module you are trying to build it could make sense to subclass other classes:

| Name | Purpose |
| --- | --- |
| spark.nn.Component | Arbitrary component of a neuron |
| spark.nn.somas.Soma | Soma models |
| spark.nn.synapses.Synanpses | Synaptic models |
| spark.nn.learning_rules.LearningRule | Learning rules |
| spark.nn.delays.Delays | Spike delay mechanisms |

This list is not exhaustive, but it contains the most important subclasses of <b>spark.nn.Module</b>. 

<b> REMARK: </b>One important thing to mention before we start is that Python does not strictly enforce typing, however, Spark does rely on typing for a few of our core features. Although some of your code may work without enforcing typing there are a few things that will require it if you like for your components to play nicely with the Spark ecosystem. In general, it is quite straight forward and is only required in a few parts, so if you really despise typing most of your code can still be typeless.

For the time being let's just create a simple module.

In [None]:
# Imports
import sys
sys.path.insert(1, './..')

import spark
import typing as tp
import jax.numpy as jnp

One of the first things to notice is new components are defined in pairs of Module - Configurations, where the Module defines the logic and the Configurations defines all the parameters required to initialize that Module. By convention we simply name these pairs as Module - ModuleConfig. 

Every time a new module is defined we need to link it to its default configuration. This is done by simply adding  <b>config: ConfigClass</b> to the Module class definition.

Another thing to notice at this point is that the signature of all the init methods is common across all modules, it should only accept a configuration class, which by default can be empty and keywords arguments. Similarly, the first thing to do in any <b>\_\_init\_\_</b> method is to invoke the <b>\_\_init\_\_</b> method of the super class with the pattern indicated below.

In [None]:
class MyAwesomeModuleConfig(spark.nn.Config):
    pass

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig       # <--- Default configuration class MUST always be indicated

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)       # <--- super().__init__ should always be invoked first as follows

At this point you will be wondering where are all your important init arguments go? The answer is the configuration class.<br>
<br>
Spark modules separate model definition from model logic. This allow us to do some neat tricks on the back and it is extremely useful for reproducibility. Configuration classes should always be typed, this is done using the notation <b>variable_name: variable_type = default_value</b>. 

All the variables defined inside the configuration class are available in the <b>\_\_init\_\_</b> method, after calling <b>super().\_\_init\_\_</b>, under the namespace <b>self.config</b>. 

Another important thing to notice here is that we cannot store arrays directly. Every array, must be properly wrapper within a <b>spark.Constant</b> or a <b>spark.Variable</b>. This wrappers are necessary when we JIT compile the model to let Jax know that some arrays may are mutable and some are simple constants. By default, some base python classes play well with JIT but we highly recomed to wrap everything around a Constant or a Variable according to its role in your model. 

In [23]:
class MyAwesomeModuleConfig(spark.nn.Config):
    foo: int
    bar: float = 2.0

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)
        #self.foo = jnp.array(self.config.foo) <--- # Will throw an error.
        self.foo = spark.Constant(jnp.array(self.config.foo))
        self.bar = spark.Variable(jnp.array(self.config.bar))

Now, <b>spark.nn.Module</b> is an abstract class and as such it cannot be instantiated if some methods are not defined first. In this case we only need the <b>\_\_init\_\_</b> and the <b>\_\_call\_\_</b> methods. However, note that depending on the particular subclass of <b>Module</b> other functions may be necessary.

Apart from the configuration class, <b>\_\_call\_\_</b> is the other strongly typed element in <b><i>Spark</i></b>. The first thing to notice is that <b>\_\_call\_\_</b> does NOT accept positional arguments only keyword arguments. This keyword arguments must always be typed and the type must always inherit from <b>spark.SparkPayload</b>. Payloads are just wrappers around arrays that helps the internal machinery to know what can connect with what. Every array withing any default payload can be access via <b>my_payload.value</b>. 

Next, <b>\_\_call\_\_</b> must always specify what it returns. This is done through the arrow indication <b>\-\></b> at the end of the <b>\_\_call\_\_</b>. The return type is always a TypedDict, that defines the names and the types of each variable that you intend to return after the <b>\_\_call\_\_</b>. Note that again, all return types must inherit from <b>spark.SparkPayload</b>. 

Finally, the return of the <b>\_\_call\_\_</b> is a dictionary that contains all the variables that you specified in TypedDict, with their respective format.
This few type hints really provide the internal machinery with guides on what to do under certain circumstances.

And this is all typing that you need to do!. No more typing after this if you do not like it!.

In [25]:
class MyAwesomeOutput(tp.TypedDict):
    my_awesome_output: spark.FloatArray

class MyAwesomeModuleConfig(spark.nn.Config):
    foo: int
    bar: float = 2.0

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)
        self.foo = spark.Constant(jnp.array(self.config.foo))
        self.bar = spark.Variable(jnp.array(self.config.bar))

    def __call__(self, my_awesome_input: spark.FloatArray) -> MyAwesomeOutput:
        awesome_output = self.foo + self.bar + my_awesome_input
        return {
            'my_awesome_output': spark.FloatArray(awesome_output)
        }

In [None]:
class MyAwesomeOutput(tp.TypedDict):
    my_awesome_output: spark.FloatArray

class MyAwesomeModuleConfig(spark.nn.Config):
    foo: int
    bar: float = 2.0

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)
        self.foo = spark.Constant(jnp.array(self.config.foo))

    def build(self, input_specs):
        mai_spec = input_specs['my_awesome_input']
        self.bar = spark.Variable(
            mai_spec.payload_type( self.config.bar * jnp.ones(mai_spec.shape) )
        )

    def __call__(self, my_awesome_input: spark.FloatArray) -> MyAwesomeOutput:
        awesome_output = self.foo + self.bar + my_awesome_input
        return {
            'my_awesome_output': spark.FloatArray(awesome_output)
        }

In [None]:
class MyAwesomeOutput(tp.TypedDict):
    my_awesome_output: spark.FloatArray

class MyAwesomeModuleConfig(spark.nn.Config):
    foo: int
    bar: float = 2.0

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)
        self.foo = spark.Constant(jnp.array(self.config.foo))

    def build(self, input_specs):
        mai_spec = input_specs['my_awesome_input']
        self.bar = spark.Variable(
            mai_spec.payload_type( self.config.bar * jnp.ones(mai_spec.shape) )
        )

    def __call__(self, my_awesome_input: spark.FloatArray) -> MyAwesomeOutput:
        awesome_output = self.foo + self.bar + my_awesome_input
        return {
            'my_awesome_output': spark.FloatArray(awesome_output)
        }

In [24]:
# Method 1
awesome = MyAwesomeModule(foo = 1)
my_awesome_input = spark.FloatArray(jnp.array(1))
res = awesome(my_awesome_input=my_awesome_input)
print(f'Method 1\n {res['my_awesome_output'].value}\n')

# Method 2
awesome_config = MyAwesomeModuleConfig(foo = 1)
awesome = MyAwesomeModule(config=awesome_config)
my_awesome_input = spark.FloatArray(jnp.arange(5))
res = awesome(my_awesome_input=my_awesome_input)
print(f'Method 2\n {res['my_awesome_output'].value}\n')

# Method 3
awesome_config = MyAwesomeModuleConfig(foo = 1)
awesome = MyAwesomeModule(config=awesome_config, bar=-1)
my_awesome_input = spark.FloatArray(jnp.arange(4).reshape(2,2))
res = awesome(my_awesome_input=my_awesome_input)
print(f'Method 3\n {res['my_awesome_output'].value}\n')

Method 1
 [4.]

Method 2
 [3. 4. 5. 6. 7.]

Method 3
 [[0. 1.]
 [2. 3.]]

