In [1]:
import pandas as pd
import numpy as np

# Generate sample data
dates = pd.date_range(start='2023-01-01', periods=100, freq='D')
num_bonds = 10
data = {
    'date': np.tile(dates, num_bonds),
    'bond_id': np.repeat(np.arange(num_bonds), len(dates)),
    'feature1': np.random.randn(len(dates) * num_bonds),
    'feature2': np.random.randn(len(dates) * num_bonds),
}

df = pd.DataFrame(data)
df.set_index(['date', 'bond_id'], inplace=True)
print(df.sort_index().head())

                    feature1  feature2
date       bond_id                    
2023-01-01 0       -0.343461 -0.978529
           1        1.228038 -2.134945
           2       -0.799468 -2.024358
           3       -0.575496  0.220507
           4       -0.270393  0.764598


In [2]:
# Pivot the dataframe to get features as columns and bond_id as rows
bond_features = df.reset_index().pivot(index='bond_id', columns='date', values=['feature1', 'feature2'])
bond_features = bond_features.fillna(0)  # Handle NaNs if any

# Create the adjacency matrix (using correlation as an example)
correlation_matrix = np.corrcoef(bond_features.values.reshape(bond_features.shape[0], -1))
adjacency_matrix = (correlation_matrix > 0.05).astype(int)

# Node features
node_features = bond_features.values.reshape(bond_features.shape[0], -1)

In [3]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax

In [5]:
class GATLayer(nn.Module):
    in_features: int
    out_features: int
    num_heads: int = 1

    def setup(self):
        self.attn_weights = self.param('attn_weights', jax.nn.initializers.glorot_uniform(), 
                                       (self.num_heads, self.in_features, self.out_features))
        self.attn_biases = self.param('attn_biases', jax.nn.initializers.zeros, (self.num_heads, self.out_features))

    def __call__(self, x, edge_index):
        attn_weights = self.attn_weights
        attn_biases = self.attn_biases
        
        def apply_attention(head, node_features):
            x_w = jnp.dot(node_features, attn_weights[head]) + attn_biases[head]
            alpha = jax.nn.softmax(jnp.dot(x_w, x_w.T), axis=-1)
            return jnp.dot(alpha, x_w)
        
        multihead_outputs = [apply_attention(head, x) for head in range(self.num_heads)]
        return jnp.concatenate(multihead_outputs, axis=-1)

class GATModel(nn.Module):
    in_features: int
    out_features: int
    hidden_dim: int
    num_heads: int

    def setup(self):
        self.gat1 = GATLayer(self.in_features, self.hidden_dim, self.num_heads)
        self.gat2 = GATLayer(self.hidden_dim * self.num_heads, self.out_features, 1)

    def __call__(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = nn.elu(x)
        x = self.gat2(x, edge_index)
        return x

# Initialize model
model = GATModel(in_features=node_features.shape[1], out_features=1, hidden_dim=8, num_heads=8)
params = model.init(jax.random.PRNGKey(0), node_features, adjacency_matrix)

In [6]:
from flax.training import train_state

class TrainState(train_state.TrainState):
    pass #batch_stats: dict

# Define loss function
def loss_fn(params, batch):
    inputs, targets = batch
    predictions = model.apply(params, inputs, adjacency_matrix)
    loss = jnp.mean((predictions - targets) ** 2)
    return loss

In [7]:
# Initialize optimizer
learning_rate = 0.005
optimizer = optax.adam(learning_rate)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

# Training loop
@jax.jit
def train_step(state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [8]:
# Example training batch (dummy data)
targets = jnp.array([0.5] * node_features.shape[0])
batch = (node_features, targets)

# Train for a few epochs
num_epochs = 200
for epoch in range(num_epochs):
    state, loss = train_step(state, batch)
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss}')

Epoch 0, Loss: 0.05672908574342728
Epoch 10, Loss: 0.4664713144302368
Epoch 20, Loss: 0.054832398891448975
Epoch 30, Loss: 0.008281473070383072
Epoch 40, Loss: 0.002554930280894041
Epoch 50, Loss: 0.00029670834192074835
Epoch 60, Loss: 0.00018542220641393214
Epoch 70, Loss: 0.00014577849651686847
Epoch 80, Loss: 7.385780190816149e-05
Epoch 90, Loss: 3.337900488986634e-05
Epoch 100, Loss: 1.6240834156633355e-05
Epoch 110, Loss: 9.31080194277456e-06
Epoch 120, Loss: 6.5535159592400305e-06
Epoch 130, Loss: 5.49555579709704e-06
Epoch 140, Loss: 4.946053195453715e-06
Epoch 150, Loss: 4.429461569088744e-06
Epoch 160, Loss: 3.949753136112122e-06
Epoch 170, Loss: 3.5682755878951866e-06
Epoch 180, Loss: 3.2438422294944758e-06
Epoch 190, Loss: 2.956787511720904e-06


In [9]:
predictions = model.apply(state.params, node_features, adjacency_matrix)
print(predictions)

[[0.4997275 ]
 [0.50127506]
 [0.49976474]
 [0.49533224]
 [0.50070435]
 [0.50047666]
 [0.5011364 ]
 [0.5003159 ]
 [0.50117505]
 [0.5001677 ]]
