# Задание

* Реализовать самостоятельно логистическую регрессию
* Обучить ее методу градиентного спуска
* Методу nesterov momentum
* Методу rmsprop

###### Дополнительное задание *
В качестве dataset’а взять Iris, оставив 2 класса:
Iris Versicolor
Iris Virginica

- Загрузка библиотек

In [125]:
from sklearn import datasets
import pandas as pd
import numpy as np

- Загрузка датасета Iris
- Переводим его в датафрейм
- Целевые значения также заносим в датафрейм
- Фильтруем согласно заданию 

In [96]:
data = datasets.load_iris()
df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.target
df_0_1 = df.loc[df['target']>0]
df_0_1

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
50,7.0,3.2,4.7,1.4,1
51,6.4,3.2,4.5,1.5,1
52,6.9,3.1,4.9,1.5,1
53,5.5,2.3,4.0,1.3,1
54,6.5,2.8,4.6,1.5,1
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,2
146,6.3,2.5,5.0,1.9,2
147,6.5,3.0,5.2,2.0,2
148,6.2,3.4,5.4,2.3,2


In [114]:
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline # используем пайплайны для удобства
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split 

In [115]:
model = make_pipeline(StandardScaler(), LogisticRegression(max_iter=1000))

In [116]:
X = df_0_1[['sepal length (cm)', 'sepal width (cm)','petal length (cm)', 'petal width (cm)']]
y = df_0_1['target']

X = X.values # Для корректной обработки переводим таблицу в массив
y = y.values # Для корректной обработки переводим таблицу в массив

In [117]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [118]:
model.fit(X_train, y_train)
predictions = model.predict(X_test)

In [119]:
model.predict_proba(X_test)

array([[0.5382958 , 0.4617042 ],
       [0.08708639, 0.91291361],
       [0.01447809, 0.98552191],
       [0.98609062, 0.01390938],
       [0.95472464, 0.04527536],
       [0.9555808 , 0.0444192 ],
       [0.48962664, 0.51037336],
       [0.01586297, 0.98413703],
       [0.98936451, 0.01063549],
       [0.93932158, 0.06067842],
       [0.56297653, 0.43702347],
       [0.9868671 , 0.0131329 ],
       [0.26196894, 0.73803106],
       [0.34771474, 0.65228526],
       [0.00876315, 0.99123685],
       [0.8126588 , 0.1873412 ],
       [0.36216695, 0.63783305],
       [0.40831253, 0.59168747],
       [0.98068838, 0.01931162],
       [0.99370844, 0.00629156],
       [0.00281964, 0.99718036],
       [0.4660448 , 0.5339552 ],
       [0.81838646, 0.18161354],
       [0.97918875, 0.02081125],
       [0.25718622, 0.74281378],
       [0.9611881 , 0.0388119 ],
       [0.94115773, 0.05884227],
       [0.0954939 , 0.9045061 ],
       [0.96251114, 0.03748886],
       [0.00227164, 0.99772836]])

In [120]:
model.score(X_train, y_train) 

0.9857142857142858

In [121]:
model.score(X_test, y_test)

0.9

### Классический метод градиентного спуска

In [112]:
EPOCHS = 1000
LEARNING_RATE = 0.0001

def cost_function(X, y, theta0, theta1):
    total_cost = 0
    for i in range(len(X)):
        total_cost += (theta0 + theta1*X[i] - y[i]) ** 2
    return total_cost / (2 * len(X))

def der_theta0(X, y, theta0, theta1):
    total_cost = 0
    for i in range(len(X)):
        total_cost += (theta0 + theta1*X[i] - y[i])
    return total_cost / (len(X))    

def der_theta1(X, y, theta0, theta1):
    total_cost = 0
    for i in range(len(X)):
        total_cost += (theta0 + theta1*X[i] - y[i]) * X[i]
    return total_cost / (len(X))  

theta0 = 1
theta1 = 1
for _ in range(EPOCHS):
    dt0 = der_theta0(X, y, theta0, theta1)
    dt1 = der_theta1(X, y, theta0, theta1)
     
    theta0 = theta0 - LEARNING_RATE * dt0
    theta1 -= LEARNING_RATE * dt1
    
    print("t0:", theta0, "t1:", theta1, "cost:", cost_function(X, y, theta0, theta1))

