# Quick Start

## Basic Usage

The main interface is the `show()` function that takes as inputs a jitted or non-jitted function
and its arguments:

In [None]:
import jax
import jax.numpy as jnp
from visu_hlo import show


def simple_function(x):
    return 10 * x + 2 + 3


# Visualize the function
show(simple_function, jnp.array([1.0, 2.0, 3.0]));

This will:
1. Lower the function to HLO
2. Generate a DOT graph representation
3. Convert it to SVG format
4. Display it using your system's default SVG viewer

## JIT vs Non-JIT Functions

visu-hlo automatically detects whether a function is jitted or not and uses the appropriate visualization method:

### Original Function

For non-jitted functions, is shown the Jaxpr computation graph converted to HLO:

In [None]:
show(simple_function, jnp.ones(10))

### Jitted Function

For jitted functions, is shown the optimized computation graph after XLA compilation:

In [None]:
jitted_simple_function = jax.jit(simple_function)

show(jitted_simple_function, jnp.ones(10))

### Key Differences

When comparing the two visualizations:

1. **Optimization**: The jitted version shows fused operations
2. **Constant Folding**: Constants like `3 * 2 = 6` are pre-computed (folded)
3. **Memory Layout**: Different memory access patterns may be visible
4. **Operation Count**: Fewer nodes in the optimized version

## Function Arguments

You can pass multiple arguments and keyword arguments:

In [None]:
def multi_arg_func(x, y, scale=1.0):
    return (x + y) * scale


show(multi_arg_func, jnp.ones(5), jnp.zeros(5), scale=2.0)

## Understanding the Output

The generated SVG shows:
- **Nodes**: Operations (add, multiply, etc.)
- **Edges**: Data flow between operations
- **Colors**: Different operation types
- **Labels**: Operation names and shapes

Each node contains:
- Operation name (e.g., `add.1`, `mul.2`)
- Input/output shapes
- Source location in your code (when available)