In [56]:
import jax
# import jax_metrics as jm
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax import random
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
# Switch off the cache 
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

class Logistic_Regression():
    """
    Basic Model + Quasi Newton Methods
    """
    def __init__(self, regularization='l2', method_opt='classic_model', lmbda=1.0):
        self.regularization = regularization
        self.method_opt = method_opt
        self.error_gradient = 0.001
        self.key = random.PRNGKey(0)
        # You need to add some variables
        self.W = None

    @partial(jit, static_argnums=(0,))
    def model(self, W:jnp, X:jnp, Y_hot:jnp)->jnp:
        """
        Logistic Model
        """
        W = jnp.reshape(W, self.sh)
        terms = self.logistic_exp(W, X)
        sum_terms = self.logistic_sum(terms)
        matrix = self.logit_matrix(terms, sum_terms)
        
        if self.regularization == "l1":
            # Regularization l1
            reg = self.lmbda/(X.shape[1]) * jnp.sum(jnp.abs(W))
        else:
            # Regularization l2
            reg = self.lmbda/(X.shape[1]) * jnp.trace(jnp.transpose(W)@W)

        return jnp.sum(jnp.sum(jnp.log(matrix)*Y_hot, axis=0), axis=0) + reg

    def regularized_model(self, W:jnp, X:jnp, Y_hot:jnp, alpha:float=1e-2, tol:float=1e-3)->jnp:
        """
        The regularized version of the logistic regression
        """
        n, m = W.shape 
        self.sh = (n, m)
        alpha = 0.5
        Grad = jax.grad(self.model, argnums=0)(jnp.ravel(W), X, Y_hot)
        loss = self.model(jnp.ravel(W), X, Y_hot)
        cnt = 0

        while cnt < 200:
            Hessian = jax.hessian(self.model, argnums=0)(jnp.ravel(W), X, Y_hot)
            W = W - alpha * jnp.reshape((jnp.linalg.inv(Hessian)@Grad) , self.sh)
            Grad =  jax.grad(self.model, argnums=0)(jnp.ravel(W), X, Y_hot)
            old_loss = loss
            loss = self.model(jnp.ravel(W), X, Y_hot)

            if cnt%30 == 0:
                print(f'{self.model(jnp.ravel(W), X, Y_hot)}')
                
            if  jnp.abs(old_loss - loss) < tol:
                break

            cnt +=1

        return W

    @staticmethod
    @jit
    def logistic_exp(W:jnp, X:jnp)->jnp:
        """
        Generate all the w^T@x values 
        args:
            W is a k-1 x d + 1
            X is a d x N
        """
        return jnp.exp(W@X)

    @staticmethod
    @jit
    def logistic_sum(exTerms: jnp)->jnp:        
        """
        Generate all the w^T@x values 
        args:
            W is a k-1 x d 
            X is a d x N
        """
        temp = jnp.sum(exTerms, axis=0)
        n = temp.shape[0]
        return jnp.reshape(1.0+temp, newshape=(1, n))

    @staticmethod
    @jit
    def logit_matrix(Terms: jnp, sum_terms: jnp)->jnp:
        """
        Generate matrix
        """
        divisor = 1/sum_terms
        n, _ = Terms.shape
        replicate = jnp.repeat(divisor, repeats=n, axis=0 )
        logits = Terms*replicate
        return jnp.vstack([logits, divisor])
    
    @partial(jit, static_argnums=(0,))
    def model(self, W:jnp, X:jnp, Y_hot:jnp)->jnp:
        """
        Logistic Model
        """
        W = jnp.reshape(W, self.sh)
        terms = self.logistic_exp(W, X)
        sum_terms = self.logistic_sum(terms)
        matrix = self.logit_matrix(terms, sum_terms)
        return jnp.sum(jnp.sum(jnp.log(matrix)*Y_hot, axis=0), axis=0)
    
    @staticmethod
    def one_hot(Y: jnp):
        """
        One_hot matrix
        """
        numclasses = len(jnp.unique(Y))
        return jnp.transpose(jax.nn.one_hot(Y, num_classes=numclasses))
    
    def generate_w(self, k_classes:int, dim:int)->jnp:
        """
        Use the random generator at Jax to generate a random generator to instanciate
        the augmented values
        """
        key = random.PRNGKey(0)
        keys = random.split(key, 1)
        return jnp.array(random.normal(keys[0], (k_classes, dim)))

    @staticmethod
    def augment_x(X: jnp)->jnp:
        """
        Augmenting samples of a dim x N matrix
        """
        N = X.shape[1]
        return jnp.vstack([X, jnp.ones((1, N))])
     
   
    def fit(self, X: jnp, Y:jnp)->None:
        """
        The fit process
        """
        nclasses = len(jnp.unique(Y))
        X = self.augment_x(X)
        dim = X.shape[0]
        W = self.generate_w(nclasses-1, dim)
        Y_hot = self.one_hot(Y)
        self.W = getattr(self, self.method_opt, lambda W, X, Y_hot: self.error() )(W, X, Y_hot)
    
    @staticmethod
    def error()->None:
        """
        Only Print Error
        """
        raise Exception("Opt Method does not exist")
    
    def classic_model(self, W:jnp, X:jnp, Y_hot:jnp, alpha:float=1e-2,  tol:float=1e-3)->jnp:
        """
        The naive version of the logistic regression
        """
        n, m = W.shape 
        self.sh = (n, m)
        alpha = 0.5
        Grad = jax.grad(self.model, argnums=0)(jnp.ravel(W), X, Y_hot)
        loss = self.model(jnp.ravel(W), X, Y_hot)
        cnt = 0
        while True:
            Hessian = jax.hessian(self.model, argnums=0)(jnp.ravel(W), X, Y_hot)
            W = W - alpha*jnp.reshape((jnp.linalg.inv(Hessian)@Grad), self.sh)
            Grad =  jax.grad(self.model, argnums=0)(jnp.ravel(W), X, Y_hot)
            old_loss = loss
            loss = self.model(jnp.ravel(W), X, Y_hot)
            if cnt%30 == 0:
                print(f'{self.model(jnp.ravel(W), X, Y_hot)}')
            if  jnp.abs(old_loss - loss) < tol:
                break
            cnt +=1
        return W
    
    def estimate_prob(self, X:jnp)->jnp:
        """
        Estimate Probability
        """
        X = self.augment_x(X)
        terms = self.logistic_exp(self.W, X)
        sum_terms = self.logistic_sum(terms)
        matrix = self.logit_matrix(terms, sum_terms)
        return matrix
    
    def estimate(self, X:jnp)->jnp:
        """
        Estimation
        """
        X = self.augment_x(X)
        terms = self.logistic_exp(self.W, X)
        sum_terms = self.logistic_sum(terms)
        matrix = self.logit_matrix(terms, sum_terms)
        return jnp.argmax(matrix, axis=0)
    
    def precision(self, y, y_hat):
        """
        Precision
        args:
            y: Real Labels
            y_hat: estimated labels
        return TP/(TP+FP)
        """
        TP = sum(y_hat == y)
        FP = sum(y_hat != y)
        return (TP/(TP+FP)).tolist()

        
    
