In [97]:
import jax 
import jax.numpy as jnp
from jax import random
from jax import grad, jit, vmap
import flax 
from flax import linen as nn
from functools import partial
from jax.flatten_util import ravel_pytree


import jax.tree_util as jtu
import matplotlib.pyplot as plt
import equinox as eqx
import optax
from tqdm.autonotebook import tqdm

  from tqdm.autonotebook import tqdm


## Autoregressive Residual Network (AResNet)

In [226]:
from typing import Callable

class ResBlockPostActivationPeriodic1d(eqx.Module):
    conv_1: eqx.nn.Conv1d # This is the first convolutional layer
    conv_2: eqx.nn.Conv1d # This is the second convolutional layer
    activation: Callable # This is the activation function, and the Callable type is a function that returns a value

    def __init__(
        self,
        hidden_channels: int,
        activation: Callable,
        *, # This is a separator to indicate that the following arguments are keyword-only
        key,
    ):
        c_1_key, c_2_key = jax.random.split(key)

        # Requires an up-to-dat version of Equinox
        self.conv_1 = eqx.nn.Conv1d( # This is the first convolutional layer
            hidden_channels, # This is the number of input channels
            hidden_channels, # This is the number of output channels
            kernel_size=2, # This is the kernel size, it means that the convolutional layer will use 3 points to make the convolution
            padding="SAME", # This is the padding, it means that the convolutional layer will use the same size of the input
            padding_mode="CIRCULAR", # This is the padding mode, it means that the convolutional layer will use the circular padding
            key=c_1_key 
        )
        self.conv_2 = eqx.nn.Conv1d(
            hidden_channels,
            hidden_channels,
            kernel_size=2,
            padding="SAME",
            padding_mode="CIRCULAR",
            key=c_2_key
        )

        self.activation = activation

    def __call__(
        self,
        x,
    ):
        x_skip_1 = x
        x = self.conv_1(x)
        x = x + x_skip_1 # Residual connection
        x = self.activation(x)
        x_skip_2 = x
        x = self.conv_2(x)
        x = x + x_skip_2 # Residual connection
        x = self.activation(x)

        return x
    
    def get_params(self):
        return [self.conv_1.weight,
                self.conv_1.bias, 
                self.conv_2.weight, 
                self.conv_2.bias]


In [171]:
class ResNetPeriodic1d(eqx.Module):
    lifting: eqx.nn.Conv1d # This is the first layer, the lifting layer that maps the input to the hidden space of the network with upper dimension
    blocks: list[ResBlockPostActivationPeriodic1d] # This is the list of the ResNet blocks that will be used in the network.
    projection: eqx.nn.Conv1d # This is another 1D convolutional layer with kernel size 1, which projects the final transformed data back to a lower-dimensional space.

    def __init__(
        self,
        hidden_channels,
        num_blocks,
        activation,
        *,
        key,
    ):
        l_key, *block_keys, p_key = jax.random.split(key, num_blocks+2)

        self.lifting = eqx.nn.Conv1d(1, hidden_channels, kernel_size=1, key=l_key)
        self.blocks = [
            ResBlockPostActivationPeriodic1d(hidden_channels, activation, key=k)
            for k in block_keys
        ]
        self.projection = eqx.nn.Conv1d(hidden_channels, 1, kernel_size=1, key=p_key)

    def __call__(self, x):
        x = self.lifting(x)
        for block in self.blocks:
            x = block(x)
        x = self.projection(x)

        return x
    
    def get_params_tree(self):
        # Retorna os parâmetros das camadas
        lifting_params = [
            self.lifting.weight,
            self.lifting.bias
        ]

        block_params = [block.get_params() for block in self.blocks]
        
        projection_params = [
            self.projection.weight,
            self.projection.bias
            ]

        return [
            lifting_params,
            block_params,
            projection_params
        ]
    
    def set_params(self, params):
        # Update lifting layer
        self = eqx.tree_at(lambda m: m.lifting.weight, self, params[0][0])
        self = eqx.tree_at(lambda m: m.lifting.bias, self, params[0][1])

        # Update ResBlocks
        for i, block in enumerate(self.blocks):
            self = eqx.tree_at(lambda m: m.blocks[i].conv_1.weight, self, params[1][i][0])
            self = eqx.tree_at(lambda m: m.blocks[i].conv_1.bias, self, params[1][i][1])
            self = eqx.tree_at(lambda m: m.blocks[i].conv_2.weight, self, params[1][i][2])
            self = eqx.tree_at(lambda m: m.blocks[i].conv_2.bias, self, params[1][i][3])

        # Update projection layer
        self = eqx.tree_at(lambda m: m.projection.weight, self, params[2][0])
        self = eqx.tree_at(lambda m: m.projection.bias, self, params[2][1])

        return self

