In [10]:
import jax
import jax_metrics as jm
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial


from jax import random
import os
import numpy as np
import matplotlib.pyplot as plt# Switch off the cache 
from sklearn.metrics import confusion_matrix 
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

In [11]:
#tenemos la matriz de pesos. Como tenemos la capa de input y una hidden, por eso tenemos dos pessos y dos bias

    
@staticmethod
def random_layer_params(m, n, key, scale=2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,1))

    # Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
@staticmethod
def one_hot(y, k_clases):
    """Create a one-hot encoding of y of size k_clases."""
    return jnp.array(y[:, None] == jnp.arange(k_clases))
@staticmethod
def activacion(x):
    return jnp.maximum(0, x)#jnp.tanh(x)

def softmax(Z):
  A = jnp.exp(Z) / jnp.sum(jnp.exp(Z),axis=0)
  return A

def forward(params,x):
    #input to hidden layers
    activations=x
    for w,b in params[:-1]:
        outputs=jnp.dot(w,activations) + b #size -> (hidden,hidden anterior)
        activations=activacion(outputs)
        
    #last hidden to output
    #we use softmax for the last one
    w_last, b_last = params[-1]
    logits = jnp.dot(w_last, activations) + b_last
    soft=softmax(logits) #size -> (classes,samples) ****
    return soft
 
def loss_function(params,x,y_hot):
    soft=forward(params,x)
    loss=jnp.mean(-y_hot*jnp.log(soft))
    
    return jnp.mean(-y_hot*jnp.log(soft))
    
@partial(jit, static_argnums=(0,))
def update( params, x, y,learning_rate):
    lr=learning_rate #learning rate
    grads = grad(loss_function)(params, x, y)
    
    return [(w - lr * dw, b - lr* db)
                    for (w, b), (dw, db) in zip(params, grads)]

def get_accuracy(predictions, Y):
    #print(predictions, Y)
    return jnp.sum(predictions == Y) / Y.size

def dloss(params, x, y):
        return grad(loss_function)(params, x, y)

def prediction(soft): 
    return jnp.argmax(soft,axis=0)

def get_pr(k_classes,samples,clases,y0,y_hat):
    FP=0
    FN=0
    TP=0
    recall_list=[]
    precision_list=[]
    for k in range(k_classes):
        for i in range (samples):
            if y0[i]==clases[k] and y_hat[i]==clases[k]:
                TP+=1
            if y0[i]!=clases[k] and y_hat[i] == clases[k]:
                FP+=1
            if y0[i]==clases[k] and y_hat[i] != clases[k]:
                FN+=1
        if FP+TP!=0:
            precision_list.append(TP/(TP+FP))
        else:
            precision_list.append(0)
        if FN+TP!=0:
            recall_list.append(TP/(TP+FN))
        else:
            recall_list.append(0)

        
        
   
    precision=sum(precision_list)/k_classes
    recall=sum(recall_list)/k_classes
    return precision,recall








def get_accuracy(predictions, Y):
    #print(predictions, Y)
    return jnp.sum(predictions == Y) / Y.size


def modelo(sizes,key,max_steps,x,y,y_hot,learning_rate,k_clases,samples,clases,stop):
    #y_hot=one_hot(y, k_clases)
    #the initial parameters
    params=init_network_params(sizes,random.PRNGKey(0))
    print(params)
    precision_list=[]
    recall_list=[]
    loss_list=[]
    loss=10
    
    for i in range(max_steps):
        old_loss=loss
        loss=loss_function(params,x,y_hot)
        g=dloss(params, x, y)
        #print(g)
        params=update(params, x, y_hot,learning_rate)
        if i%100==0:
            soft=forward(params,x)
            y_hat=prediction(soft)
            ac=get_accuracy(y_hat, y)
            precision,recall=get_pr(k_clases,samples,clases,y,y_hat)
            precision_list.append(precision)
            recall_list.append(recall)
            print(loss,i,precision,recall,ac)
        if  jnp.abs(loss-old_loss)<stop:
            break
        #print(i,params)
        #print(loss)
    
        
    return loss

def graficar_pr(recall_list,precision_list):
    plt.style.use('rose-pine')
    plt.plot(recall_list,precision_list,color='#fb9f9f',marker='*')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall curve')
    plt.show()

