# Problem 0. 8 pts.

__Rules__
* You have to create groups with min. 2 and max. 3 people for this problem.
    * Each group should create a repository for this task
    * The link to this repository should be included into current report
    * New repository should be private. Only members of the team and course instructor (Mikhail) should be in the list of collaborators
* Grades will be assigned according to the contributions to the repository (don't work on one PC alltogether and then push to master branch).
* In the perferct world, the team would split the problem into subproblems and each member would work only on their own subproblem. That would be a good practice for: 1. project, 2. collaborative work with git on a single code base
* Only the code in `master` branch will be evaluated. (In order to avoid painfull merge in the end - work on you own subproblem in a separate branch and merge to the master once the work is done).

__Task__
* You have to implement sparse tensor slicing library:
    * You need to come up with a cool name for the library
    * You have to write the code for sparse tensor slicing (remember about slow python loops, consider outsourcing this parts into `C/C++`) using `SWIG`, `pybind` or `cython`
    * You need to write tests that show that actually your library is working correctly: here we would play my favorite game again. If I find an example where you lib is failed - I am the winner.
    * `pytest` is suggested framework for writing the tests
* The most important thing in working with sparse tensor is the speed:
    * You have to choose proper sparse format
    * All team members will get additional **up to 5 Bonus Points** for `the fastest correct implementation`.
    * The speed will be evaluated on a closed set of test sparse tensor (you don't know them or their structure in advance).
* The winner must opensource their implementation (make the repository public).
* The input tensor is in **COO** format
* You library should export just one single function for slicing the tensor
* The library should be installable by command `python setup.py install`
* The code should be written in a clean and nice manner (in the best case even with a documantation in any of the formats available, see `Sphinx` for example)
* Commits history must be clean such that it is clear from the commit message what was done in each commit
* It is a good opportunity to make a good use of pull requests

* The team with the best documentation gets **up to 3 Bonus Points**
* The team with the best test coverage gets **up to 3 Bonus Points**

# Bonus task:

$\frac{a}{b} + \frac{c}{d} = 1$

$\frac{a}{d} + \frac{c}{b} = 2018$


solve for a, b, c, d $\in$ [1, $\infty$] in integers.
W.r.t. (a + b + c + d) <= 1e17.

# SymPy Intro

Take a look [here](https://safwanahmad.github.io/2018/01/21/Linear-Regression-A-Tale-of-a-Transform.html)

In [None]:
!pip install "jax[cpu]"
!pip install scikit-learn

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sklearn

# Warming up
For starters, let's implement a python function that computes the sum of squares of numbers from 0 to N-1.
* Use numpy or python
* An array of numbers 0 to N - numpy.arange(N)

In [None]:
def sum_squares(N):
    return <student.implement_me()>

In [None]:
%%time
sum_squares(10**8)

# Same with jax

In [None]:
def sum_squares_jax(N):
    return <student.implement_me()>

In [None]:
%%time
sum_squares_jax(10**8)

## Practice 1: polar pretzels
_inspired by [this post](https://www.quora.com/What-are-the-most-interesting-equation-plots)_

There are some simple mathematical functions with cool plots. For one, consider this:

$$ x(t) = t - 1.5 * cos( 15 t) $$
$$ y(t) = t - 1.5 * sin( 16 t) $$


In [None]:
t = jnp.linspace(-10, 10, num=10000)

# compute x(t) and y(t) as defined above.
x = ###YOUR CODE
y = ###YOUR CODE

plt.plot(x, y);

# Practice 2: mean squared error


In [None]:
# Quest #1 - implement a function that computes a mean squared error of two input vectors
# Your function has to take 2 vectors and return a single number

compute_mse = lambda vector1, vector2: <student.define_transformation()>

In [None]:
# Tests
from sklearn.metrics import mean_squared_error
import numpy as np

for n in [1, 5, 10, 10 ** 3]:
    
    elems = [np.arange(n), np.arange(n,0,-1), np.zeros(n),
             np.ones(n), np.random.random(n), np.random.randint(100, size=n)]
    
    for el in elems:
        for el_2 in elems:
            true_mse = np.array(mean_squared_error(el, el_2))
            my_mse = compute_mse(el, el_2)
            if not np.allclose(true_mse, my_mse):
                print('Wrong result:')
                print('mse(%s,%s)' % (el, el_2))
                print("should be: %f, but your function returned %f" % (true_mse, my_mse))
                raise ValueError("Smth went wrong")

print("All tests passed")    

# Autodiff - why graphs matter
* Jax can compute derivatives and gradients automatically using the computation graph
* Gradients are computed as a product of elementary derivatives via chain rule:

$$ {\partial f(g(x)) \over \partial x} = {\partial f(g(x)) \over \partial g(x)}\cdot {\partial g(x) \over \partial x} $$

It can get you the derivative of any graph as long as it knows how to differentiate elementary operations

In [None]:
square = lambda x: x**2
grad_square = jax.grad(square)
print(grad_square(1.0))

In [None]:
x = jnp.linspace(-3,3)
square = lambda x: (x**2).sum()
grad_square = jax.grad(square)
x_squared, x_squared_der = x**2, grad_square(x)

plt.plot(x, x_squared,label="x^2")
plt.plot(x, x_squared_der, label="derivative")
plt.legend();

In [None]:
#Compute the gradient of the next weird function over my_scalar and my_vector
#warning! Trying to understand the meaning of that function may result in permanent brain damage
weird_psychotic_function = lambda my_vector, my_scalar: ((my_vector + my_scalar)**(1 + jnp.var(my_vector)) + 1./ jnp.arctan(my_scalar)).mean() / (my_scalar**2 + 1) + 0.01 * jnp.sin(2 * my_scalar**1.5) * (sum(my_vector) * my_scalar**2) * jnp.exp((my_scalar - 4)**2) / (1 + jnp.exp((my_scalar - 4)**2)) * (1. - (jnp.exp(-(my_scalar - 4)**2)) / (1 + jnp.exp(-(my_scalar - 4)**2)))**2

der_by_scalar = <student.compute_grad_over_scalar()>
der_by_vector = <student.compute_grad_over_vector()>

In [None]:
#Plotting your derivative
scalar_space = jnp.linspace(1, 7, 100)

y = [weird_psychotic_function(jnp.array([1, 2, 3.]), x) for x in scalar_space]

plt.plot(scalar_space, y, label='function')

y_der_by_scalar = [der_by_scalar(jnp.array([1, 2, 3.]), x) for x in scalar_space]

plt.plot(scalar_space, y_der_by_scalar, label='derivative')
plt.grid()
plt.legend();

# Almost done - optimizers

While you can perform gradient descent by hand with automatic grads from above, jax also has some optimization methods implemented for you. Recall momentum & rmsprop?
Check out this [example](https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html) for the use of optimizers in jax.

# Problem 1. 3 pts. Logistic regression example
Implement the regular logistic regression training algorithm
 
We shall train on a two-class MNIST dataset. 

This is a binary classification problem, so we'll train a __Logistic Regression with sigmoid__.
$$P(y_i | X_i) = \sigma(W \cdot X_i + b) ={ 1 \over {1+e^{- [W \cdot X_i + b]}} }$$


The natural choice of loss function is to use binary crossentropy (aka logloss, negative llh):
$$ L = {1 \over N} \underset{X_i,y_i} \sum - [  y_i \cdot log P(y_i | X_i) + (1-y_i) \cdot log (1-P(y_i | X_i)) ]$$

Mind the minus :)


In [2]:
from sklearn.datasets import load_digits
X, y = load_digits(n_class=2, return_X_y=True)

print("y [shape - %s]:" % (str(y.shape)), y[:10])
print("X [shape - %s]:" % (str(X.shape)))

y [shape - (360,)]: [0 1 0 1 0 1 0 0 1 1]
X [shape - (360, 64)]:


In [None]:
print('X:\n', X[:3,:10])
print('y:\n', y[:10])
plt.imshow(X[0].reshape([8,8]))

In [None]:
# inputs
weights = <student.create_variable()>

In [None]:
predicted_y_proba = lambda input_X, weights: <predicted probabilities for input_X using weights>

loss = lambda predicted_y_proba, input_y: <logistic loss (scalar, mean over sample) between predicted_y_proba and input_y>

train_step = <operator that minimizes loss>

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [None]:
from sklearn.metrics import roc_auc_score

for i in range(5):
    
    loss_i = ###<YOUR CODE: feed values>)
    
    print("loss at iter %i: %.4f" % (i, loss_i))
    
    print("train auc:", roc_auc_score(y_test, predicted_y_proba(X_train))
    print("test auc:", roc_auc_score(y_test, predicted_y_proba(X_test))

    
print ("resulting weights:")
plt.imshow(sess.run(weights).reshape(8, -1))
plt.colorbar();

# Problem 2. 3 pts. my first jax network
Your ultimate task for this week is to build your first neural network [almost] from scratch and pure jax.

This time you will same digit recognition problem, but at a larger scale
* images are now 28x28
* 10 different digits
* 50k samples

Note that you are not required to build 152-layer monsters here. A 2-layer (one hidden, one output) NN should already have ive you an edge over logistic regression.

__[bonus score]__
If you've already beaten logistic regression with a two-layer net, but enthusiasm still ain't gone, you can try improving the test accuracy even further! The milestones would be 95%/97.5%/98.5% accuraсy on test set.

__SPOILER!__
At the end of the notebook you will find a few tips and frequently made mistakes. If you feel enough might to shoot yourself in the foot without external assistance, we encourage you to do so, but if you encounter any unsurpassable issues, please do look there before mailing us.

In [3]:
from sklearn.datasets import load_digits
X, y = load_digits(return_X_y=True)

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```


# SPOILERS!

Recommended pipeline

* Adapt logistic regression from previous assignment to classify some number against others (e.g. zero vs nonzero)
* Generalize it to multiclass logistic regression.
  - Instead of weight vector you'll have to use matrix (feature_id x class_id)
  - softmax (exp over sum of exps) can implemented manually or as jax.nn.softmax (stable)
  - probably better to use STOCHASTIC gradient descent (minibatch)
    - in which case sample should probably be shuffled (or use random subsamples on each iteration)
* Add a hidden layer. Now your logistic regression uses hidden neurons instead of inputs.
  - Hidden layer uses the same math as output layer (ex-logistic regression), but uses some nonlinearity (sigmoid) instead of softmax
  - You need to train both layers, not just output layer :)
  - Do not initialize layers with zeros (due to symmetry effects). A gaussian noize with small sigma will do.
  - 50 hidden neurons and a sigmoid nonlinearity will do for a start. Many ways to improve. 
  - In ideal casae this totals to 2 .dot's, 1 softmax and 1 sigmoid
  - __make sure this neural network works better than logistic regression__
  
* Now's the time to try improving the network. Consider layers (size, neuron count),  nonlinearities, optimization methods, initialization - whatever you want, but please avoid convolutions for now.