In [1]:
import sys
from pathlib import Path

src = str(Path('../src').resolve())
if src not in sys.path:
    sys.path.append(src)
    
import os
os.environ['JAX_PLATFORMS'] = 'cpu'

In [2]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx

def generate_synthetic_data(batch_size, static_in_size, daily_in_size, irregular_in_size, seq_length):
    x_s = jax.random.normal(jax.random.PRNGKey(0), (batch_size, static_in_size))
    x_dd = jax.random.normal(jax.random.PRNGKey(1), (batch_size, seq_length, daily_in_size))
    x_di = jax.random.normal(jax.random.PRNGKey(2), (batch_size, seq_length, irregular_in_size))

    data = {
        'x_s': x_s,
        'x_dd': x_dd,
        'x_di': x_di
    }    
    return data

batch_size = 8
static_in_size = 14
daily_in_size = 22  
irregular_in_size = 12  
seq_length = 365
hidden_size = 64
out_size = 4
num_heads = 4
dropout = 0.5

# Generate synthetic data
synthetic_data = generate_synthetic_data(batch_size, static_in_size, daily_in_size, irregular_in_size, seq_length)

In [3]:
import models
from importlib import reload
reload(models)
from models.tft_mha import TFT_MHA

# Initialize the TFT model
dynamic_sizes = {'x_dd': daily_in_size, 'x_di': irregular_in_size}
model = TFT_MHA(
    dynamic_sizes = dynamic_sizes,
    static_size = static_in_size,
    hidden_size=hidden_size,
    out_size=out_size,
    num_heads=num_heads,
    dropout=dropout,
    seed=0
)

num_params, memory_bytes = models.count_parameters(model)
size, unit = models.human_readable_size(memory_bytes)
print(f"Model contains {num_params:,} parameters, using {size:.2f}{unit} memory.")

# Run one step of the model
key = jax.random.PRNGKey(0)
batch_keys = jax.random.split(key, batch_size)
output = jax.vmap(model)(synthetic_data, batch_keys)

print("Output shape:", output.shape)

Model contains 286,084 parameters, using 1.09MB memory.
Output shape: (8, 4)


In [4]:
model.encoder

