In [14]:
import jax 
import jax.numpy as jnp
from jax import random
from jax import grad, jit, vmap
import flax 
from flax import linen as nn
from functools import partial
from jax.flatten_util import ravel_pytree


import jax.tree_util as jtu
import matplotlib.pyplot as plt
import equinox as eqx
import optax
from tqdm.autonotebook import tqdm

## Autoregressive Residual Network (AResNet)

In [15]:
from typing import Callable

class ResBlockPostActivationPeriodic1d(eqx.Module):
    conv_1: eqx.nn.Conv1d # This is the first convolutional layer
    conv_2: eqx.nn.Conv1d # This is the second convolutional layer
    activation: Callable # This is the activation function, and the Callable type is a function that returns a value

    def __init__(
        self,
        hidden_channels: int,
        activation: Callable,
        *, # This is a separator to indicate that the following arguments are keyword-only
        key,
    ):
        c_1_key, c_2_key = jax.random.split(key)

        # Requires an up-to-dat version of Equinox
        self.conv_1 = eqx.nn.Conv1d( # This is the first convolutional layer
            hidden_channels, # This is the number of input channels
            hidden_channels, # This is the number of output channels
            kernel_size=2, # This is the kernel size, it means that the convolutional layer will use 3 points to make the convolution
            padding="SAME", # This is the padding, it means that the convolutional layer will use the same size of the input
            padding_mode="CIRCULAR", # This is the padding mode, it means that the convolutional layer will use the circular padding
            key=c_1_key 
        )
        self.conv_2 = eqx.nn.Conv1d(
            hidden_channels,
            hidden_channels,
            kernel_size=2,
            padding="SAME",
            padding_mode="CIRCULAR",
            key=c_2_key
        )

        self.activation = activation

    def __call__(
        self,
        x,
    ):
        x_skip_1 = x
        x = self.conv_1(x)
        x = x + x_skip_1 # Residual connection
        x = self.activation(x)
        x_skip_2 = x
        x = self.conv_2(x)
        x = x + x_skip_2 # Residual connection
        x = self.activation(x)

        return x
    
    def get_params(self):
        return [self.conv_1.weight,
                self.conv_1.bias, 
                self.conv_2.weight, 
                self.conv_2.bias]


In [53]:
class ResNetPeriodic1d(eqx.Module):
    lifting: eqx.nn.Conv1d # This is the first layer, the lifting layer that maps the input to the hidden space of the network with upper dimension
    blocks: list[ResBlockPostActivationPeriodic1d] # This is the list of the ResNet blocks that will be used in the network.
    projection: eqx.nn.Conv1d # This is another 1D convolutional layer with kernel size 1, which projects the final transformed data back to a lower-dimensional space.

    def __init__(
        self,
        hidden_channels,
        num_blocks,
        activation,
        *,
        key,
    ):
        l_key, *block_keys, p_key = jax.random.split(key, num_blocks+2)

        self.lifting = eqx.nn.Conv1d(1, hidden_channels, kernel_size=1, key=l_key)
        self.blocks = [
            ResBlockPostActivationPeriodic1d(hidden_channels, activation, key=k)
            for k in block_keys
        ]
        self.projection = eqx.nn.Conv1d(hidden_channels, 1, kernel_size=1, key=p_key)

    def __call__(self, x):
        x = self.lifting(x)
        for block in self.blocks:
            x = block(x)
        x = self.projection(x)

        return x
    
    def get_params_tree(self):
        # Retorna os parâmetros das camadas
        lifting_params = [
            self.lifting.weight,
            self.lifting.bias
        ]

        block_params = [block.get_params() for block in self.blocks]
        
        projection_params = [
            self.projection.weight,
            self.projection.bias
            ]

        return [
            lifting_params,
            block_params,
            projection_params
        ]
    
    def set_params(self, params):
        # Update lifting layer
        self = eqx.tree_at(lambda m: m.lifting.weight, self, params[0][0])
        self = eqx.tree_at(lambda m: m.lifting.bias, self, params[0][1])

        # Update ResBlocks
        for i, block in enumerate(self.blocks):
            self = eqx.tree_at(lambda m: m.blocks[i].conv_1.weight, self, params[1][i][0])
            self = eqx.tree_at(lambda m: m.blocks[i].conv_1.bias, self, params[1][i][1])
            self = eqx.tree_at(lambda m: m.blocks[i].conv_2.weight, self, params[1][i][2])
            self = eqx.tree_at(lambda m: m.blocks[i].conv_2.bias, self, params[1][i][3])

        # Update projection layer
        self = eqx.tree_at(lambda m: m.projection.weight, self, params[2][0])
        self = eqx.tree_at(lambda m: m.projection.bias, self, params[2][1])

        return self
    
    def get_num_params(self):
        return sum([p.size for p in jax.tree.flatten(self.get_params_tree())[0]])

### Testing get_params_tree() and set_params() methods

In [54]:
rede = ResNetPeriodic1d(1, 1, jax.nn.relu, key=jax.random.PRNGKey(1))
# Ravel the parameters into a flat vector for manipulation
param_list = (rede.get_params_tree())
tree, tree_map = ravel_pytree(param_list)
print(jax.tree_map(lambda x: x.shape, param_list))

print(tree.shape[0])
(param_list)

rede.get_num_params()
# key = jax.random.PRNGKey(0)
# random_tree = jax.random.uniform(key, shape=tree.shape)
# random_tree = tree_map(random_tree)

# print(random_tree)

