In [None]:
import jax
from nanodl import GAT

# Generate a random key for Jax
key = jax.random.PRNGKey(0)

# Create dummy input data
num_nodes = 10
num_features = 5
x = jax.random.normal(key, (num_nodes, num_features))  # Features for each node
adj = jax.random.bernoulli(key, 0.3, (num_nodes, num_nodes))  # Random adjacency matrix

# Initialize the GAT model
model = GAT(nfeat=num_features, nhid=8, nclass=3, dropout_rate=0.5, alpha=0.2, nheads=3)

# Initialize the model parameters
params = model.init(key, x, adj, deterministic=True)

# Apply the model in inference mode (deterministic=True)
output = model.apply(params, x, adj, deterministic=True)

# Print the output shape and a sample of the output
print("Output shape:", output.shape)
print("Output sample:", output[:2])