### Testing get_params_tree() and set_params() methods

In [229]:
rede = ResNetPeriodic1d(1, 1, jax.nn.relu, key=jax.random.PRNGKey(1))
# Ravel the parameters into a flat vector for manipulation
rede.get_params_tree()
tree, tree_map = ravel_pytree(rede.get_params_tree())
tree
# key = jax.random.PRNGKey(0)
# random_tree = jax.random.uniform(key, shape=tree.shape)
# random_tree = tree_map(random_tree)

# print(random_tree)

# # Set new random parameters using the modified set_params method
# rede = rede.set_params(random_tree)

# # Check updated parameters
# print(rede.get_params_tree())


[[Array([[[0.31059694]]], dtype=float32), Array([[0.5476253]], dtype=float32)],
 [[Array([[[-0.09132952, -0.01586188]]], dtype=float32),
   Array([[0.3417482]], dtype=float32),
   Array([[[ 0.05476376, -0.09840479]]], dtype=float32),
   Array([[0.43600458]], dtype=float32)]],
 [Array([[[-0.19735742]]], dtype=float32),
  Array([[0.8156264]], dtype=float32)]]

## Multi Layer Perceptron (MLP)

In [101]:
# Creating a MultiLayer Perceptron (MLP) model in JAX
import jax 
import jax.numpy as jnp
import equinox as eqx
import optax
from typing import List

In [127]:
N_SAMPLES = 200
LAYERS = [1, 2, 1]
ACTIVATION = jax.nn.relu
LEARNING_RATE = 1e-3
N_ITER = 1000

In [129]:
class HyperNet(eqx.Module):
    layers: List[eqx.nn.Linear]

    def __init__(
        self,
        layers_size = [],
        key = None,
        activation = jax.nn.relu
    ):
        self.layers = []
        for (fan_in, fan_out) in zip(layers_size[:-1], layers_size[1:]):
            key, subkey = jax.random.split(key)
            self.layers.append(eqx.nn.Linear(fan_in, fan_out, key=subkey, use_bias=True))

    def __call__(self, x):
        for layer in self.layers[:-1]: # Iterate over all layers except the last one
            x = ACTIVATION(layer(x))
        return self.layers[-1](x) # Apply the last layer
    
    def get_params(self):
        return [layer.weight for layer in self.layers] + [layer.bias for layer in self.layers]

### Testing get_params() method

In [174]:
modelHN = HyperNet(LAYERS, key=key, activation=ACTIVATION)

modelHN.get_params()

[Array([[-0.19835997],
        [-0.7369585 ]], dtype=float32),
 Array([[-0.3148051 ,  0.26949587]], dtype=float32),
 Array([0.69720936, 0.8496475 ], dtype=float32),
 Array([-0.09411155], dtype=float32)]

## Hyper Network

In [193]:
#create a class that inherits from the HyperNet and ResNetPeriodic1d

class HyperResidualRNN(eqx.Module):
    hypernet: HyperNet
    targetnet: ResNetPeriodic1d
    def __init__(self, hypernet, targetnet):
        self.hypernet = hypernet
        self.targetnet = targetnet

        # Get the parameters for the targetnet
    def get_unravel(self):
        targetnet_params = self.targetnet.get_params_tree() #TODO removing the last layer
        # Get the unravelling function for the targetnet
        _,unravel = ravel_pytree(targetnet_params)
        return unravel

    def __call__(self, input_target, input_hyper, unravel_function):
        # Calculate the output of the hypernet
        input_hyper = self.hypernet(input_hyper)

        # Unravel the parameters
        targetnet_params = unravel_function(input_hyper)

        # Set the parameters of the targetnet
        self.targetnet.set_params(targetnet_params)

        input_target = self.targetnet(input_target)

        return input_target
        



In [195]:
# Create the HyperNet
hypernet = HyperNet(LAYERS, key=key, activation=ACTIVATION)

# Create the TargetNet
targetnet = ResNetPeriodic1d(1, 2, jax.nn.relu, key=jax.random.PRNGKey(1))

