Logistic Regression Part 2
--------------------------

In this example, we extend the code from Part 1 with several important features:
- Instead of just updating the weight matrix ``W``, we add a bias ``b`` and use the ``.variables()`` method to compactly update both variables.
- We attach an additional computation to the transformer to compute the loss on a held-out validation dataset.
- We switch from a flat ``C``-dimensional feature space to a ``W x H`` feature space to demonstrate multi-dimensional logistic regression.

The corresponding jupyter notebook is found [here](https://github.com/NervanaSystems/ngraph/blob/master/examples/walk_through/Logistic_Regression_Part_2.ipynb).

In [None]:
import ngraph as ng
import ngraph.transformers as ngt
import gendata

The axes creation is conecptually the same as before, except we now add a new axes ``H`` to represent the new feature space. 

In [None]:
ax_W = ng.make_axis(length=2)
ax_H = ng.make_axis(length=2)  # new axis added.
ax_N = ng.make_axis(length=128, name='N')

### Building the graph
Our model, as in the previous example, has three placeholders: ``X``, ``Y``, and ``alpha``. But now, the input ``X`` has shape ``(W, H, N)``:

In [None]:
alpha = ng.placeholder(())
X = ng.placeholder([ax_W, ax_H, ax_N])  # now X has shape (W, H, N)
Y = ng.placeholder([ax_N])

Similarly, the weight matrix is now multi-dimensional, with shape ``(W, H)``, and we add a new scalar bias variable. We want also to specify that, for the weight matrix ``W``, both axes will be reduced when computing the element-wise product and summation with the inputs (so we add ``-1`` to specify this)

In [None]:
W = ng.variable([ax_W, ax_H], initial_value=0).named('W')  # now the Weight Matrix W has shape (W, H)
b = ng.variable((), initial_value=0).named('b')

Our predicted output will now be including the bias ``b``. Please note there here the + operation implicitly broadcasts ``b`` to the batch size N, the size of the only axis of Y_hat:

In [None]:
Y_hat = ng.sigmoid(ng.dot(W, X) + b)
L = ng.cross_entropy_binary(Y_hat, Y, out_axes=()) / ng.batch_size(Y_hat)

For the parameter updates, instead of explicitly specifying the variables ``W`` and ``b``, we can call ``L.variables()`` to retrieve all the variables that the loss function depends on:

In [None]:
print([var.name for var in L.variables()])

For complicated ngraphs, the ``variables()`` method makes it easy to iterate over all its dependant variables. Our new parameter update is then:

In [None]:
updates = [ng.assign(v, v - alpha * ng.deriv(L, v) / ng.batch_size(Y_hat))
           for v in L.variables()]

Please note that this time we embedded the (call to the) gradient computation inside the definition of the weight update computation. As stated in the previous example, the ``ng.deriv`` function computes the backprop using autodiff. The update step computes the new weight and assigns it to ``W``:

In [None]:
all_updates = ng.doall(updates)

### Computation

We have our update computation as before, but we also add an evaluation computation that computes the loss on a separate dataset without performing the updates. Since the evaluation computation does not perform any update operation, we need not pass in the learning rate ``alpha``

In [None]:
transformer = ngt.make_transformer()

update_fun = transformer.computation([L, W, b, all_updates], alpha, X, Y)
eval_fun = transformer.computation(L, X, Y)

For convenience, we define a function that computes the average cost across the validation set.

In [None]:
def avg_loss(xs, ys):
    total_loss = 0
    for x, y in zip(xs, ys):
        loss_val = eval_fun(x, y)
        total_loss += loss_val
    return total_loss / x.shape[-1]

We then generate our training and evaluation sets and perform the updates with the same technique that we used in the previous example. We emit the average loss on the validation set during training. Please note that because the length of the axes W and H is 2 now (for both; before we had only one axis of lenght 4), the number of weights is the same as in the previous example

In [None]:
g = gendata.MixtureGenerator([.5, .5], (ax_W.length, ax_H.length))
XS, YS = g.gen_data(ax_N.length, 10)
EVAL_XS, EVAL_YS = g.gen_data(ax_N.length, 4)

print("Starting avg loss: {}".format(avg_loss(EVAL_XS, EVAL_YS)))
for i in range(10):
    for xs, ys in zip(XS, YS):
        loss_val, w_val, b_val, _ = update_fun(5.0 / (1 + i), xs, ys)
    print("After epoch %d: W: %s, b: %s, avg loss %s" % (i, w_val.T, b_val, avg_loss(EVAL_XS, EVAL_YS)))
    