# Design Decisions

Here, we write about why we made the decisions we did when designing Solstice. 

## The problem

In machine learning projects, you generally have at least four main parts of your code:

- model
- dataset
- training strategy (i.e. optimization, train/eval step logic, metrics calculation)
- training/testing loops (including logging, checkpointing etc...)

In many research projects, it is helpful to be able to swap out any of these parts on a whim (ideally, just by changing your config files). In practice, it follows that **the better you can decouple these four parts of your code, the faster you can iterate your experiments.** 

Much attention is paid to the four components individually, but many researchers then just throw everything together haphazardly. It can require quite a significant engineering effort to properly decouple all the components, so most people don't bother. Wouldn't it be great if there was a standard way of organising your code to rapidly scale and iterate experiments...

PyTorch Lightning has filled this niche for PyTorch, but there are still few options when doing research in JAX.

## A key idea

In object-oriented code, classes encapsulate both state and transformations on that state. Crucially, the state is mutable, so calling a method might return nothing but alter the object's state. This is a *side effect*. In functional programming, side effects are avoided by making state immutable, so instead of calling a method to update an object's state, you could pass the object to a pure function, returning a new object with the altered state. This is generally an easier paradigm to reason about and immutability is also needed in JAX for XLA to work its magic. 

This is all great, but Python is not really designed for the functional paradigm so it is difficult to fully decouple all the parts of your code. Type hinting functions with `Protocols` can get you surprisingly far, but at some point you will probably want to achieve some level of encapsulation and use abstract base classes to get dependency inversion.

The approach we take in Solstice is to use immutable dataclasses to try to get the best of both worlds, the code below shows how you would implement a simple counter in each of the paradigms.

In [6]:
class OOPCounter:
    def __init__(self, initial_value: int = 0) -> None:
        self.count = initial_value
    
    def increment(self) -> None:
        self.count += 1

# 'initialise' the OO counter 
oop_counter = OOPCounter()
print(f"{oop_counter.count=}, object id: {id(oop_counter)}")

# 'apply' the increment method, updating the counter object's internal state
oop_counter.increment()
print(f"after incrementing {oop_counter.count=}, object id: {id(oop_counter)}")

###############################################################################

def functional_increment(current_value: int) -> int:
    return current_value + 1

# 'initialise' the functional counter
functional_count = 0
print(f"{functional_count=}, object id: {id(functional_count)}")

# 'apply' the functional increment method, returning a new state object
functional_count = functional_increment(functional_count)
print(f"after incrementing {functional_count=}, object id: {id(functional_count)}")

###############################################################################

import dataclasses

@dataclasses.dataclass(frozen=True)
class SolsticeStyleCounter:
    count: int = 0

    def increment(self) -> "SolsticeStyleCounter":
        return dataclasses.replace(self, count=self.count + 1)

# 'initialise' the SolsticeStyleCounter
solstice_style_counter = SolsticeStyleCounter()
print(f"{solstice_style_counter.count=}, object id: {id(solstice_style_counter)}")

# 'apply' the increment method, returning a new state object
solstice_style_counter = solstice_style_counter.increment()
print(f"after incrementing {solstice_style_counter.count=}, object id: {id(solstice_style_counter)}")



oop_counter.count=0, object id: 140693673590944
after incrementing oop_counter.count=1, object id: 140693673590944
functional_count=0, object id: 140693771819280
after incrementing functional_count=1, object id: 140693771819312
solstice_style_counter.count=0, object id: 140693699787456
after incrementing solstice_style_counter.count=1, object id: 140693699786064


Notice that the Solstice style counter did not mutate its state, it returned a new instance of itself. In practice, in machine learning, this means we can replace the common init/apply pure functions with methods in a frozen dataclass (usually `__init__()`, and `__call__()`). There is one final matter to take care of... JAX only operates on PyTrees and doesn't know how to deal with dataclasses. This is why we build Solstice on top of Equinox, because an `equinox.Module` is just a dataclass which is registered as a PyTree.

## Why these 4 abstractions?

TODO