## logistic regression multi-dimensional data

 logistic regression multi-dimensional data
 
 
 $$ F(X)=X \times W $$
 $$ H(x)= \frac{1}{1+ e ^{-F(x)}} $$
 $$ C= -\frac{1}{n} \sum_{i,j} (Y \odot log(H(x)) + (1-Y) \odot log(1-H(x)) ) $$

$X_{n \times k}$

$W_{k \times p}$

$Y_{n \times p}$

In [1]:
import numpy as np
import random

In [2]:
n, k, p=100, 8, 3 

In [3]:
X=np.random.random([n,k])
W=np.random.random([k,p])

y=np.random.randint(p, size=(1,n))
Y=np.zeros((n,p))
Y[np.arange(n), y]=1

max_itr=5000
alpha=0.01
Lambda=0.01

Gradient is as follows:
$$ X^T (H(x)-Y) + \lambda 2 W$$

In [4]:
# F(x)= w[0]*x + w[1]
def F(X, W):
    return np.matmul(X,W)

def H(F):
    return 1/(1+np.exp(-F))

def cost(Y_est, Y):
    E= - (1/n) * (np.sum(Y*np.log(Y_est) + (1-Y)*np.log(1-Y_est)))  + np.linalg.norm(W,2)
    return E, np.sum(np.argmax(Y_est,1)==y)/n

def gradient(Y_est, Y, X):
    return (1/n) * np.matmul(X.T, (Y_est - Y) ) + Lambda* 2* W

In [5]:
def fit(W, X, Y, alpha, max_itr):
    for i in range(max_itr):
        
        F_x=F(X,W)
        Y_est=H(F_x)
        E, c= cost(Y_est, Y)
        Wg=gradient(Y_est, Y, X)
        W=W - alpha * Wg
        if i%100==0:
            print(E, c)
        
    return W, Y_est

To take into account for the biases, we concatenate X by a 1 column, and increase the number of rows in W by one

In [6]:
X=np.concatenate( (X, np.ones((n,1))), axis=1 ) 
W=np.concatenate( (W, np.random.random((1,p)) ), axis=0 )

W, Y_est = fit(W, X, Y, alpha, max_itr)

6.718996914758343 0.3
4.820471171494192 0.36
4.318936523202406 0.38
4.218780306307605 0.42
4.197294212201086 0.44
4.191208347293338 0.44
4.188063539520591 0.45
4.185373731746892 0.45
4.182709348857521 0.45
4.180046614167144 0.45
4.1774200627910165 0.45
4.174857733252697 0.46
4.172375257807301 0.46
4.169979645990572 0.46
4.167672954582862 0.45
4.1654546110071875 0.45
4.163322716369359 0.46
4.161274740671397 0.46
4.159307884770626 0.45
4.157419266030563 0.46
4.155606012237945 0.46
4.15386530817463 0.48
4.152194417855384 0.48
4.1505906942922 0.49
4.149051582883661 0.5
4.147574621562894 0.51
4.146157439314604 0.5
4.1447977538896685 0.5
4.143493369144091 0.5
4.142242172222593 0.5
4.141042130700821 0.5
4.139891289745262 0.5
4.138787769321565 0.48
4.137729761467164 0.48
4.136715527636316 0.48
4.13574339612153 0.48
4.134811759553116 0.47
4.133919072477272 0.47
4.133063849012423 0.47
4.13224466058305 0.46
4.131460133729893 0.46
4.130708947995285 0.47
4.129989833882121 0.47
4.1293015708849135 0.

In [7]:
np.sum(Y_est,1)

array([1.04766095, 1.00516582, 0.99910737, 1.0435139 , 0.87375826,
       1.05671236, 0.9519994 , 0.94796727, 1.03601181, 0.9502441 ,
       0.94247826, 0.90353089, 0.92597811, 0.98511565, 0.97698083,
       1.02257008, 0.90844781, 1.10405497, 1.07010023, 0.93491063,
       0.98666603, 0.91225622, 0.99214984, 1.00096054, 0.92386857,
       1.03707912, 0.95229589, 0.93376756, 0.84302983, 0.96032026,
       1.01965308, 0.9744561 , 0.94435374, 0.84480364, 0.95262682,
       0.98724069, 0.90484181, 0.95751317, 0.97975727, 0.9673253 ,
       0.92215659, 1.0020763 , 0.97161778, 1.02142198, 0.97336865,
       0.99777586, 1.03726999, 1.07576816, 0.9575923 , 1.03840422,
       0.9733585 , 0.99737341, 0.96482504, 1.0056167 , 0.87698596,
       1.02431512, 0.96108744, 0.97146442, 1.02332517, 0.9055078 ,
       0.99921677, 1.01962354, 0.89424699, 0.95582332, 1.00612214,
       0.97121292, 0.88844944, 0.96024968, 0.93818885, 1.01437035,
       0.85985206, 0.93281064, 1.09641145, 0.95523637, 0.94659

In [8]:
Y_est

array([[0.46542848, 0.27369544, 0.30853704],
       [0.50673792, 0.20788303, 0.29054487],
       [0.44053807, 0.26958853, 0.28898076],
       [0.33309341, 0.34557902, 0.36484147],
       [0.36754939, 0.2915906 , 0.21461827],
       [0.48960444, 0.32859521, 0.23851271],
       [0.5284096 , 0.26961635, 0.15397346],
       [0.36825933, 0.33825752, 0.24145042],
       [0.51509989, 0.30827208, 0.21263984],
       [0.40103175, 0.30352946, 0.24568288],
       [0.47722772, 0.24584378, 0.21940675],
       [0.42793365, 0.23378695, 0.24181029],
       [0.42551066, 0.23640075, 0.2640667 ],
       [0.40719009, 0.31822731, 0.25969825],
       [0.33858396, 0.31925362, 0.31914325],
       [0.37620438, 0.32080387, 0.32556183],
       [0.24074667, 0.37012768, 0.29757346],
       [0.44254694, 0.3095497 , 0.35195834],
       [0.35298124, 0.36738642, 0.34973256],
       [0.32123275, 0.3129106 , 0.30076728],
       [0.46733998, 0.2252248 , 0.29410125],
       [0.29757238, 0.28923839, 0.32544544],
       [0.