def graficar_rc(loss_list):
    plt.style.use('rose-pine')
    plt.plot(loss_list,color='#fb9f9f')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Rate of Convergency')
    plt.show()

In [12]:
#before trying with MNIST lets try with something simpller
import sklearn.datasets
x,y=sklearn.datasets.make_moons(200,noise=0.15) #x are inputs and y are the expected labels
print(y)
x=jnp.transpose(x) #because we wrote the equations thinking that sizeX=(features,samples)
y=jnp.transpose(y)


import pandas as pd
df=pd.DataFrame(y,columns=['y'])
renglones=df['y'].shape[0]
df_categorical=df[['y']]
columnas=df_categorical.shape[1]
one_hot=np.zeros(renglones)

col=0
df2=pd.DataFrame()
df3=pd.DataFrame()
df3.insert(0,'1',one_hot)
for k in range(0,columnas):
    name=df_categorical.columns[k]
    print(name)
    lista=[]
    lista=df_categorical[name].values.tolist()
    clases=df_categorical[name].unique() #un array de las distintas clases
    size_clases=len(clases)
    
    
    for i in range(size_clases):
        clase=clases[i]
        print(clase)

        for j in range(0,renglones):
            if lista[j]==clase:
                one_hot[j]=1  
        
        df2.insert(i,clase,one_hot)
        
        one_hot=np.zeros(renglones)
    df3=df3.join(df2)
    df2=pd.DataFrame()
del df3[df3.columns[0]]
y_hot=df3.to_numpy()
#
y_hot=jnp.transpose(y_hot)
print(y_hot.shape)


[0 0 1 0 1 0 0 1 0 0 0 0 0 1 1 1 1 0 1 1 0 1 1 1 1 0 0 1 0 1 1 1 1 1 0 0 0
 1 1 0 0 0 0 1 0 1 0 1 1 0 0 1 1 1 1 1 1 1 0 0 1 1 0 1 1 0 0 1 0 1 0 1 0 0
 1 1 1 0 0 1 1 0 0 1 1 0 1 0 0 1 0 1 0 1 1 0 0 0 1 0 0 1 0 1 1 1 0 1 0 0 0
 0 0 1 1 0 0 1 0 1 1 0 0 1 0 1 1 0 1 0 0 1 0 1 1 0 1 1 1 0 1 0 0 1 0 0 0 0
 1 0 0 0 1 0 1 0 1 1 0 1 1 0 1 0 0 0 0 0 0 1 0 0 1 1 0 1 1 1 1 0 0 0 0 0 0
 1 1 1 1 1 1 1 0 1 1 1 0 0 1 0]
y
0
1
(2, 200)


In [13]:
sizes=[2,3,4,2]
print(len(sizes))

max_steps=10000
key=2
learning_rate=0.28
k_clases=2
samples=200
clases=jnp.array([0,1])
stop=0.00001
modelo(sizes,key,max_steps,x,y,y_hot,learning_rate,k_clases,samples,clases,stop)

4
[(DeviceArray([[-0.82923746, -1.146248  ],
             [-1.170297  ,  1.1766202 ],
             [ 1.3146174 ,  2.3568673 ]], dtype=float32), DeviceArray([[-0.5948747],
             [ 5.230382 ],
             [-0.0607711]], dtype=float32)), (DeviceArray([[ 0.08676989,  0.6301467 , -1.4263141 ],
             [ 1.6186938 ,  0.04880301, -1.0032021 ],
             [-2.0201273 ,  0.25526375, -2.33574   ],
             [ 4.728775  ,  2.3069553 , -0.6993406 ]], dtype=float32), DeviceArray([[-1.7650905 ],
             [-0.20067427],
             [-0.61008555],
             [-2.3439033 ]], dtype=float32)), (DeviceArray([[ 1.1595801 ,  3.2079782 ,  1.6459509 , -1.0763928 ],
             [-0.09025478, -3.2137861 ,  1.0516778 , -0.28516877]],            dtype=float32), DeviceArray([[4.5517006],
             [1.2412996]], dtype=float32))]
1.0844892 0 0.5 0.75 0.5
0.32005876 100 0.6821428571428572 0.575 0.65
0.12938456 200 0.8991019417475729 0.9125000000000001 0.905
0.12633243 300 0.89910194174757

KeyboardInterrupt: 