In [None]:
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.model_selection import KFold, cross_val_score, GridSearchCV
from sklearn.linear_model import SGDClassifier
import copy

In [None]:
data = pd.read_json('Downloads/News.json')
data_for_dropping = copy.copy(data)

In [None]:
def drop_tag(data, tag):
    X = data[data['tag'] != tag]
    y = X['tag']
    X = X.drop(['tag'], axis=1)
    return X, y

In [None]:
def transform_data(data):
    vect = CountVectorizer()
    data_count = vect.fit_transform(data['text'])
    
    tfidf = TfidfTransformer()
    X = tfidf.fit_transform(data_count)
    return X

In [None]:
y_main = data['tag']
data = data.drop(['tag'], axis=1)
X_main = transform_data(data)

In [None]:
sgd = SGDClassifier(alpha=0.001, class_weight='balanced', l1_ratio=0.25)
sgd.fit(X_main, y_main)

In [None]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)

for i in range(9):
    X, y = drop_tag(data_for_dropping, i)
    X = transform_data(X)
    scores = cross_val_score(estimator=sgd, X=X, y=y, cv=kf.split(X), scoring='accuracy')
    print(i, scores, scores.mean())

In [None]:
parameters = {
    'l1_ratio': (0.05, 0.1, 0.15, 0.2, 0.25, 0.3),
    'alpha': (1e-1, 1e-2, 1e-3, 1e-4),
    'class_weight': (None, 'balanced')
}

gs_cv = GridSearchCV(estimator=sgd, param_grid=parameters, n_jobs=-1)
gs_cv.fit(X, y)

In [None]:
# Обучение на дропнутом наборе, проверка не нем же
#0 [ 0.72566372  0.73451327  0.67256637  0.7079646   0.70535714] 0.709213021492
#1 [ 0.70175439  0.62280702  0.71929825  0.72807018  0.71681416] 0.697748796771
#2 [ 0.62608696  0.74561404  0.65789474  0.71929825  0.75438596] 0.700655987796
#3 [ 0.62365591  0.68817204  0.64516129  0.72826087  0.61956522] 0.660963066854
#4 [ 0.5862069   0.6637931   0.75        0.71551724  0.68103448] 0.679310344828
#5 [ 0.625       0.74107143  0.72321429  0.71428571  0.75892857] 0.7125
#6 [ 0.74545455  0.79090909  0.74545455  0.79816514  0.80733945] 0.777464553795
#7 [ 0.61290323  0.69892473  0.67391304  0.67391304  0.65217391] 0.662365591398
#8 [ 0.7079646   0.69911504  0.66371681  0.69642857  0.74107143] 0.701659292035

In [None]:
# Обучение на всех, проверка на дропнутом
#0 [ 0.71681416  0.75221239  0.69911504  0.7079646   0.6875    ] 0.712721238938
#1 [ 0.70175439  0.66666667  0.74561404  0.71929825  0.69026549] 0.704719764012
#2 [ 0.68695652  0.73684211  0.6754386   0.73684211  0.75438596] 0.718093058734
#3 [ 0.6344086   0.70967742  0.67741935  0.76086957  0.65217391] 0.686909770921
#4 [ 0.62931034  0.65517241  0.72413793  0.71551724  0.65517241] 0.675862068966
#5 [ 0.625       0.73214286  0.71428571  0.70535714  0.77678571] 0.710714285714
#6 [ 0.79090909  0.8         0.73636364  0.82568807  0.77981651] 0.786555462886
#7 [ 0.61290323  0.65591398  0.68478261  0.65217391  0.67391304] 0.655937353904
#8 [ 0.7079646   0.69911504  0.67256637  0.69642857  0.74107143] 0.70342920354