class Tools():
    """
    Tools
    """
    def __init__(self):
        """
        Basic Init
        """
        self.key = random.PRNGKey(0)
    
    def GenerateData(self, n_samples: int, n_classes: int, dim: int):
        """
        Data Generation
        """
        Total_Data = [] 
        Total_Y = []
        for idx in range(n_classes):
            keys = random.split(self.key, 1)
            X = random.normal(keys[0], (dim, n_samples)) + idx*5*jnp.ones((dim, 1))
            Y = idx*jnp.ones(n_samples)
            Total_Data.append(X)
            Total_Y.append(Y)
        return jnp.hstack(Total_Data), jnp.hstack(Total_Y)
    
    @staticmethod
    def plot_classes(X: jnp, Y: jnp, n_classes: int)-> None:
        """
        Plot the classes
        """
        symbols = ['ro', 'bx', 'go', 'rx']
        plt.figure()
        for idx in range(n_classes):
            mask = idx == Y
            X_p = X[:, mask]
            plt.plot(X_p[0,:], X_p[1,:], symbols[idx])
        

In [57]:
tools = Tools()

In [58]:
X, Y = tools.GenerateData(n_samples=200, n_classes=3, dim=2)
X_val, Y_val = tools.GenerateData(n_samples=50, n_classes=3, dim=2)