ravel_pytree(targetnet.get_params_tree())

# Create the HyperResidualRNN
model = HyperResidualRNN(hypernet, targetnet)

unravel = model.get_unravel()

optimizer = optax.adam(LEARNING_RATE)

opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

# Creating the loss function
@jax.jit
def loss_fn(model, batch):
    inputs, targets = batch
    input_target, input_hyper = inputs
    predictions = model(input_target, input_hyper, unravel)
    return jnp.mean((predictions - targets) ** 2)

# Função de atualização do modelo
@eqx.filter_jit
def step_fn(model, state, batch):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, batch)
    updates, new_state = optimizer.update(grad, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_state, loss



In [223]:
import pandas as pd
import numpy as np

# Configurações
n_samples = 1000  # número de amostras
time_steps = 40   # passos temporais (para a RNN)
n_features = 5    # número de features para a RNN
n_const_features = 3  # número de features constantes para a MLP
noise_level = 0.2  # nível de ruído

# Função para gerar sinal senoidal com frequência variável
def generate_sinusoidal_signal(time_steps, freq_base, freq_variation, noise_level):
    time = np.arange(time_steps)
    frequency = freq_base + np.random.uniform(-freq_variation, freq_variation)
    signal = np.sin(2 * np.pi * frequency * time / time_steps)  # Gera sinal senoidal
    noise = np.random.randn(time_steps) * noise_level  # Adiciona ruído
    return signal + noise

# Gerar série temporal com sinal senoidal para cada amostra e cada feature
temporal_data = np.zeros((n_samples, time_steps, n_features))

for i in range(n_samples):
    for j in range(n_features):
        freq_base = np.random.uniform(0.1, 1.0)  # Frequência base aleatória
        freq_variation = 0.05  # Variação da frequência
        temporal_data[i, :, j] = generate_sinusoidal_signal(time_steps, freq_base, freq_variation, noise_level)

# Gerar features constantes para a MLP (randomizadas)
const_data = np.random.randn(n_samples, n_const_features)

# Combinar tudo em um dataframe
df_temporal = pd.DataFrame(temporal_data.reshape(n_samples, -1),
                           columns=[f'feature_t{i}' for i in range(1, time_steps * n_features + 1)])
df_const = pd.DataFrame(const_data, columns=[f'const_feature{i}' for i in range(1, n_const_features + 1)])

# Combinar features temporais e constantes
df = pd.concat([df_temporal, df_const], axis=1)

# Exemplo de target (valor a ser previsto pela RNN)
df['target'] = np.random.randn(n_samples)

(df.head())


Unnamed: 0,feature_t1,feature_t2,feature_t3,feature_t4,feature_t5,feature_t6,feature_t7,feature_t8,feature_t9,feature_t10,...,feature_t195,feature_t196,feature_t197,feature_t198,feature_t199,feature_t200,const_feature1,const_feature2,const_feature3,target
0,0.316419,0.094006,0.180198,0.063395,-0.038635,-0.266741,0.46365,-0.027799,0.281768,0.130849,...,-0.183962,0.830583,-0.634478,0.580287,1.022595,-0.281525,-0.274235,-0.193651,-0.583533,0.722456
1,0.287647,-0.141836,-0.014514,-0.248675,0.633971,-0.472422,0.048249,0.410364,0.127561,0.005904,...,0.428507,-0.645935,-0.264288,-0.563421,-0.999967,0.457197,0.161862,-0.130422,0.33403,0.218963
2,0.431424,-0.11079,-0.12546,0.209351,-0.083805,-0.112344,0.197484,0.016432,-0.077888,-0.127886,...,-1.034519,-0.627236,-0.804753,-0.80435,-0.59027,-0.744307,-1.062656,-0.508828,0.316776,-0.856118
3,-0.091222,-0.278331,-0.173858,0.181299,-0.160004,0.034044,-0.264407,0.028212,-0.001304,-0.235955,...,-0.989921,0.925385,-0.800712,0.223292,-0.795376,-0.598333,0.241152,0.732813,-0.329903,-0.398819
4,-0.168778,0.210203,-0.084391,-0.261193,0.218225,-0.118044,-0.007894,0.172762,-0.115775,0.083989,...,0.528696,-0.916553,0.057136,-0.111954,0.328855,0.925494,0.171832,-1.219297,-0.579278,-1.25439