HybridEncoder(
  dynamic_blocks={
    'x_dd':
    HybridEncoderBlock(
      head_context=StaticContextHeadBias(
        out_shape=(4, 16),
        linear=Linear(
          weight=f32[64,64],
          bias=f32[64],
          in_features=64,
          out_features=64,
          use_bias=True
        ),
        dropout=Dropout(p=0.5, inference=False),
        layernorm=LayerNorm(
          shape=(4, 16),
          eps=1e-05,
          use_weight=True,
          use_bias=True,
          weight=f32[4,16],
          bias=f32[4,16]
        )
      ),
      dynamic_proj=Linear(
        weight=f32[64,22],
        bias=f32[64],
        in_features=22,
        out_features=64,
        use_bias=True
      ),
      attn=MultiheadAttention(
        query_proj=Linear(
          weight=f32[64,64],
          bias=None,
          in_features=64,
          out_features=64,
          use_bias=False
        ),
        key_proj=Linear(
          weight=f32[64,64],
          bias=None,
          in_features

In [11]:
submodel = model.encoder.dynamic_blocks

sub_num_params, memory_bytes = models.count_parameters(submodel)
print(f"Submodel contains {sub_num_params:,} parameters ({sub_num_params/float(num_params)*100:0.2f}% of model)")

Submodel contains 1,717,248 parameters (54.21% of model)


In [None]:
class GatedLinearUnit(eqx.Module):
    gates: eqx.nn.Linear
    linear: eqx.nn.Linear

    def __init__(self, input_size: int, output_size: int, *, key):
        keys = jrandom.split(key)
        self.gates = eqx.nn.Linear(input_size, output_size, key=keys[0])
        self.linear = eqx.nn.Linear(input_size, output_size, key=keys[1])

    def __call__(self, gamma):
        gates = jax.nn.sigmoid(self.gates(gamma))
        return gates * self.linear(gamma)

class GatedResidualNetwork(eqx.Module):
    eta2_dynamic: eqx.nn.Linear
    eta2_static: eqx.nn.Linear
    eta2_bias: jnp.ndarray
    eta1_linear: eqx.nn.Linear
    glu: GatedLinearUnit
    layer_norm: eqx.nn.LayerNorm

    def __init__(self, grn_size, context_size=None, *, key):
        if isinstance(grn_size, tuple):
            input_size, hidden_size, output_size = grn_size
        elif isinstance(grn_size, int):
            input_size = hidden_size = output_size = grn_size
        else:
            raise ValueError("grn_size must either be a tuple or int for input, hidden, and output sizes")
            
        keys = jax.random.split(key, 5)
        self.eta2_dynamic = eqx.nn.Linear(input_size, hidden_size, use_bias=False, key=keys[0])
        if context_size is not None:
            self.eta2_static = eqx.nn.Linear(context_size, hidden_size, use_bias=False, key=keys[1])
        else:
            self.eta2_static = None
        self.eta2_bias = jax.random.uniform(keys[2], (hidden_size,))

        self.eta1_linear = eqx.nn.Linear(hidden_size, output_size, key=keys[3])
        self.glu = GatedLinearUnit(hidden_size, hidden_size, key=keys[4])
        self.layer_norm = eqx.nn.LayerNorm(output_size)

    def __call__(self, input: jnp.ndarray, context: jnp.ndarray = None) -> jnp.ndarray:
        if self.eta2_static and context is not None:
            context_term = self.eta2_static(context)
        elif self.eta2_static or context is not None:
            raise ValueError("Either eta2_static was created and no context was passed during call, " +
                             "or context was passed during call with no eta2_static created during init.")
        else:
            context_term = 0
        eta2 = jax.nn.elu(self.eta2_dynamic(input) + context_term + self.eta2_bias)
        eta1 = self.eta1_linear(eta2)
        return self.layer_norm(input + self.glu(eta1))
    
    
class VariableSelectionNetwork(eqx.Module):
    variable_transformers: list
    variable_processors: list
    weights_grn: GatedResidualNetwork
    
    def __init__(self, hidden_size, num_variables, context_size=None, key=None):
        keys = jax.random.split(key, 4)   

        transformer_keys = jax.random.split(keys[0], num_variables)
        self.variable_transformers = [
            eqx.nn.Linear(1, hidden_size, key=k)
            for k in transformer_keys
        ]

        processor_keys = jax.random.split(keys[1], num_variables)
        self.variable_processors = [
            GatedResidualNetwork(hidden_size, key=k)
            for k in processor_keys
        ]

        weights_size = num_variables*hidden_size
        self.weights_grn = GatedResidualNetwork(weights_size, context_size, key=keys[2])
        
    
    def __call__(self, inputs, context=None):
        # inputs shape: (seq_length, num_variables)
        seq_length, num_variables = inputs.shape
        
        transformed_inputs = jnp.stack([
            jax.vmap(processor)(inputs[:, i][:,jnp.newaxis])
            for i, processor in enumerate(self.variable_transformers)
        ], axis=1)

        # Process each variable
        processed_inputs = jnp.stack([
            jax.vmap(processor)(transformed_inputs[:, i], None)
            for i, processor in enumerate(self.variable_processors)
        ], axis=1)

        # Generate variable selection weights
        flattened = processed_inputs.reshape([seq_length,-1])
        flat_weights = jax.vmap(self.weights_grn)(flattened, context)
        flat_weights = jax.nn.softmax(flat_weights, axis=-1)
        variable_weights = flat_weights.reshape(transformed_inputs.shape)
        
        # Weight and sum the processed inputs
        weighted_inputs = variable_weights * processed_inputs
        outputs = jnp.sum(weighted_inputs, axis=1)
        
        return outputs

In [None]:
synthetic_data['x_s'][0,:][jnp.newaxis,:].shape

In [None]:

d_context.shape

In [None]:
static_out.shape

In [None]:
daily_vsn.weights_grn

In [None]:
synthetic_data['x_s'][0,...].shape

In [None]:
static_out[0,].shape

In [None]:
key = jax.random.PRNGKey(0)

batch_xs = synthetic_data['x_s'][0,...]
static_vsn = VariableSelectionNetwork(hidden_size, static_in_size, key=key)
static_out = static_vsn(batch_xs[jnp.newaxis,:])

d_context_grn = GatedResidualNetwork(hidden_size, key=key)
d_context = d_context_grn(static_out[0,:])

batch_xdd = synthetic_data['x_dd'][0,...]
daily_vsn = VariableSelectionNetwork(hidden_size,daily_in_size, key=key)
daily_out = daily_vsn(batch_xdd)

daily_out.shape

In [None]:
# Variable Selection Network
# INIT
keys = jax.random.split(key, 4)   

transformer_keys = jax.random.split(keys[0], daily_in_size)
variable_transformers = [
    eqx.nn.Linear(1, hidden_size, key=k)
    for k in transformer_keys
]

processor_keys = jax.random.split(keys[1], daily_in_size)
variable_processors = [
    GatedResidualNetwork(hidden_size, key=k)
    for k in processor_keys
]

weights_size = daily_in_size*hidden_size
weights_grn = GatedResidualNetwork(weights_size, hidden_size, key=keys[2])

# CALL
# Individual batch.
inputs = synthetic_data['x_dd'][0,...]

# Transform each variable
transformed_inputs = jnp.stack([
    jax.vmap(processor)(inputs[:, i][:,jnp.newaxis])
    for i, processor in enumerate(variable_transformers)
], axis=1)

# Process each variable
processed_inputs = jnp.stack([
    jax.vmap(processor)(transformed_inputs[:, i], None)
    for i, processor in enumerate(variable_processors)
], axis=1)

flattened = transformed_inputs.reshape([seq_length,-1])
flat_weights = jax.vmap(weights_grn, in_axes=(0,None))(flattened, d_context)
flat_weights = jax.nn.softmax(flat_weights, axis=-1)

variable_weights = flat_weights.reshape(transformed_inputs.shape)
print("Done!")

In [None]:
variable_weights.shape

In [None]:
d_context.shape

In [None]:
weighted_inputs.shape

In [None]:
inputs.shape

In [None]:
inputs = flattened[0]

eta2 = jax.nn.elu(eta2_dynamic(inputs) + eta2_bias)
eta1 = eta1_linear(eta2)
layer_norm(inputs + glu(eta1))

In [None]:
grn(flattened[0],None)

In [None]:
grn = GatedResidualNetwork(daily_in_size*hidden_size, None, daily_in_size*hidden_size, daily_in_size*hidden_size, key=keys[2])

flattened = transformed_inputs.reshape([seq_length,-1])
selection_weights = jax.vmap(grn)(flattened, None)
selection_weights = jax.nn.softmax(selection_weights, axis=-1)



In [None]:
selection_weights.reshape([30,22,-1]).shape