In [2]:
import jax

# 1. Abstract method to enforce overloading of certain functions

In [3]:
# create a dummy model as a ABC class with dataclass
from abc import ABC, abstractmethod

# AbstractLayer is an abstract class which cannot be instantiated
# ABC classes with abstract methods insure that all subclasses implement the abstract methods
# (It doesn't make sense to use the @dataclass decorator with an ABC class)
class AbstractLayer(ABC):
    @abstractmethod
    def __call__(self, x):
        pass

layer = AbstractLayer() # this will raise an error

TypeError: Can't instantiate abstract class AbstractLayer without an implementation for abstract method '__call__'

In [4]:
class DummyLayer(AbstractLayer):
    def __init__(self, weight, bias):
        self.weight = weight
        self.bias = bias

    def __repr__(self):
        return f"DummyLayer(\nweight=\n{self.weight},\nbias={self.bias}\n)"

    def __call__(self, x):
        return self.weight @ x + self.bias
    
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
w = jax.random.normal(subkey, (3, 2))
b = jax.random.normal(subkey, (3,))
layer = DummyLayer(w, b) # works fine
print(layer)

DummyLayer(
weight=
[[-1.4581939 -2.047044 ]
 [-1.4242861  1.1684095]
 [-0.9758364 -1.2718494]],
bias=[ 1.1378784  -1.2209548  -0.59153634]
)


We can recover the "genealogy" of a class with the __bases__ and __name__ attributes. <font color='red'>WARNING: It doesn't provide all the information. Example: ```Linear(eqx.Module)``` doesn't show the ```ABC``` class. </font>

In [7]:
print(type(layer))
print(type(layer).__bases__)
print(type(layer).__bases__[0].__name__)
print(type(layer).__bases__[0].__bases__)

<class '__main__.DummyLayer'>
(<class '__main__.AbstractLayer'>,)
AbstractLayer
(<class 'abc.ABC'>,)


In [None]:
import equinox as eqx
class Linear(eqx.Module):
    weight: jnp.ndarray
    bias: jnp.ndarray

    def __call__(self, x):
        return self.weight @ x + self.bias
    

print(type(model))
print(type(model).__bases__)
print(type(model).__bases__[0].__bases__)

# 2. @dataclass decorator to simplify the class definition (```__init__``` and ```__repr__``` are implicitly defined)
Also it adds three methods

In [None]:
from dataclasses import dataclass

class Point:
    def __init__(self, x, y):
        self.x = x
        self.y = y

@dataclass
class PointDataclass:
    def __init__(self, x, y):
        self.x = x
        self.y = y

p1 = Point(1, 2)
p2 = PointDataclass(1, 2)

print(f'The class representation doesn\'t change with @dataclass')
print(f'{Point=}')
print(f'{PointDataclass=}')

print(f'\nHowever, the instance representation changes with @dataclass')
print(f'{p1=}')
print(f'{p2=}')

print(f'\n@dataclass adds attributes/methods:')
print(f'{dir(Point)=}')
print(f'{dir(PointDataclass)=}')
set_a = set(dir(Point))
set_b = set(dir(PointDataclass))
print(f'More precisely, @dataclass adds the following attributes/methods: {set_b - set_a}')

The class representation doesn't change with @dataclass
Point=<class '__main__.Point'>
PointDataclass=<class '__main__.PointDataclass'>

However, the instance representation changes with @dataclass
p1=<__main__.Point object at 0x11f558e30>
p2=PointDataclass()

@dataclass adds attributes/methods:
dir(Point)=['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__']
dir(PointDataclass)=['__class__', '__dataclass_fields__', '__dataclass_params__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__match_args__', '__module__', '__ne__', '__new__', '_

In [None]:
@dataclass
class EnhancedDummyLayer(AbstractLayer):
    weight: jax.Array
    bias: jax.Array

    def __call__(self, x):
        return self.weight @ x + self.bias