In [1]:
import pandas as pd
import numpy as np

from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, plot_roc_curve
from sklearn.metrics import make_scorer, precision_recall_fscore_support

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.model_selection import GridSearchCV

import seaborn as sns
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [2]:
path = 'https://raw.githubusercontent.com/s-a-nersisyan/HSE_bioinformatics_2021/master/seminar15/BRCA_pam50.tsv'
df = pd.read_csv(path, sep="\t", index_col=0)
X = df.iloc[:, :-1].to_numpy()
y = df["Subtype"].to_numpy()

In [3]:
def make_model(X, y, params=None):
    model = KNeighborsClassifier()

    params = {
        "n_neighbors": [*range(1, 20)],
        "weights": ["uniform", "distance"],
        "p": [1, 2]
    } if not params else params

    cv = GridSearchCV(
        model, params,
        scoring=make_scorer(accuracy_score),
        cv=RepeatedStratifiedKFold(n_repeats=10, n_splits=20)
    )
    cv.fit(X, y)
    
    print('best params:', cv.best_params_)
    print('best score:', cv.best_score_)
    print('\n')

    model = KNeighborsClassifier().set_params(**cv.best_params_)
    
    X_train, X_test, y_train, y_test = train_test_split(
            X, y, stratify=y, test_size=0.35, random_state=17
        )
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    
    print(classification_report(y_test, y_pred))
print(X.shape)
make_model(X, y)

(915, 50)
best params: {'n_neighbors': 5, 'p': 1, 'weights': 'distance'}
best score: 0.9067874396135267


                 precision    recall  f1-score   support

  HER2-enriched       0.94      0.77      0.85        22
        Healthy       0.95      1.00      0.97        35
      Luminal A       0.85      0.99      0.91       144
      Luminal B       0.88      0.71      0.79        65
    Normal-like       1.00      0.12      0.22         8
Triple-negative       1.00      0.98      0.99        47

       accuracy                           0.89       321
      macro avg       0.94      0.76      0.79       321
   weighted avg       0.90      0.89      0.88       321



In [4]:
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)
print(X_pca.shape)
make_model(X_pca, y)

(915, 2)
best params: {'n_neighbors': 15, 'p': 2, 'weights': 'uniform'}
best score: 0.8623091787439614


                 precision    recall  f1-score   support

  HER2-enriched       0.82      0.64      0.72        22
        Healthy       0.88      0.86      0.87        35
      Luminal A       0.82      0.94      0.88       144
      Luminal B       0.75      0.69      0.72        65
    Normal-like       0.00      0.00      0.00         8
Triple-negative       1.00      0.98      0.99        47

       accuracy                           0.84       321
      macro avg       0.71      0.68      0.70       321
   weighted avg       0.82      0.84      0.83       321



In [5]:
tsne = TSNE(n_components=2)
X_tsne = tsne.fit_transform(X)
print(X_tsne.shape)
make_model(X_tsne, y)

(915, 2)
best params: {'n_neighbors': 10, 'p': 1, 'weights': 'distance'}
best score: 0.8912391304347825


                 precision    recall  f1-score   support

  HER2-enriched       1.00      0.82      0.90        22
        Healthy       0.95      1.00      0.97        35
      Luminal A       0.86      0.94      0.90       144
      Luminal B       0.83      0.75      0.79        65
    Normal-like       0.50      0.12      0.20         8
Triple-negative       1.00      1.00      1.00        47

       accuracy                           0.89       321
      macro avg       0.86      0.77      0.79       321
   weighted avg       0.88      0.89      0.88       321