t0: [0.9994238 0.9997628 0.9995594 0.9998824] t1: [0.99636464 0.9993129  0.99780324 0.99980254] cost: [16.64459678  2.93671877  9.79566572  0.7302719 ]
t0: [0.99884993 0.99952582 0.99911992 0.99976484] t1: [0.99274405 0.99862644 0.99561213 0.99960516] cost: [16.51049018  2.9314474   9.74578897  0.72974422]
t0: [0.99827839 0.99928906 0.99868156 0.99964733] t1: [0.98913818 0.99794063 0.99342666 0.99940786] cost: [16.3774713   2.92618588  9.69616843  0.72921695]
t0: [0.99770917 0.99905253 0.99824432 0.99952987] t1: [0.98554696 0.99725545 0.99124681 0.99921063] cost: [16.24553132  2.9209342   9.64680279  0.7286901 ]
t0: [0.99714225 0.99881621 0.99780819 0.99941245] t1: [0.98197034 0.99657092 0.98907257 0.99901349] cost: [16.11466148  2.91569232  9.59769072  0.72816367]
t0: [0.99657762 0.99858011 0.99737317 0.99929507] t1: [0.97840826 0.99588702 0.98690393 0.99881642] cost: [15.98485311  2.91046024  9.54883094  0.72763765]
t0: [0.99601529 0.99834423 0.99693926 0.99917774] t1: [0.97486064 0.

t0: [0.96612897 0.9845866  0.97282738 0.99221913] t1: [0.78639823 0.95535631 0.86464452 0.98694271] cost: [9.76979563 2.60911668 6.99772074 0.69632201]
t0: [0.96568991 0.98436376 0.97245591 0.9921045 ] t1: [0.78363096 0.95471095 0.86279601 0.98675046] cost: [9.6914496  2.6044574  6.96221657 0.69582108]
t0: [0.96525264 0.98414114 0.97208537 0.99198991] t1: [0.78087494 0.9540662  0.86095226 0.98655828] cost: [9.61373902 2.59980683 6.92689478 0.69532055]
t0: [0.96481713 0.98391871 0.97171578 0.99187536] t1: [0.77813012 0.95342204 0.85911324 0.98636618] cost: [9.53665874 2.59516494 6.89175444 0.69482041]
t0: [0.96438338 0.9836965  0.97134713 0.99176086] t1: [0.77539645 0.95277849 0.85727896 0.98617415] cost: [9.46020365 2.59053172 6.8567946  0.69432066]
t0: [0.96395139 0.98347449 0.97097941 0.9916464 ] t1: [0.77267389 0.95213554 0.85544941 0.98598221] cost: [9.38436868 2.58590717 6.82201434 0.69382131]
t0: [0.96352114 0.98325269 0.97061263 0.99153199] t1: [0.7699624  0.95149319 0.85362455 

t0: [0.92908267 0.96269155 0.9391211  0.98063834] t1: [0.55312397 0.89195639 0.69720096 0.96753781] cost: [4.28536397 2.17219275 4.15242652 0.64674585]
t0: [0.92879339 0.96248911 0.93883514 0.98052812] t1: [0.55130477 0.89137031 0.6957833  0.96735329] cost: [4.2515014  2.16834981 4.13153812 0.64628398]
t0: [0.92850529 0.96228686 0.9385499  0.98041794] t1: [0.54949297 0.89078477 0.69436929 0.96716885] cost: [4.21791349 2.16451405 4.11075702 0.64582248]
t0: [0.92821834 0.96208479 0.93826539 0.9803078 ] t1: [0.54768853 0.89019979 0.69295892 0.96698448] cost: [4.184598   2.16068546 4.09008267 0.64536135]
t0: [0.92793256 0.96188292 0.9379816  0.9801977 ] t1: [0.54589143 0.88961535 0.69155218 0.96680019] cost: [4.15155273 2.15686401 4.06951451 0.64490057]
t0: [0.92764793 0.96168123 0.93769852 0.98008765] t1: [0.54410163 0.88903145 0.69014906 0.96661596] cost: [4.11877548 2.15304971 4.04905202 0.64444016]
t0: [0.92736445 0.96147974 0.93741617 0.97997763] t1: [0.5423191  0.88844811 0.68874954 

t0: [0.9076192  0.94557978 0.91653492 0.97109748] t1: [0.41830748 0.84242375 0.58543624 0.95157893] cost: [2.14935789 1.86008371 2.67057672 0.60746398]
t0: [0.9074165  0.94539328 0.91630605 0.97099089] t1: [0.41703624 0.84188399 0.58430621 0.95140078] cost: [2.13282016 1.8568239  2.65730028 0.60703307]
t0: [0.90721461 0.94520695 0.91607776 0.97088433] t1: [0.41577017 0.84134472 0.58317909 0.9512227 ] cost: [2.11641656 1.85357019 2.64409204 0.6066025 ]
t0: [0.90701353 0.9450208  0.91585005 0.97077782] t1: [0.41450924 0.84080596 0.58205486 0.9510447 ] cost: [2.100146   1.85032256 2.63095165 0.60617227]
t0: [0.90681327 0.94483482 0.9156229  0.97067135] t1: [0.41325344 0.84026771 0.58093353 0.95086676] cost: [2.08400742 1.84708099 2.61787875 0.60574237]
t0: [0.9066138  0.94464901 0.91539634 0.97056491] t1: [0.41200274 0.83972995 0.57981509 0.95068889] cost: [2.06799973 1.84384548 2.60487301 0.60531282]
t0: [0.90641515 0.94446337 0.91517034 0.97045852] t1: [0.41075713 0.83919271 0.57869953 

t0: [0.89727585 0.93522165 0.90436878 0.9650858 ] t1: [0.35351141 0.81244884 0.52545475 0.94153704] cost: [1.39215891 1.68367065 2.01308727 0.58343617]
t0: [0.89711475 0.9350448  0.90417056 0.96498149] t1: [0.35250355 0.8119371  0.5244791  0.9413629 ] cost: [1.38176269 1.68074045 2.0031882  0.58302419]
t0: [0.8969543  0.9348681  0.90397283 0.96487722] t1: [0.35149978 0.81142583 0.52350596 0.94118883] cost: [1.3714508  1.67781573 1.99333999 0.58261253]
t0: [0.8967945  0.93469158 0.9037756  0.96477299] t1: [0.3505001  0.81091505 0.52253533 0.94101483] cost: [1.36122255 1.67489647 1.98354236 0.5822012 ]
t0: [0.89663533 0.93451521 0.90357887 0.9646688 ] t1: [0.34950448 0.81040474 0.5215672  0.94084089] cost: [1.35107726 1.67198267 1.97379505 0.58179019]
t0: [0.89647681 0.93433901 0.90338263 0.96456465] t1: [0.3485129  0.80989491 0.52060156 0.94066703] cost: [1.34101425 1.66907431 1.96409782 0.58137951]
t0: [0.89631893 0.93416298 0.90318689 0.96446053] t1: [0.34752536 0.80938556 0.5196384  

t0: [0.88323778 0.91720186 0.88574727 0.95414621] t1: [0.26592058 0.76031957 0.43410178 0.92329265] cost: [0.64663893 1.39925125 1.19665815 0.54114548]
t0: [0.88313294 0.91704178 0.88559573 0.95404605] t1: [0.26526878 0.75985657 0.43336131 0.9231258 ] cost: [0.6422895  1.39685245 1.19095285 0.54076682]
t0: [0.88302852 0.91688184 0.88544456 0.95394593] t1: [0.26461964 0.75939401 0.43262275 0.922959  ] cost: [0.63797535 1.39445814 1.18527686 0.54038846]
t0: [0.88292451 0.91672206 0.88529377 0.95384585] t1: [0.26397313 0.75893188 0.43188608 0.92279228] cost: [0.63369619 1.3920683  1.17963003 0.5400104 ]
t0: [0.88282092 0.91656242 0.88514336 0.95374581] t1: [0.26332926 0.75847018 0.43115132 0.92262562] cost: [0.62945173 1.38968293 1.1740122  0.53963264]
t0: [0.88271774 0.91640293 0.88499332 0.9536458 ] t1: [0.262688   0.75800892 0.43041844 0.92245903] cost: [0.62524171 1.38730201 1.16842322 0.53925517]
t0: [0.88261497 0.91624359 0.88484366 0.95354583] t1: [0.26204934 0.75754808 0.42968745 

t0: [0.8749319  0.90264022 0.87291309 0.94479869] t1: [0.21447083 0.71821329 0.37162935 0.90773509] cost: [0.35776959 1.19026627 0.7669185  0.50647311]
t0: [0.8748601  0.90249368 0.87279347 0.94470208] t1: [0.21402822 0.71778967 0.37104975 0.90757445] cost: [0.35576311 1.18825794 0.76342061 0.50612177]
t0: [0.87478859 0.90234728 0.87267416 0.9446055 ] t1: [0.21358742 0.71736644 0.37047165 0.90741387] cost: [0.35377291 1.18625336 0.7599407  0.50577071]
t0: [0.87471736 0.90220102 0.87255514 0.94450895] t1: [0.2131484  0.71694361 0.36989503 0.90725335] cost: [0.35179885 1.18425252 0.75647866 0.50541992]
t0: [0.87464642 0.9020549  0.87243641 0.94441245] t1: [0.21271117 0.71652118 0.3693199  0.9070929 ] cost: [0.3498408  1.18225543 0.7530344  0.50506941]
t0: [0.87457576 0.9019089  0.87231798 0.94431598] t1: [0.21227572 0.71609914 0.36874626 0.90693252] cost: [0.34789863 1.18026206 0.74960783 0.50471918]
t0: [0.87450537 0.90176305 0.87219984 0.94421954] t1: [0.21184204 0.7156775  0.36817409 

t0: [0.86856136 0.88744418 0.86146976 0.93447987] t1: [0.17543692 0.67429449 0.3164862  0.89059753] cost: [0.21221067 0.99204247 0.4744044  0.46976146]
t0: [0.86851464 0.88731178 0.86137835 0.93438715] t1: [0.17515306 0.67391194 0.31604866 0.89044373] cost: [0.21138477 0.99040449 0.47240901 0.46943905]
t0: [0.86846811 0.8871795  0.86128715 0.93429448] t1: [0.17487036 0.67352975 0.31561225 0.89029   ] cost: [0.21056557 0.98876958 0.47042387 0.46911688]
t0: [0.86842176 0.88704735 0.86119619 0.93420183] t1: [0.1745888  0.67314791 0.31517697 0.89013633] cost: [0.20975302 0.98713771 0.46844892 0.46879497]
t0: [0.86837559 0.88691531 0.86110544 0.93410923] t1: [0.17430839 0.67276644 0.3147428  0.88998272] cost: [0.20894705 0.9855089  0.46648412 0.46847332]
t0: [0.8683296  0.8867834  0.86101492 0.93401665] t1: [0.17402913 0.67238531 0.31430976 0.88982917] cost: [0.20814762 0.98388313 0.46452941 0.46815192]
t0: [0.86828379 0.88665162 0.86092462 0.93392412] t1: [0.173751   0.67200455 0.31387783 

t0: [0.86465226 0.87467456 0.85327686 0.92529065] t1: [0.1518885  0.63740923 0.2775344  0.87537086] cost: [0.1551173  0.84115085 0.31685619 0.43844735]
t0: [0.86462068 0.87455402 0.85320537 0.92520141] t1: [0.15170045 0.63706118 0.27719726 0.87522316] cost: [0.15475445 0.8397948  0.31567    0.43814961]
t0: [0.86458923 0.8744336  0.85313406 0.9251122 ] t1: [0.15151317 0.63671345 0.27686099 0.87507551] cost: [0.15439454 0.83844127 0.3144899  0.43785209]
t0: [0.86455789 0.8743133  0.85306292 0.92502303] t1: [0.15132665 0.63636604 0.27652559 0.87492793] cost: [0.15403756 0.83709028 0.31331586 0.43755481]
t0: [0.86452667 0.8741931  0.85299195 0.92493389] t1: [0.15114089 0.63601897 0.27619105 0.8747804 ] cost: [0.15368347 0.83574182 0.31214785 0.43725777]
t0: [0.86449558 0.87407302 0.85292115 0.92484478] t1: [0.15095589 0.63567221 0.27585738 0.87463293] cost: [0.15333225 0.83439587 0.31098584 0.43696096]
t0: [0.8644646  0.87395304 0.85285052 0.92475571] t1: [0.15077163 0.63532579 0.27552456 

t0: [0.86240107 0.86504219 0.84789194 0.91799703] t1: [0.1386211  0.60960112 0.25231485 0.86331033] cost: [0.13313655 0.73680109 0.23653168 0.4145158 ]
t0: [0.86237802 0.86493061 0.84783337 0.91791054] t1: [0.13848707 0.60927907 0.25204275 0.86316745] cost: [0.13295197 0.73564    0.23575802 0.4142369 ]
t0: [0.86235507 0.86481913 0.84777493 0.91782408] t1: [0.13835358 0.60895733 0.25177136 0.86302463] cost: [0.13276887 0.73448108 0.23498833 0.41395823]
t0: [0.86233219 0.86470776 0.84771663 0.91773765] t1: [0.13822064 0.60863589 0.25150067 0.86288186] cost: [0.13258727 0.73332432 0.23422259 0.41367977]
t0: [0.86230941 0.86459649 0.84765848 0.91765126] t1: [0.13808823 0.60831475 0.25123068 0.86273916] cost: [0.13240713 0.73216973 0.23346078 0.41340154]
t0: [0.86228671 0.86448532 0.84760046 0.9175649 ] t1: [0.13795637 0.6079939  0.25096138 0.86259651] cost: [0.13222846 0.73101729 0.23270289 0.41312352]
t0: [0.86226409 0.86437425 0.84754257 0.91747857] t1: [0.13782504 0.60767336 0.25069278 

t0: [0.86061813 0.85530117 0.84309108 0.91025832] t1: [0.12841218 0.581495   0.23022155 0.8505398 ] cost: [0.1212238  0.63955258 0.18017588 0.39001535]
t0: [0.86060165 0.85519864 0.84304382 0.91017475] t1: [0.12831973 0.58119924 0.23000648 0.85040203] cost: [0.12113581 0.63857319 0.17969162 0.38975576]
t0: [0.86058524 0.8550962  0.84299668 0.9100912 ] t1: [0.12822767 0.58090376 0.22979196 0.85026432] cost: [0.12104853 0.63759563 0.17920984 0.38949637]
t0: [0.86056889 0.85499385 0.84294964 0.91000769] t1: [0.12813598 0.58060856 0.229578   0.85012666] cost: [0.12096196 0.63661989 0.17873053 0.38923718]
t0: [0.86055259 0.8548916  0.84290271 0.90992421] t1: [0.12804466 0.58031363 0.22936459 0.84998906] cost: [0.12087609 0.63564598 0.17825369 0.3889782 ]
t0: [0.86053635 0.85478945 0.8428559  0.90984076] t1: [0.12795372 0.58001898 0.22915173 0.84985151] cost: [0.12079092 0.63467388 0.17777929 0.38871942]
t0: [0.86052017 0.85468739 0.84280919 0.90975734] t1: [0.12786314 0.57972461 0.22893942 

t0: [0.85926489 0.84578562 0.83898115 0.90229603] t1: [0.12099401 0.55405741 0.21174688 0.83743012] cost: [0.115298   0.55258859 0.14310358 0.36576338]
t0: [0.85925319 0.84569192 0.83894337 0.90221545] t1: [0.12093183 0.55378732 0.21157954 0.8372976 ] cost: [0.11525804 0.55177167 0.14280965 0.36552289]
t0: [0.85924154 0.8455983  0.83890568 0.90213489] t1: [0.1208699  0.55351748 0.21141263 0.83716514] cost: [0.11521841 0.55095629 0.14251723 0.36528259]


In [122]:
cost_function(X,y,10,10)

array([2549.27,  697.76, 1687.21,  326.34])

# Вопрос №1
* подскажите как теперь я могу пользоваться моделью
* что, куда подставлять
* не совсем понимаю, за счет чего модель могла обучиться градиентному спуску
* 

### Метод градиентного спуска, используя библиотеку NumPy

In [126]:
# h(xi) = params[0] + params[1] * xi

EPOCHS = 12
LEARNING_RATE = 0.0001

costs = []
params = []
preds = []
params = np.random.normal(size=(2,))

for _ in range(EPOCHS):
    predictions = params[0] + params[1] * X
    preds.append(predictions)

    cost = np.sum(np.square(predictions - y)) / (2 * len(predictions))
    costs.append(cost)
    
    params[0] -= LEARNING_RATE * np.sum(predictions - y) / len(predictions)
    params[1] -= LEARNING_RATE * np.sum((predictions - y) * X) / len(predictions)    

ValueError: operands could not be broadcast together with shapes (100,4) (100,) 

# Вопрос №2
* насколько я понимаю тут много параметров (отсюда и конфликт)