In [None]:
import numpy as np

def soft_thresholding(a, b):
    return np.sign(a) * np.maximum(np.abs(a) - b, 0)

def one_hot_encode(y, num_classes=None):
    y = np.array(y)
    if num_classes is None:
        num_classes = np.max(y) + 1
    
    one_hot = np.zeros((len(y), num_classes))
    one_hot[np.arange(len(y)), y] = 1
    return one_hot


class logisitic_regression():
    def __init__(self):
        pass
    
    def set_std_mean(self,X, epsilon = 1e-8):
        self.mean = np.mean(X, axis = 0)
        std = np.std(X,axis = 0)
        self.std = np.where(std > 0, std, epsilon)

    def standarize(self, X):
        return (X-self.mean)/self.std
        

    def fit(self, X, y, a, lambd, times = 10, extinction = 0.1):
        n, p = X.shape
        X, y = np.array(X), np.array(y)
        self.set_std_mean(X)
        X = self.standarize(X)
        y = one_hot_encode(y)
        g = y.shape[1]
        self.B = np.zeros((p,g))
        for i in range(times):
            #lambd *= extinction
            for k in range(g):
                for j in range(p):
                    xj = X[:,j]
                    #X = (X.T[self.B]).T jest tam opcja by olewać obliczenia dla zerowych bet to zrobić mniej więcej tak albo trzymać listę indeksów i na samych nich to robić
                    #X = X[:,non_zero_beta]
                    sum = (xj@(y[:,k])) - xj@X@(self.B[:,k]) + self.B[j,k]
                    self.B[j,k]= soft_thresholding(sum/n,lambd*a)/(1+lambd*(1-a))

    def predict_proba(self, X):
        X = np.array(X)
        X = self.standarize(X)
        X = np.exp(X@self.B)
        return X / X.sum(axis=1, keepdims=True)
        

In [30]:
from ucimlrepo import fetch_ucirepo
import numpy as np

rice_cammeo_and_osmancik = fetch_ucirepo(id=545) 
X_1 = np.array(rice_cammeo_and_osmancik.data.features) 
y_1 = (np.array(rice_cammeo_and_osmancik.data.targets).T[0]== "Cammeo").astype(int)

In [52]:
one_hot_encode(y_1)

array([[0., 1.],
       [0., 1.],
       [0., 1.],
       ...,
       [1., 0.],
       [1., 0.],
       [1., 0.]])

In [53]:
lr_test = logisitic_regression()
lr_test.fit(X_1,y_1,1 ,0.1)

In [54]:
lr_test.B

array([[-0.03315889,  0.03315889],
       [-0.06549224,  0.06549224],
       [-0.02104434,  0.02104434],
       [-0.        ,  0.        ],
       [-0.01468034,  0.01468034],
       [-0.0674756 ,  0.0674756 ],
       [ 0.        , -0.        ]])

In [55]:
lr_test.predict_proba(X_1)

  return X / X.sum(axis=1, keepdims=True)


array([[ 1.06486668e+14, -1.06486668e+14],
       [ 6.74720214e+13, -6.74720214e+13],
       [ 8.94526007e+13, -8.94526007e+13],
       ...,
       [ 8.98611605e+13, -8.98611605e+13],
       [ 7.95162300e+13, -7.95162300e+13],
       [ 8.18281267e+13, -8.18281267e+13]])

In [23]:
X = np.random.randint(1, 10, (15,3))
B = np.random.randint(0, 2, (3,4))
X,B

(array([[1, 2, 6],
        [9, 7, 2],
        [2, 4, 9],
        [4, 9, 4],
        [4, 4, 5],
        [3, 8, 6],
        [7, 7, 6],
        [4, 6, 3],
        [2, 4, 4],
        [8, 6, 7],
        [3, 4, 2],
        [5, 5, 2],
        [8, 2, 5],
        [2, 6, 3],
        [5, 1, 3]]),
 array([[1, 0, 1, 0],
        [1, 1, 0, 1],
        [1, 1, 0, 1]]))

In [26]:
W = X@B
W

array([[ 9,  8,  1,  8],
       [18,  9,  9,  9],
       [15, 13,  2, 13],
       [17, 13,  4, 13],
       [13,  9,  4,  9],
       [17, 14,  3, 14],
       [20, 13,  7, 13],
       [13,  9,  4,  9],
       [10,  8,  2,  8],
       [21, 13,  8, 13],
       [ 9,  6,  3,  6],
       [12,  7,  5,  7],
       [15,  7,  8,  7],
       [11,  9,  2,  9],
       [ 9,  4,  5,  4]])

In [28]:
W/W.sum(axis=1, keepdims = True)

array([[0.34615385, 0.30769231, 0.03846154, 0.30769231],
       [0.4       , 0.2       , 0.2       , 0.2       ],
       [0.34883721, 0.30232558, 0.04651163, 0.30232558],
       [0.36170213, 0.27659574, 0.08510638, 0.27659574],
       [0.37142857, 0.25714286, 0.11428571, 0.25714286],
       [0.35416667, 0.29166667, 0.0625    , 0.29166667],
       [0.37735849, 0.24528302, 0.13207547, 0.24528302],
       [0.37142857, 0.25714286, 0.11428571, 0.25714286],
       [0.35714286, 0.28571429, 0.07142857, 0.28571429],
       [0.38181818, 0.23636364, 0.14545455, 0.23636364],
       [0.375     , 0.25      , 0.125     , 0.25      ],
       [0.38709677, 0.22580645, 0.16129032, 0.22580645],
       [0.40540541, 0.18918919, 0.21621622, 0.18918919],
       [0.35483871, 0.29032258, 0.06451613, 0.29032258],
       [0.40909091, 0.18181818, 0.22727273, 0.18181818]])