# # Set new random parameters using the modified set_params method
# rede = rede.set_params(random_tree)

# # Check updated parameters
# print(rede.get_params_tree())


[[(1, 1, 1), (1, 1)], [[(1, 1, 2), (1, 1), (1, 1, 2), (1, 1)]], [(1, 1, 1), (1, 1)]]
10


10

## Multi Layer Perceptron (MLP)

In [18]:
# Creating a MultiLayer Perceptron (MLP) model in JAX
import jax 
import jax.numpy as jnp
import equinox as eqx
import optax
from typing import List

In [42]:
N_SAMPLES = 200
LAYERS = [1,1]
ACTIVATION = jax.nn.relu
LEARNING_RATE = 1e-3
N_ITER = 1000

In [38]:
class HyperNet(eqx.Module):
    layers: List[eqx.nn.Linear]

    def __init__(
        self,
        layers_size = [],
        last_layer_size = None,
        key = None,
    ):
        self.layers = []
        for (fan_in, fan_out) in zip(layers_size[:-1], layers_size[1:]):
            key, subkey = jax.random.split(key)
            self.layers.append(eqx.nn.Linear(fan_in, fan_out, key=subkey, use_bias=True))
        
        self.layers.append(eqx.nn.Linear(layers_size[-1], last_layer_size, key=subkey, use_bias=True))

    def __call__(self, x):
        for layer in self.layers[:-1]: # Iterate over all layers except the last one
            x = jax.nn.relu(layer(x))
            # the last layer must not have an activation function in addition the output must have the last_layer_size shape
        return self.layers[-1](x)
    
    def get_params(self):
        return [[layer.weight] + [layer.bias] for layer in self.layers]

### Testing get_params() method

In [46]:
key = jax.random.PRNGKey(1)
modelHN = HyperNet(LAYERS, last_layer_size= 2, key=key)

modelHN.get_params()

[[Array([[0.5729215]], dtype=float32), Array([-0.37414527], dtype=float32)],
 [Array([[ 0.4417026],
         [-0.981369 ]], dtype=float32),
  Array([ 0.73462176, -0.4050598 ], dtype=float32)]]

## Hyper Network

In [48]:
#create a class that inherits from the HyperNet and ResNetPeriodic1d

class HyperResidualRNN(eqx.Module):
    hypernet: HyperNet
    targetnet: ResNetPeriodic1d
    def __init__(self, hypernet, targetnet):
        self.hypernet = hypernet
        self.targetnet = targetnet

        # Get the parameters for the targetnet
    def get_unravel(self):
        targetnet_params = self.targetnet.get_params_tree() #TODO removing the last layer
        # Get the unravelling function for the targetnet
        _,unravel = ravel_pytree(targetnet_params)
        return unravel

    def __call__(self, input_target, input_hyper,unravel_fn):
        # Calculate the output of the hypernet
        input_hyper = self.hypernet(input_hyper) # input in the hypernet to the output of the hypernet with the vector of targetnet parameters

        # Unravel the parameters
        targetnet_params = unravel_fn(input_hyper) # transform the output of the hypernet into the parameters of the targetnet

        # Set the parameters of the targetnet
        self.targetnet.set_params(targetnet_params) # set the parameters of the targetnet with the output of the hypernet

        input_target = self.targetnet(input_target) 

        return input_target
        



In [56]:
# Create the TargetNet
targetnet = ResNetPeriodic1d(1, 2, jax.nn.relu, key=jax.random.PRNGKey(1))
targetnet_num_params = targetnet.get_num_params()

# Create the HyperNet
hypernet = HyperNet(LAYERS,last_layer_size = targetnet_num_params, key=key, activation=ACTIVATION)


ravel_pytree(targetnet.get_params_tree())

# Create the HyperResidualRNN
model = HyperResidualRNN(hypernet, targetnet)

unravel = model.get_unravel()

optimizer = optax.adam(LEARNING_RATE)

opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

# Creating the loss function
@jax.jit
def loss_fn(model, batch):
    inputs, targets = batch
    input_target, input_hyper = inputs
    predictions = model(input_target, input_hyper, unravel)
    return jnp.mean((predictions - targets) ** 2)

# Função de atualização do modelo
@eqx.filter_jit
def step_fn(model, state, batch):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, batch)
    updates, new_state = optimizer.update(grad, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_state, loss

opt_state

(ScaleByAdamState(count=Array(0, dtype=int32), mu=HyperResidualRNN(
   hypernet=HyperNet(
     layers=[
       Linear(
         weight=f32[1,1],
         bias=f32[1],
         in_features=1,
         out_features=1,
         use_bias=True
       ),
       Linear(
         weight=f32[16,1],
         bias=f32[16],
         in_features=1,
         out_features=16,
         use_bias=True
       )
     ]
   ),
   targetnet=ResNetPeriodic1d(
     lifting=Conv1d(
       num_spatial_dims=1,
       weight=f32[1,1,1],
       bias=f32[1,1],
       in_channels=1,
       out_channels=1,
       kernel_size=(1,),
       stride=(1,),
       padding=((0, 0),),
       dilation=(1,),
       groups=1,
       use_bias=True,
       padding_mode='ZEROS'
     ),
     blocks=[
       ResBlockPostActivationPeriodic1d(
         conv_1=Conv1d(
           num_spatial_dims=1,
           weight=f32[1,1,2],
           bias=f32[1,1],
           in_channels=1,
           out_channels=1,
           kernel_size=(2,),
   