<img alt="[GoMLX Mascot]" src="gomlx_gopher.jpg" style="height:10em; float:left; margin:1em;"/>

# GoMLX Tutorial 

If you want just to quickly look at an working example, checkout [examples/cifar/demo/adult.ipynb](https://github.com/gomlx/gomlx/blob/main/examples/cifar/cifar.ipynb), for model trained on the UCI Adult Income dataset. More examples in [examples/](https://github.com/gomlx/gomlx/tree/main/examples) subdirectory. 

The tutorial won't detail the whole API, but should present all important concepts. Everything else is
well documented in godoc (in the code), also available in [pkg.go.dev](https://pkg.go.dev/github.com/gomlx/gomlx#section-readme).

The tutorial was written using a Jupyter notebook with [GoNB](https://github.com/janpfeifer/gonb), a kernel for Go that was co-developed with **GoMLX**. It has its own [short tutorial](https://github.com/janpfeifer/gonb/blob/main/examples/tutorial.ipynb) for those interested.

If you are seeing this tutorial from [github](https://github.com/gomlx/gomlx/blob/main/examples/tutorial/tutorial.ipynb) snapshop, you won't be able to interact with it. To be able to play with it, try **installing GoMLX**, see its [README Installation section](https://github.com/gomlx/gomlx#installation). The easiest way is to start the pre-generated docker and use the Jupyter notebook there -- this tutorial can be opened from there in an interactive way.





## Computation Graphs

> [Package `graph` reference documentation](https://pkg.go.dev/github.com/gomlx/gomlx/graph)

To do machine learning based on neural networks and gradient descent one of most important
requirements is the ability to do mathematical computations (mostly matrix multiplications)
fast.

GoMLX is built on the concept of building "computation graphs", just-in-time compiling them 
and only then executing them to get the desired results. That means one has to write code
that generates other type of code (computation graph) so to say. We do this because then we are
able to execute it really fast using [XLA](https://github.com/openxla/xla).

For example, let's create a computation graph to sum two values:

> **Note**
> - If executing this on a notebook, notice the very first cell execution takes a few seconds for Go to fetch the required dependencies.

In [26]:
!*go mod edit --replace github.com/gomlx/gomlx=${HOME}/Projects/gomlx
import . "github.com/gomlx/gomlx/graph"

func SumGraph(a, b *Node) *Node {
    return Add(a, b)
}

> **Note**
> 
> - The `import . "github.com/gomlx/gomlx/graph"` import all definitions in computation to the current scope. Usually,
>   it's easier to work like this on go files that are going to implement graph building functions.
> - Our function is named `SumGraph`: the suffix `Graph` is just a convention, but it helps identifying functions 
>   that do graph building.
> - The type `*Node` represents a node in the computation graph. All graph operations either take a
>   `*Graph` object to start with, or a `*Node`, and create new nodes with the corresponding operations.
>   So our example will create an `Add` node that will take the nodes pointed by `a` and `b`, build a node 
>   that represent their summation and then return this `*Node`. 
> - Every node contains a reference to the `*Graph` it's part of (see [`Node.Graph()`](https://pkg.go.dev/github.com/gomlx/gomlx/graph#Node.Graph)).
> - There is a rich set of operations available in GoMLX, see [`Node` documentation](https://pkg.go.dev/github.com/gomlx/gomlx/graph#Node).

Ok, but this won't tell us what is 1+1 yet. We need to compile and then execute this graph with
some input values.

## Executing Graphs with `Exec`

[`Exec`](https://pkg.go.dev/github.com/gomlx/gomlx/graph#Exec) is the easiest way to compile and execute computation graphs in GoMLX. 
To run our `SumGraph` function above we can do:

> **Note**: if you have a GPU and this is the very first time you run, NVidia/XLA take a couple of minutes pre-caching "stuff" (only happens once).

In [28]:
var manager = NewManager()

%%
exec := NewExec(manager, SumGraph)  
two := exec.Call(1, 1)[0]
fmt.Printf("1+1=%v\n", two.Value())

1+1=2


> **Note**
>
> - `%%` is a shortcut for `func main()`: everything after it is put inside a `main` function by **GoNB**, the Go Notebook kernel.
> - `NewManager()` creates a `Manager` object, which connects to an accelerator if present. 
>   Usually one creates one at the beginning of the program and passes it around.
>   Here **GoNB** will keep the global variable `manager` available to all cells, so we don't need
>   to define it again.
>   - The API also offers `BuildManager` method, that allows fine-grained control on which accelerator (or not)
>     to use, number of threads, etc.
> - The `Exec` object created is associated with a graph building function (`SumGraph` in this case). It lazily 
>   compiles and executes the compiled computation as needed. Naturally the first time `Call()` is invoked
>   it is slow: it has to build the graph and *just-in-time* (JIT) compile it. But the compiled graph afterwards 
>   is optimized and very fast to execute, which is what we want for machine learning.
> - The `Exec.Call(1, 1)` method always returns a slice of results, along with an optional error (ignored here), 
>   independent of the number of outputs of the Graph building function. That's why we use `[0]` to access the
>   first result.
> - The results of graph execution are always **tensors** (see section below). They can be converted to Go types
>   using `.Value()` method.

Something important to understand is that **Graphs have static (fixed) shapes for its inputs and outputs**. What
it means is that, for example, if you are going to sum floats instead of ints, `Exec` would have to rebuild the graph
to take as input two floats. Or if you want to sum a vector or matrix or ints, or any different `shapes.Shape`.

> For a detailed explanation of `Shape` and the associate concepts of Axis, Dimensions and `DType` (the underlying data type), 
see [package `shapes` documentation](https://pkg.go.dev/github.com/gomlx/gomlx/types/shapes).

To exemplify, let's expand our code a bit:

In [30]:
import (
    "fmt"
    . "github.com/gomlx/gomlx/graph"
)

func SumGraph(a, b *Node) *Node {
    fmt.Printf("* building graph for a.shape=%s and b.shape=%s\n", a.Shape(), b.Shape())
    return Add(a, b)
}


func main() {
    sumExec := NewExec(manager, SumGraph)
    two := sumExec.Call(1, 1)[0]
    fmt.Printf("1+1=%v\n", two.Value())

    for ii := 0; ii < 5; ii++ {
        sumInts := sumExec.Call(ii, ii)[0]
        fmt.Printf("%d+%d=%v\n", ii, ii, sumInts.Value())
    }

    five := sumExec.Call(3.5, 1.5)[0]
    fmt.Printf("3.5+1.5=%v\n", five.Value())

    many := sumExec.Call([]float32{1.1, 2.2, 3.3}, []float32{10, 10, 10})[0]
    fmt.Printf("[1.1, 2.2, 3.3] + [10, 10, 10] = %v\n", many.Value())
}

* building graph for a.shape=(Int64)[] and b.shape=(Int64)[]
1+1=2
0+0=0
1+1=2
2+2=4
3+3=6
4+4=8
* building graph for a.shape=(Float64)[] and b.shape=(Float64)[]
3.5+1.5=5
* building graph for a.shape=(Float32)[3] and b.shape=(Float32)[3]
[1.1, 2.2, 3.3] + [10, 10, 10] = [11.1 12.2 13.3]


> **Note**
>   - Each time a new graph is created, we added a `fmt.Printf` to tell us
>     the shape of the graph operands. Notice that `fmt.Printf` is not included in the graph,
>     it's only part of the graph building function. We'll see later how to print
>     intermediary results in the middle of the execution of the graph.
>   - Every `Node` has an associated shape (`shapes.Shape` type). A shape is defined by its underlying data type 
>     `shapes.DType` and its axes dimensions. For scalars, the shape has zero axes (dimensions). E.g.: `(Int64)[]` represents
>     a scalar `int` value, and `(Float32)[3]` represents a vector with 3 `float32` values. More details
>     and the list of type supported in the package `gomlx/types`.
>   - `Exec` automatically calls `SumGraph` whenever the `Call()` method sees parameters of shapes
>     different from it has seen before (there is a cache of pre-compiled graphs kept in memory with limited size).

In general the graph operations only work with the same `shapes.DType` ("data type"). 
If they are different, they are reported back with a `panic` (works like an exception, and can be caught) with an error with a full stack-trace in the returned result. 

Let's create an example with an error to see how this goes:


In [35]:
%%
sumExec := NewExec(manager, SumGraph)
_ = sumExec.Call(1.1, 2) // Error: arguments have different dtypes float64 and int.


* building graph for a.shape=(Float64)[] and b.shape=(Int64)[]


panic: operands of AddNode have different dtypes (Float64 and Int64)

goroutine 1 [running]:
github.com/gomlx/gomlx/types/exceptions.Panicf(...)
	/home/janpf/Projects/gomlx/types/exceptions/exceptions.go:69
github.com/gomlx/gomlx/graph.twoArgsNode(0x2f, 0xc0000e00a0, 0xc0000e01e0)
	/home/janpf/Projects/gomlx/graph/node.go:426 +0x1a5
github.com/gomlx/gomlx/graph.Add(...)
	/home/janpf/Projects/gomlx/graph/node.go:433
main.SumGraph(0xc0000e00a0, 0xc0000e01e0)
	 [7m[[ Cell [30] Line 8 ]][0m /tmp/gonb_7be75428/main.go:15 +0x171
reflect.Value.call({0x523760?, 0x552a50?, 0xc0000cd8d0?}, {0x544cb5, 0x4}, {0xc0000aca80, 0x2, 0xc0000cda98?})
	/home/janpf/src/golang/go/src/reflect/value.go:586 +0xb0b
reflect.Value.Call({0x523760?, 0x552a50?, 0x0?}, {0xc0000aca80?, 0x0?, 0x0?})
	/home/janpf/src/golang/go/src/reflect/value.go:370 +0xbc
github.com/gomlx/gomlx/graph.(*Exec).createAndCacheGraph(0xc0000dc000, {0xc0000be150, 0x2, 0x0?})
	/home/janpf/Projects/gomlx/graph/exec.go:450 +0x854
github.com/g

> **Note**
> - In the stack-trace above there are 2 lines of interest, that typically help to debug such issues:
>   1. Where in the graph building function `main.SumGraph` function the invalid operation was created: **Line 8 of the previous cell**.
>   2. Where in the `main` function, the graph was attempted to be executed: **Line 3 of this cell**.
> - You can enable displaying line-numbers in the JupyterLab with "ESC+L" (upper-case L).

If you want to catch errors, **GoMLX** provides a small `exceptions` library, that defines `TryCatch[E]`, that will catch arbitrary `panic` (thrown) exceptions. **GoMLX** only throws `error` type of exceptions. So you could do:

```go
err := TryCatch[error](func() {_ = exec.Call(1.1, 2)}) // Error: different types (float64 and int) !?
if err != nil { … }
```


## Tensors

Tensors are multidimensional arrays of a given data type (`shapes.DType`) defined in the package `gomlx/types/tensor`. 

For GoMLX tensors work as **containers of data** that are used as concrete inputs and outputs for the
execution of computational graphs. There is only basic support to manipulate tensors directly (it includes access directly to its data)
because one expects to do that with the computational graphs.  Tensors have a shape (`shapes.Shape`)
just like `Node`.

The package includes a generic `tensor.Tensor` interface, that is implemented by two types of tensors:
`tensor.Local` and `tensor.Device`, that differ where their data is stored. 

The `tensor.Device` have the data in the accelerator device that is going to run the computation 
(even if it is "Host", which represents the CPU). These are used as input and output of the
computation graph.

`Local` means a tensor stored in local memory, and it can be directly mutated -- but generally
we only use it to input or output data.

The generic `Tensor` interface includes methods to transfer the tensor from `Local` to `Device` and vice-versa.
It includes a cache to avoid transfering the same data multiple times. Most APIs will use `Tensor` as
parameters and transfer as needed. 

There is a **cost in transferring** between `Local` and `Device`, be mindful when handling 
large data values. The API include `Finalize` to force immediate _finalization_ of a tensor, 
to free it's memory, as opposed to wait for the GC (for those very large models where GPU
memory is at a premium).

Graph execution only consumes (input) and outputs `Device` tensors, but the library provide
all the conversion tools needed to make that simple.

Example:

In [20]:
%%
manager := BuildManager().Platform("Host").MustDone()
onePlusExec := NewExec(manager, func (x *Node) *Node {
    return OnePlus(x) 
})
// exec.Call will return a tensor.Device.
counter := onePlusExec.Call(0)[0]
// counter.Value() will first transfer counter to local with counter.Local().
fmt.Printf("counter.type=%s, counter.shape=%s, counter=%v\n", reflect.TypeOf(counter), counter.Shape(), counter.Value())
for ii := 0; ii < 10; ii++ {
    // exec.Call will use counter.Device(): which uses the current value (no need to transfer) that is already
    // on Device.
    counter = onePlusExec.Call(counter)[0]
}
// counter.Value() will first convert counter to local with counter.Local().
fmt.Printf("counter=%v\n", counter.Value())

counter.type=*tensor.Device, counter.shape=(Int64)[], counter=1
counter=11


> **Note**:
> - In the first call to `onePlusExec.Call(0)`, the Go constant `0` is automatically converted to a `*tensor.Device`
>   by `Exec` and fed to the graph. It returns a `[]tensor.Tensor` with one element, containing `0+1=1`.
> - The returned tensor is actually a `*tensor.Device` (that implements a `tensor.Tensor` interface), but the
>   storage of the data in on device.
> - When we loop the counter, note that we never move `counter` to a `*tensor.Local`. If we were to execute this on a 
>   GPU or TPU, the data would not have been moved back to the CPU while executing the loop -- making it faster.
> - When we print the final result in `counter.Value()` the actual data is converted to a `tensor.Local`
>   and its content automatically converted back to a Go `int` type.

There are several ways to create `Local` tensors, the most common:

 - `FromValue[S](value S)`: Generics conversion, works with any supported `DType` scalar
   as well as with any arbitrary multidimensional slice. Slices of rank > 1 must be regular, that is
   all the sub-slices must have the same shape. E.g: `FromValue([][]float{{1,2}, {3, 5}, {7, 11}})`
 - `FromShape(shape shapes.Shape)`: creates a Local tensor with the given shape, and uninitialized values. See
   documentation on how to mutate `Local` tensors in place with `Local.AcquireData()`.
 - `FromScalarAndDimensions[T](value T, dimensions ...int)`: creates a Local tensor with the
   given dimensions, filled with the scalar value given. `T` must be one of the supported types.

`Local` tensors provide also functions to serialize/deserialize in binary format.

`Device` tensors are created only by transferring local tensors to the device, or if they are returned by the execution of a graph. 

See more documentation in [pkg.go](https://pkg.go.dev/github.com/gomlx/gomlx/types/tensor).

Errors in the manipulation of Tensors (e.g. invalid values) are reported back with `panic`, with full stack-traces, just as
with the `graph` package described in the previous package. The errors can easily be caught (with `recover()` or with `exceptions.TryCatch` helper) when needed.

## Gradients

Another important functionality required to train machine learning models based on gradient descent is calculating 
the gradients of some value being optimized with respect to some variable / quantity. 

GoMLX does this statically, during graph building time. It adds to the graph the computation for the gradient.

Example: let's calculate the gradient of the function $f(x, y) = x^2 + xy$ for a few values of $x$ and $y$.
Algebraically we have $df/dx(x,y) = 2x + y$ and $df/dy(x,y) = x$.


In [36]:
func f(x, y *Node) *Node {
    return Add(Square(x), Mul(x, y))
}

%%
gradOfFExec := NewExec(manager, func(x, y *Node) (output, gradX, gradY *Node) {
    output = f(x, y)
    sum := ReduceAllSum(output) // In case x and y are not scalars.
    grads := Gradient(sum, x, y)
    gradX, gradY = grads[0], grads[1] // df/dx, df/dy
    return output, gradX, gradY
})

results := gradOfFExec.Call([]float64{0, 1, 2}, []float64{10, 20, 30})
fmt.Printf("f=%v, df/dx=%v, df/dy=%v\n", results[0].Value(), results[1].Value(), results[2].Value())


f=[0 21 64], df/dx=[10 22 34], df/dy=[0 1 2]


> **Note**:
> - For now GoMLX only calculates gradients of a scalar (typically a model loss) with respect to arbitrary tensors.
>   It does not yet calculate **jacobians**, that is, if the value we are deriving is not a scalar. That's the reason
>   of the `ReduceAllSum` in the example, the result is the derivative of the sum of all the 3 inputs.
> - A question that may arise is whether it calculates the second derivative (*hessian*). In principle the machinery
>   to do that is in place, but there are 2 limitations: (1) not all operations have their derivative implemented,
>   in particular some of the operations that are only used when calculating the first derivative; (2) it only 
>   calculates the gradient with respect to a scalar, in most cases the hessian will be the gradient of a gradient,
>   usually of higher rank -- Btw, contributions to the project here are welcome ;)  


## Variables and Context

Computation graphs are [pure functions](https://en.wikipedia.org/wiki/Pure_function): they have no state,
they take inputs, return outputs and everything in between is transient[^1].

[^1]: With a few exceptions, like the random number generator.

For Machine Learning as well as many types of computations, it's convenient to store intermediary results
(the model parameters for ML) in between the execution of the computation graphs.

For that GoMLX offers the `context.Context` object (completely unrelated to the usual Go's `context` package), 
and a corresponding `context.Exec`. It is a container of
variables (whose values are tensors on device), and it manages automatically its updates, passing it as extra
inputs and taking them out (if changed) as extra outputs of the computation graph.

This may sound more complex than it is in practice. Let's see an example, where we try to find $argmin_{x}{f(x)}$ where
$f(x) = ax^2 + bx + c$. If we solve it literally we should get, for $a > 0$, $argmin_{x}{f(x)} = \frac{-b}{2a}$. Instead
we solve it numerically, by gradient descent:

In [38]:
import "flag"

var (
    flagA = flag.Float64("a", 1.0, "Value of a in the equation ax^2+bx+c")
    flagB = flag.Float64("b", 2.0, "Value of b in the equation ax^2+bx+c")
    flagC = flag.Float64("c", 4.0, "Value of c in the equation ax^2+bx+c")
    flagNumSteps = flag.Int("steps", 10, "Number of gradient descent steps to perform")
    flagLearningRate    = flag.Float64("lr", 0.1, "Initial learning rate.")
)

// f(x) = ax^2 + bx + c
func fGraph(x *Node) *Node {
    f := MulScalar(Square(x), *flagA)
    f = Add(f, MulScalar(x, *flagB))
    f = AddScalar(f, *flagC)
    return f
}

// minimizeF does one gradient descent on F by moving a variable "x",
// and returns the value of the function at the current "x".
func minimizeF(ctx *context.Context, graph *Graph) *Node {
    xVar := ctx.VariableWithValue("x", 0.0) // Variable reference.
    x := xVar.ValueGraph(graph)             // Read variable for the current graph.
    f := fGraph(x)                          // Value of f(x).
    
    // Gradient always return a slice, we take the first element for grad of X.
    gradX := Gradient(f, x)[0] 
    
    // stepNum += 1
    stepNumVar := ctx.VariableWithValue("stepNum", 0.0)
    stepNum := stepNumVar.ValueGraph(graph)
    stepNum = OnePlus(stepNum)
    stepNumVar.SetValueGraph(stepNum)
    
    // step = -learningRate * gradX / Sqrt(stepNum)
    step := Div(gradX, Sqrt(stepNum))
    step = MulScalar(step, -*flagLearningRate)
    
    // x += step
    x = Add(x, step)
    xVar.SetValueGraph(x)
    return f  // f(x)
}

func Solve() {
    ctx := context.NewContext(manager)
    exec := context.NewExec(manager, ctx, minimizeF)
    
    for ii := 0; ii < *flagNumSteps-1; ii++ {
        _ = exec.Call()
    }
    f := exec.Call()[0]
    x := ctx.InspectVariable(ctx.Scope(), "x").Value()
    stepNum := ctx.InspectVariable(ctx.Scope(), "stepNum").Value()
    fmt.Printf("Minimum found at x=%g, f(x)=%g after %f steps.\n", x.Value(), f.Value(), stepNum.Value())
}

The code above created `Solve()` that will solve for the values set by the flags `a`, `b`, and `c`.

Let's try a few values:

> **Note**: `%%` in GoNB automatically creates a `func main()` and passes the extra arguments to the Go program.

In [40]:
%% --a=1 --b=2 --c=3 --steps=10 --lr=0.5
Solve()

Minimum found at x=-1, f(x)=2 after 10.000000 steps.


In [41]:
%% --a=2 --b=12 --c=20 --steps=10 --lr=0.5
Solve()

Minimum found at x=-3, f(x)=2 after 10.000000 steps.


> **Note**:
> - We are using `context.Exec`, while before we were using `computation.Exec`. The main difference is that
>   `context.Exec` compiles and executes graph functions that take a context as its first parameter, and it 
>   automatically handles the passing of variables as side inputs and outputs (for those variables updated)
>   of the computation graph.
> - During graph building, we access and set the variables with `Variable.ValueGraph` and `VariableSetValueGraph`: 
>   They return/take `*Node` types, that can be used in the graph.
> - Outside graph building, we can access the last value set to a variable by using `Variable.Value()` and 
>   `Variable.SetValue`. They return/take concrete `tensor.Tensor` types. Usually a `*tensor.Device`, the
>   value actually being stored in the accelerator.
> - We created two variables, one for "x" that we were optimizing to minimize $f(x)$, and one variable "stepNum",
>   used to keep track how many steps were already executed. 
> - Yes, if we set `--lr=1` (the learning rate), it will get to the minimum in one step for the quadratic f(x). 😉

There is more to `context.Context`, some we'll present on the next section on *Machine Learning*, others can
be found in its documentation. A few things worth advancing:

* `Context` is always configured at a certain *scope*, and variables are unique within its scope. Scope
  is easily changed with `ctx.In("new_scope")`. So the `Context` object is a scope (a string) and a pointer to 
  the actual data (variables, graph and model parameters).
* `Context` also holds model parameters (concrete Go values), which are also scoped. 
  Those can be hyperparameters for the models (learning rate, regularization, etc.) or anything the user or
  any library may create a convention for.
* Similarly `Context` also holds "Graph parameters". Those are very similar to model parameters, but they
  have one value per Graph. So if a model is created with parameters of different shape (or for training/evaluation),
  each version will have its own Graph parameters. Don't worry about this now -- if you need it later
  when building complex graphs, the funcitonality will be there.

## Machine Learning (ML)

The previous sections presented the fundamentals of what is needed to implement machine learning. This section
we present various sub-packages that provide high level ML layers and tools that make building, training and 
evaluating a model trivial.

First is the package `layers` (see code in [ml/layers](../../../tree/main/ml/layers/). It provides several composable ML layers. 
These are graph building functions, most of which take a `*context.Context` as first parameter, where they store 
variables or access hyperparameters. 
There are several such layers, for example: `layers.Dense`, 
`layers.Dropout`, `layers.PiecewiseLinearCalibration` (very good for normalization of inputs), `layers.BatchNorm`, 
`layers.LayerNorm`, `layers.Convolution`, `layers.MultiHeadAttention` (for [Transformers](https://arxiv.org/abs/1706.03762) 
layers), etc.

The package `train` offers two main functionalities: `train.Trainer` will build a *train step* and an *eval step*
graph, given a model graph building function and an optimizer. This graph can be executed in sequence to train a model.
The package also provides `train.Loop` that simply loop over a `train.Dataset` interface, reading data and feeding
it to `Trainer.TrainStep` or `Trainer.EvalStep`, along with executing configurable hooks on the training loop. One
such hooks is provided by `gomlx/train/commandline.AttachProgressBar(loop)`, it pretty prints the progress during
training on the command line.

There are also a collection of optimizers, loss functions, metrics, etc. For any functionality there is always an
example under the [examples/](../../../tree/main/examples/) subdirectory.

Let's look at the simplest ML example: [`linear`](../../../tree/main/examples/linear/linear.go), 
which trains a linear model on nosiy generated data.

Here are the constants of our problem:

In [43]:
const (
    CoefficientMu    = 0.0
    CoefficientSigma = 5.0
    BiasMu           = 1.0
    BiasSigma        = 10.0
)

To generate synthetic data it first randomly chose some random coefficients and bias based on which data is 
generated. These selected coefficients is the ones we want to try to learn using ML. The coefficients could
have been selected in Go directly using `math/random`, but just for fun, we do it using a computation graph.

In [47]:
import (
    "github.com/gomlx/gomlx/types/shapes"
    "github.com/gomlx/gomlx/types/tensor"
)

// initCoefficients chooses random coefficients and bias. These are the true values the model will
// attempt to learn.
func initCoefficients(manager *Manager, numVariables int) (coefficients, bias tensor.Tensor) {
    e := NewExec(manager, func(g *Graph) (coefficients, bias *Node) {
        rngState := Const(g, RngState())
        rngState, coefficients = RandomNormal(rngState, shapes.Make(shapes.F64, numVariables))
        coefficients = AddScalar(MulScalar(coefficients, CoefficientSigma), CoefficientMu)
        rngState, bias = RandomNormal(rngState, shapes.Make(shapes.F64))
        bias = AddScalar(MulScalar(bias, BiasSigma), BiasMu)
        return
    })
    results := e.Call()
    coefficients, bias = results[0], results[1]
    return
}

%%
coef, bias := initCoefficients(manager, 3)
fmt.Printf("Example of target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value()) 

Example of target: coefficients=[-8.83 2.06 5.73], bias=-1.16


> **Note**
> - This code should look familiar, using things we presented earlier in the tutorial. It creates a computation
>   graph to generate randomly the `coefficients` and `bias`. Then it executes it and returns the result. 
> - Notice that since the *computation graph is functional*: we need to pass around the random number
>   generator state, which gets updated at each call to `RandomUniform` or `RandomNormal`.
>   - Alternatively the `context.Context` introduced early can keep the state
>     as a variable, and provides a simpler interface: see `Context.RandomUniform` and `Context.RandomNormal`.

Next, we want to generate the data (examples): we generate random inputs, and then the label using the
selected coefficients plus some normal noise.


In [51]:
func buildExamples(manager *Manager, coef, bias tensor.Tensor, numExamples int, noise float64) (inputs, labels tensor.Tensor) {
    e := NewExec(manager, func(coef, bias *Node) (inputs, labels *Node) {
        g := coef.Graph()
        numFeatures := coef.Shape().Dimensions[0]

        // Random inputs (observations).
        rngState := Const(g, RngState())
        rngState, inputs = RandomNormal(rngState, shapes.Make(shapes.F64, numExamples, numFeatures))
        coef = ExpandDims(coef, 0)

        // Calculate perfect labels.
        labels = ReduceAndKeep(Mul(inputs, coef), ReduceSum, -1)
        labels = Add(labels, bias)
        if noise > 0 {
            // Add some noise to the labels.
            var noiseVector *Node
            rngState, noiseVector = RandomNormal(rngState, labels.Shape())
            noiseVector = MulScalar(noiseVector, noise)
            labels = Add(labels, noiseVector)
        }
        return
    })
    examples := e.Call(coef, bias)
    inputs, labels = examples[0], examples[1]
    return
}

%%
coef, bias := initCoefficients(manager, 3)
numExamples := 5
inputsTensor, labelsTensor := buildExamples(manager, coef, bias, numExamples, 0.2)
fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value()) 

fmt.Printf("%d dataset examples:\n", numExamples)
inputs := inputsTensor.Local().Value().([][]float64)
labels := labelsTensor.Local().Value().([][]float64)
for ii := 0; ii < numExamples; ii ++ {
    fmt.Printf("\tx=%0.3v; label=%0.3v\n", inputs[ii], labels[ii])
}

Target: coefficients=[-2.12 0.37 11.2], bias=-25.7
5 dataset examples:
	x=[0.73 0.976 1.4]; label=[-11.4]
	x=[2.18 -1.06 -0.082]; label=[-31.7]
	x=[-0.335 0.113 -0.313]; label=[-28.7]
	x=[0.623 0.269 -0.979]; label=[-37.8]
	x=[1.87 0.309 -0.419]; label=[-34.3]


### Dataset

Now the first new concept of this section: `train.Dataset` is the interface that is used to feed data to
during a training loop or evaluation. 

There are three methods: `Dataset.Yield` that returns the next batch of examples; 
`Dataset.Reset` restarts the dataset, for datasets that don't loop indefinitely; 
Finally `Dataset.Name` returns the dataset name, usually used for metric names, logging and printing.

Datasets also yield a `spec`, an opaque type for GoMLX (defined as `any`), that allows the dataset to communicate
to the model which type of data it is generating. In our case, since it's always the same data, we don't need it, 
so we keep it set to `nil`. 
If one would implement a dataset like a generic CSV file, one may want to communicate 
to the model the field names to the Model throught the `spec`, for instance. See the documentation for more details.

For our linear synthetic data we implement the simplest `train.Dataset`: the whole data is pre-generated, and we return a giant batch with the full data every time:

In [52]:
import "github.com/gomlx/gomlx/ml/train"

// TrivialDataset always returns the whole data.
type TrivialDataset struct {
    name string
    inputs, labels []tensor.Tensor
}

var (
    // Assert Dataset implements train.Dataset.
    _ train.Dataset = &TrivialDataset{}
)
// Name implements train.Dataset.
func (ds *TrivialDataset) Name() string { return ds.name }

// Yield implements train.Dataset.
func (ds *TrivialDataset) Yield() (spec any, inputs, labels []tensor.Tensor, err error) {
    return nil, ds.inputs, ds.labels, nil
}

// Reset implements train.Dataset.
func (ds *TrivialDataset) Reset() {}

> **Note**:
> * More often it is more work pre-processing data than actually building an ML model ... that's life 🙁
> * In the [examples/](../../../tree/main/examples/) subdirectory we implement `train.Dataset` for some
>   well known data sets: UCI Adult, Cifar-10, Cifar-100, IMDB Reviews, Kaggle's Dogs vs Cats, Oxford Flowers 102.
>   These can be used as libraries to easily try different models. 
>   If you are workig on public datasets, please consider contributing similar libraries.
> * **GoMLX** also include the `data.InMemoryDataset`, which can be created from tensors in one line -- we could
>   have used that instead of defining `TrivialDataset`, but we left it because often 

The package [github.com/gomlx/gomlx/ml/data](https://pkg.go.dev/github.com/gomlx/gomlx/ml/data) provides several
tools to facilitate the work here:

* `Parallel`: parallelizes any dataset, includes some buffer.
* `InMemory`: reads a dataset into (accelerator) memory, and then serves it from there -- greatly accelerates training.
* Downloading (with progress-bar) and checksum functions.


### ModelFn

Next we build a model, that for our `train` package means implementing a function with the following signature:

```go
type ModelFn func(ctx *context.Context, spec any, inputs []*graph.Node) (predictions []*graph.Node)
```

It takes a `context.Context` for the variables and hyperparameters, the `spec`  and a slice of `inputs` --
the last two are fed by `Dataset.Yield` above. It returns a slice of `predictions` -- is most cases there 
is just one value in the slice (only one prediction). During training `predictions` fed to the loss function,
and during inference they can be returned directly.

Our linear example has the simplest model possible:

In [14]:
func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) {
    _ = spec  // Not needed here, we know the dataset.
    logits := layers.Dense(ctx, inputs[0], /* useBias= */ true, /* outputDim= */ 1)
    return []*Node{logits}
}

> **Note**
> - It uses the `layers.Dense` layer, which simply multiplies the input by a learnable matrix (weights)
>   and optionally add a learnable bias. It's the most basic building block of neural networks (NNs).
>   The implementation of `layers.Dense` is pretty simple, and worth checking out to refresh how
>   variables from the `Context` are used.
> - Since it's a linear model, we don't use an activation function. The usual are available for NNs (`Relu`,
>   `Sigmoid`, `Tanh` and more to come).
> - The `spec` parameter allows the creation of a `ModelFn` that can be used for different types of data. The dataset
>   can Yield also a `spec` about the type of data it is reading. Each different value of `spec` will trigger the
>   the creation of a different computation graph, so ideally there would be at most a few types of different data
>   source `spec`. Most commonly there is only one, like in this example, and the parameter can be ignored.

### Trainer and Loop

The last part is put together a `train.Trainer` and `train.Loop` objects in our `main()` function. The
first stitches together the model, the optimizer and the loss function, and is able to run training
steps and evaluations. The second, `train.Loop`, loops over the dataset executing a training step at
a time and supporst a subscription (hooking) system, where one attaches things like a progress bar,
or plotting of a graph.

In [15]:
import "github.com/gomlx/gomlx/ml/train/commandline"

var (
    flagNumExamples  = flag.Int("num_examples", 10000, "Number of examples to generate")
    flagNumFeatures  = flag.Int("num_features", 3, "Number of features")
    flagNoise        = flag.Float64("noise", 0.2, "Noise in synthetic data generation")
)

// AttachToLoop is a hook to allow one to attach different functionality to the loop.
func AttachToLoop(loop *train.Loop) {
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
}

// TrainMain() does everything to train the linear model.
func TrainMain() {
    flag.Parse()
    manager := BuildManager().Platform("Host").MustDone()

    // Select coefficients that we will try to predic.
    trueCoefficients, trueBias := initCoefficients(manager, *flagNumFeatures)
    fmt.Printf("Coefficients: %0.5v\n", trueCoefficients.Value())
    fmt.Printf("Bias: %0.5v\n\n", trueBias.Value())

    // Generate training data with noise.
    inputs, labels := buildExamples(manager, trueCoefficients, trueBias, *flagNumExamples, *flagNoise)
    fmt.Printf("Training data (inputs, labels): (%s, %s)\n\n", inputs.Shape(), labels.Shape())
    dataset := &TrivialDataset{"linear", []tensor.Tensor{inputs}, []tensor.Tensor{labels}}

    // Creates Context with learned weights and bias.
    ctx := context.NewContext(manager)
    ctx.SetParam(optimizers.LearningRateKey, *flagLearningRate)

    // train.Trainer executes a training step.
    trainer := train.NewTrainer(manager, ctx, modelGraph,
        losses.MeanSquaredError,
        optimizers.StochasticGradientDescent(),
        nil, nil) // trainMetrics, evalMetrics
    loop := train.NewLoop(trainer)
    AttachToLoop(loop)

    // Loop for given number of steps.
    _, err := loop.RunSteps(dataset, *flagNumSteps)
    AssertNoError(err)

    // Print learned coefficients and bias -- from the weights in the dense layer.
    fmt.Println()
    coefVar, biasVar := ctx.InspectVariable("/dense", "weights"), ctx.InspectVariable("/dense", "biases")
    learnedCoef, learnedBias := coefVar.Value(), biasVar.Value()
    fmt.Printf("Learned coefficients: %0.5v\n", learnedCoef.Value())
    fmt.Printf("Learned bias: %0.5v\n", learnedBias.Value())
}

%%
TrainMain()

Coefficients: [-1.8218 5.2781 -0.99633]
Bias: 17.422

Training data (inputs, labels): ((Float64)[10000 3], (Float64)[10000 1])

Training (10 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (151 steps/s)[0m [loss=44.794] [~loss=125.460]        

Learned coefficients: [[-1.2449] [3.5755] [-0.73268]]
Learned bias: [11.466]


> **Note**:
> - Hyperparameters are set on the context. Layers and optimizers can define their own
>   hyperparemeters independently. `Context` uses a scoping system (like directories), 
>   and hyperparameters can take specialized values under specific scopes -- e.g.: 
>   doing `ctx.In("dense_5").SetParam(layers.L2RegularizationKey, 0.1)` would set
>   L2 regularization only for the layer `dense_5` of the model to `0.1`.
> - The `trainer` constructor also takes as input arbitrary metrics (for training and evaluation).
>   The metric of the loss of the last batch, and a moving average of the loss are always included
>   automatically. There are many others, that can be means or moving averages, etc.
> - The 'Loop' object is very flexible. One can attach any functionality with `OnStart`, `OnStep`,
>   `OnEnd`, `EveryNSteps` or `NTimesDuringLoop`. The most common such functionality is the
>   `commandline.AttachProgressBar`. There is also a plotting of any arbitrary metric or any
>   arbitrary `Node` in the computation graph.
> - `Loop.RunSteps()` returns also the final metrics from the training, usually printed out.

### Training and Plotting

As our last example, let's train it "for real", that is, with more steps.

And to make things prettier, let's attach also a plot of the metrics registered. In our example
the only metrics are the default ones, the batch and mean of the loss -- the mean squared error.

In [16]:
import (
    "github.com/gomlx/gomlx/examples/notebook/gonb/margaid"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/ml/train/commandline"
)

func AttachToLoop(loop *train.Loop) {
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop. 
    margaid.New(1024, 400).DynamicUpdates().Attach(loop, /* num plot points: */ *flagNumSteps)  // Generates a new plot point at every step.
}

%% --steps=100
TrainMain()

Coefficients: [-4.3234 3.3146 0.12164]
Bias: 6.2798

Training data (inputs, labels): ((Float64)[10000 3], (Float64)[10000 1])



Training (100 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (994 steps/s)[0m [loss=0.073] [~loss=3.528]        



Learned coefficients: [[-4.2297] [3.2544] [0.11925]]
Learned bias: [6.1451]


> **Note**:
> The `gomlx/examples/notebook/gonb/margaid` package will automatically plot all the metrics registered in the trainer. When it attaches itself to 
> `loop` it collects the metric values (and run evaluation on any requested datasets), and at the end of it, plot it. Optionally, with `Plots.DynamicUpdates()`
> it also plots intermediary results, displaying the progress of the training as it happens -- it's too fast here to notice.

There is much more in the libraries. Mostly it's well documented, and the implementation is
generally simple. We highly recommend looking at it when trying to understand, for instance, how a layer works.


## Debugging

Unfortunately, the computers just "don't get it": they do exactly what we tell them to do,
as opposed to what we want them to do, and thus programs fail or crash. 
GoMLX provides different ways to track down various types of errors. The most commonly used below:

### Good old "printf"

It's convenient because of Go fast compilation (change something and run to see what one gets is
almost instant). Logging results to stdout is a valid way of developing. During graph building development,
often one prints the shape of the `Node` being operated to confirm (or not) one's expectations.

### Delayed Errors

Errors during the building of the graph are reported to the `Graph` or the `Context` or both. They,
as well as `Node` and tensors, implement the methods `Ok()` and `Error()` to check if there has been 
an error, and what it is. The errors always include a stack-trace -- print error with `"%+v"` to get 
full stack-trace output.

To avoid error checking at every step (it would make the code too cumbersome), the idea is to check for errors
only sporadically. The suggestion is to do it in the start of a graph function. If an error happens in between,
most operations and layers are able to handle invalid `Node`, by returning invalid nodes themselves. The 
error stored is always the first one that happened.

This scheme has proven very effective during the development of the various operations. 

Tensors also support delayed error, and can be similarly checked. `Exec` objects report failure
in execution through the returned tensors and `Context` object.

More discussion on [error handling here](error_handling.go).

### Node Shape Asserts

During the writing of complex models, it's very common to add comments on the expected shapes of the graph nodes, to
facilitate the reader (and developer) of the code to have the right mental model of what is going on.

GoMLX provides a series of _assert_ methods that can be used instead. They serve both as documentation, and an early
exit in case of some unexpected results. They will set the appropriate error in the Graph if they fail to check the
condition.

For example, a `modelGraph` function could contain:

```go
    batch_size := inputs[0].Shape().Dimensions[0]
    ...
    layer := Concatenate(allEmbeddings, -1)
    if !layer.AssertDims(batchSize, -1) {  // 2D tensor, with batch size as the leading dimension.
        return nil
    }
```

Although using these when building graphs is the most common case, there are similar assert functions for tensors and shapes themselves in the package `gomlx/types/shapes`.


### Graph Execution Logging

Every `Node` of the graph can be flagged with `SetLogged(msg)`.
The executor (`Exec`) will at the end of the execution log all these values. The default logger 
(set with `Exec.SetLogger`) will simply print the message `msg` along with the value of the `Node` of
interest. 

In package [`github.com/gomlx/gomlx/experimeantal/collector`](https://github.com/gomlx/gomlx/blob/main/experimental/collector/collector.go) there is also a specialized logger that will collect 
these values for plotting later. Creating a new specialized logger of any type is trivial.

### More Debugging

Tests and these methods have been enough to develop most of GoMLX so far. But there are other
debugging tools that could be made, see discussion in the [Debugging](https://github.com/gomlx/gomlx/blob/main/docs/debugging.md) document. Let us know if you need something specialized.

---

Happy coding and [good luck](https://arxiv.org/abs/1803.03635) on modeling!!

<img alt="[Zürich See]" src="zurich_see.jpg" style="width:100%;"/>
