In this network we try to build a multitask network using jax.

In [129]:
import jax
from typing import Sequence, Callable
from jax import numpy as jnp
from flax import linen as nn
from jax import lax

In [130]:
key = jax.random.PRNGKey(0)
key_data, key_network = jax.random.split(key)

In [141]:
X = jnp.expand_dims(jax.random.normal(key_data, (100, 2)), axis=0)
print(X.shape)
X = jnp.repeat(X, repeats=1, axis=0)
print(X.shape)

(1, 100, 2)
(1, 100, 2)


In [164]:
class MultiTaskDense(nn.Module):
    features: int
    n_tasks: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros

    @nn.compact
    def __call__(self, inputs):
        kernel = self.param('kernel', self.kernel_init,(self.n_tasks, inputs.shape[-1], self.features)) 
        y = lax.dot_general(inputs, kernel, dimension_numbers=(((2, ), (1, )), ((0, ), (0, ))))
        bias = self.param('bias', self.bias_init, (self.n_tasks, 1, self.features))
        y = y + bias
        return y

In [145]:
kernel = jax.random.normal(key, (5, 2, 50))

In [151]:
kernel.shape

(5, 2, 50)

In [157]:
lax.dot_general(X, kernel, dimension_numbers=(((2, ), (1, )), ((), ()))).shape

(1, 100, 5, 50)

In [149]:
forward = jax.vmap(lambda x: lax.dot_general(x, kernel, dimension_numbers=(((2, ), (1, )), ((0, ), (0, )))), in_axes=0, out_axes=0)

In [None]:
forward(X)

In [167]:
layer = MultiTaskDense(50, 5)

In [168]:
params = layer.init(key, X)

TypeError: dot_general requires lhs batch dimensions and rhs batch dimensions to have the same shape, got [1] and [5].

In [135]:
layer.apply(params, X).shape

(5, 100, 50)

In [136]:
class MLP(nn.Module):
    """Simple feed-forward NN.
    """

    features: Sequence[int]  # this is dataclass, so we dont use __init__
    n_networks: int

    @nn.compact  # this function decorator lazily intializes the model, so it makes the layers the first time we call it
    def __call__(self, inputs):
        x = inputs  # we overwrite x so we copy it to a new tensor
        for feature in self.features[:-1]:
            x = nn.tanh(MultiTaskDense(feature, self.n_networks)(x))
        x = nn.Dense(self.features[-1], self.n_networks)(x)
        return x

In [137]:
model = MLP([10, 10], 5)

In [138]:
params = model.init(key_network, X)

In [139]:
model.apply(params, X)[0, :, :] - model.apply(params, X)[1, :, :]

DeviceArray([[-4.15464878e-01, -1.53274387e-01,  3.99248540e-01,
              -2.81798720e-01,  1.09057570e+00, -1.27116546e-01,
              -3.25902462e-01,  2.40228102e-01,  1.06225275e-01,
               1.56591713e-01],
             [ 1.30599618e-01,  5.47062531e-02, -1.59760609e-01,
               1.26406267e-01, -3.15526485e-01,  9.05806124e-02,
               4.28541899e-02,  1.50267854e-02, -1.48729421e-02,
              -2.99254470e-02],
             [-1.96892768e-02, -6.60598278e-04, -4.98154461e-01,
               3.58352840e-01, -5.09266108e-02,  3.91472638e-01,
              -3.64516377e-01,  5.62013865e-01,  2.60310397e-02,
               5.06199375e-02],
             [ 3.19861054e-01,  1.26700684e-01, -3.81382912e-01,
               2.95952618e-01, -8.10271621e-01,  1.86908230e-01,
               1.45442873e-01, -3.20778042e-02, -5.25340736e-02,
              -8.56416076e-02],
             [-4.01072949e-01, -1.48670033e-01,  5.50629318e-01,
              -4.23782676e-

In [None]:
class MultiTaskMLP(nn.Module):
    """Simple feed-forward NN.
    """

    shared_features: Sequence[int]  # this is dataclass, so we dont use __init__
    specific_features: Sequence[int]
    n_tasks: int

    @nn.compact  # this function decorator lazily intializes the model, so it makes the layers the first time we call it
    def __call__(self, inputs):
        x = inputs  # we overwrite x so we copy it to a new tensor
        for feature in self.shared_features:
            x = nn.tanh(nn.Dense(feature)(x))
        x = jnp.repeat(
            jnp.expand_dims(x, axis=0), repeats=self.n_tasks, axis=0
        )  # If we batch, can we do without copying data?
        for feature in self.specific_features[:-1]:
            x = nn.tanh(MultiTaskDense(feature, self.n_tasks)(x))
        x = MultiTaskDense(self.specific_features[-1], self.n_tasks)(x)
        return x
