In [50]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.utils import Bunch

### Выбираем категории и загружаем dataset

In [51]:
categories = [
    'alt.atheism',
    'soc.religion.christian',
    'comp.graphics',
    'sci.med'
]
twenty_train: Bunch = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42
) # type: ignore

In [52]:
twenty_train.target_names

['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']

### Смотрим на размер выборки

In [53]:
len(twenty_train.data)
len(twenty_train.filenames)

2257

In [54]:
print("\n".join(twenty_train.data[0].split("\n")[:3]))
print(twenty_train.target_names[twenty_train.target[0]])

From: sd345@city.ac.uk (Michael Collier)
Subject: Converting images to HP LaserJet III?
Nntp-Posting-Host: hampton
comp.graphics


In [55]:
twenty_train.target[:10]

array([1, 1, 3, 3, 3, 3, 3, 2, 2, 2])

In [56]:
for t in twenty_train.target[:10]:
    print(twenty_train.target_names[t])

comp.graphics
comp.graphics
soc.religion.christian
soc.religion.christian
soc.religion.christian
soc.religion.christian
soc.religion.christian
sci.med
sci.med
sci.med


### Выполняем предварительную обработку признаков

In [57]:
from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(twenty_train.data)
X_train_counts.shape

(2257, 35788)

In [58]:
count_vect.vocabulary_.get(u'algorithm')

4690

In [59]:
from sklearn.feature_extraction.text import TfidfTransformer

In [60]:
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
X_train_tfidf.shape

(2257, 35788)

### Разделяем выборку на train и test

In [61]:
X = X_train_tfidf
y = twenty_train.target

In [62]:
from sklearn.model_selection import train_test_split

In [63]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

### Обучаем модель

In [64]:
from sklearn.tree import DecisionTreeClassifier

In [65]:
from sklearn.ensemble import RandomForestClassifier

In [66]:
clf = DecisionTreeClassifier(max_depth=10).fit(X_train, y_train)

In [67]:
clf_rf = RandomForestClassifier(max_depth=10).fit(X_train, y_train)

### Оцениваем качество предсказания модели

In [68]:
predicted = clf.predict_proba(X_test)

In [69]:
from sklearn.metrics import classification_report

In [70]:
print(classification_report(y_pred=predicted.argmax(axis=1), y_true=y_test)) # type: ignore

              precision    recall  f1-score   support

           0       0.82      0.67      0.74       143
           1       0.46      0.95      0.62       184
           2       0.89      0.49      0.63       208
           3       0.93      0.61      0.74       210

    accuracy                           0.67       745
   macro avg       0.78      0.68      0.68       745
weighted avg       0.78      0.67      0.68       745



In [71]:
predicted_rf = clf_rf.predict_proba(X_test)
print(classification_report(y_pred=predicted_rf.argmax(axis=1), y_true=y_test)) # type: ignore

              precision    recall  f1-score   support

           0       1.00      0.72      0.84       143
           1       0.68      0.99      0.80       184
           2       0.96      0.75      0.84       208
           3       0.86      0.87      0.86       210

    accuracy                           0.84       745
   macro avg       0.88      0.83      0.84       745
weighted avg       0.87      0.84      0.84       745



### Обработываем новые данные

In [72]:
docs_new = ['God is love', 'OpenGL on the GPU is fast']
X_new_counts = count_vect.transform(docs_new)
X_new_tfidf = tfidf_transformer.transform(X_new_counts)

predicted = clf.predict(X_new_tfidf)

for doc, category in zip(docs_new, predicted):
    print('%r => %s' % (doc, twenty_train.target_names[category]))

'God is love' => soc.religion.christian
'OpenGL on the GPU is fast' => comp.graphics


In [73]:
predicted = clf_rf.predict(X_new_tfidf)

for doc, category in zip(docs_new, predicted):
    print('%r => %s' % (doc, twenty_train.target_names[category]))

'God is love' => comp.graphics
'OpenGL on the GPU is fast' => comp.graphics