In [59]:
model = Logistic_Regression()

In [60]:
model.fit(X, Y)

-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-6.462610721588135
-6.343696117401123
-6.236150741577148
-6.135331153869629
-6.041781425476074
-5.958338260650635
-5.883556365966797
-5.813255310058594
-5.746197700500488
-5.680989742279053
-5.619110584259033
-5.559559345245361
-5.50145959854126
-5.444943904876709
-5.3913655281066895
-5.340021133422852
-5.289865016937256
-5.243139266967773
-5.198108673095703
-5.155386447906494
-5.114518642425537
-5.075049877166748
-5.03630256652832
-4.998791217803955
-4.962948322296143
-4.92769193649292
-4.893009662628174
-4.858890056610107
-4.825308322906494
-4.792853355407715
-4.761292457580566
-4.730490684509277


In [61]:
Y_hat = model.estimate(X_val)

In [62]:
Y_hat

DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
             2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
             2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32)

In [63]:
model.precision(Y_val, Y_hat)

1.0

In [64]:
X_prob = model.estimate_prob(X)

In [65]:
X_prob[:, 402]

DeviceArray([8.1802198e-15, 8.2354492e-04, 9.9917644e-01], dtype=float32)

In [66]:
X_prob

DeviceArray([[9.9998546e-01, 9.9993360e-01, 9.9926937e-01, ...,
              3.4069217e-14, 1.4104338e-12, 2.3718556e-14],
             [1.4573496e-05, 6.6376946e-05, 7.3067157e-04, ...,
              1.5517577e-03, 8.0707679e-03, 1.3268435e-03],
             [4.3063693e-15, 6.5353803e-14, 4.8829512e-12, ...,
              9.9844825e-01, 9.9192929e-01, 9.9867320e-01]],            dtype=float32)

In [67]:
lmbdas = [0, 0.1, 1, 3, 5, 10, 20, 50]
precision_list_l1 = []
precision_list_l2 = []

for l in lmbdas:
    # For l1 Regularization
    model = Logistic_Regression(
        regularization = "l1",
        method_opt = 'regularized_model',
        lmbda = l)
    model.fit(X, Y)
    Y_hat = model.estimate(X_val)
    precision_list_l1.append(model.precision(Y_val, Y_hat))

    # For l2 Regularization
    model = Logistic_Regression(
        regularization = "l2",
        method_opt = 'regularized_model',
        lmbda = l,
    )
    model.fit(X, Y)
    Y_hat = model.estimate(X_val)
    precision_list_l2.append( model.precision(Y_val, Y_hat))

-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815002441
-6.927130699157715
-6.749755382537842
-6.593540668487549
-676.3424072265625
-7.675201416015625
-7.383645057678223
-7.134308815

In [68]:
print(f'L1 precission: {precision_list_l1}\tL2 precission: {precision_list_l2}')

L1 precission: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]	L2 precission: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]


The implementation follows the previous code cells. What could be not working?

In [74]:
url = "https://www.openml.org/data/download/1595261/phpMawTba"

df = pd.read_csv(url, header = None, skiprows=19, na_values=' ?')
le = preprocessing.LabelEncoder()
df[14] = le.fit_transform(df[14])
df = df.dropna()
df = pd.get_dummies(df)

y = df[14].values
X = df.drop([14], axis = 1).values

scaler = MinMaxScaler()
X = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X, y,
    test_size=0.30,
    random_state=42,
    stratify = y
)

X_train = np.swapaxes(X_train,0,1)
X_test = np.swapaxes(X_test,0,1)

lmbdas = [0, 1]
acc_list_l2 = []
rec_list_l2 = []

model = Logistic_Regression(
    regularization = "l2",
    method_opt = 'regularized_model',
    lmbda = 0.01,
)

model.fit(X_train[i * len(X_train):(i + 1)*len(X_train)], y_train[i * len(y_train):(i + 1)*len(y_train)])
Y_hat = model.estimate(X_test)
acc_list_l2.append(accuracy_score(y_test, Y_hat))
rec_list_l2.append(recall_score(y_test, Y_hat, average='binary'))


Batch: 0
-40635.04296875


KeyboardInterrupt: ignored