# Feature performance comparison

The purpose of this notebook is to compare the classification performance of the individual features, and their combination, for E Coli

# Imports

In [1]:
# from sklearnex import patch_sklearn
# patch_sklearn()
import os
import sys
from IPython.display import display

sys.path.append('../src')
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.svm import SVC, LinearSVC
from sklearn.feature_selection import SelectKBest
from sklearn.decomposition import PCA, KernelPCA
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.feature_selection import SelectKBest, RFE, VarianceThreshold
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict, GridSearchCV
from sklearn.linear_model import SGDClassifier
from sklearn.base import clone
from scipy.stats import shapiro
import matplotlib.pyplot as plt
from sklearn.base import BaseEstimator,TransformerMixin

from yellowbrick.features import ParallelCoordinates
from yellowbrick.features import Rank1D, Rank2D

import pandas as pd
import numpy as np
import seaborn as sns

from dataset.transporter_dataset import create_dataset
from dataset.cluster_fasta import cd_hit
from features.labels import fasta_to_labels
from features.compositions import calculate_composition_feature
from features.pssm import calculate_pssm_feature
from features.coexp import calculate_coexp_feature
from models.eval import nested_crossval
from visualization.feature_plots import create_plot

# Globals

In [2]:
ORGANISM = "ecoli"
TAX_ID = 83333
LOG_FILE = f"../logs/{ORGANISM}_amino_sugar.log"
N_THREADS = 16
IDENTITY_THRESHOLD=70
OUTLIERS = [
    "P0AAG8",
    "P04983",
    "P10346",
    "P28635",
    "P14175",
    "P69797",
    "P76773",
    "P37388",
    "P30750",
    "Q47706",
    "P10907",
    "P37774",
    "P0AAF3",
    "P64550",
    "P02943",
    "P75733",
    "P0AAF6",
    "P68187",
    "P69856",
    "P07109",
    "Q6BEX0",
]



# Dataset

In [3]:
# Delete previous log
if os.path.exists(LOG_FILE):
    with open(LOG_FILE, "w"):
        pass

create_dataset(
    keywords_substrate_filter=["Amino-acid transport", "Sugar transport"],
    keywords_component_filter=["Membrane"],
    keywords_transport_filter=["Transport"],
    input_file="../data/raw/swissprot/uniprot-reviewed_yes.tab.gz",
    multi_substrate="remove",
    outliers=OUTLIERS,
    verbose=True,
    tax_ids_filter=[TAX_ID],
    output_tsv=f"../data/datasets/{ORGANISM}_amino_sugar.tsv",
    output_fasta=f"../data/datasets/{ORGANISM}_amino_sugar.fasta",
    output_log=LOG_FILE,
)

Unnamed: 0_level_0,keywords_transport,keywords_location,keywords_transport_related,gene_names,protein_names,tcdb_id,organism_id,sequence
Uniprot,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
P69801,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,manY pel ptsP b1818 JW1807,PTS system mannose-specific EIIC component (EI...,4.A.6.1.1,83333,MEITTLQIVLVFIVACIAGMGSILDEFQFHRPLIACTLVGIVLGDM...
P36672,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,treB b4240 JW4199,PTS system trehalose-specific EIIBC component ...,4.A.1.2.4,83333,MMSKINQTDIDRLIELVGGRGNIATVSHCITRLRFVLNQPANARPK...
P56580,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,srlE gutA gutE b2703 JW5430,PTS system glucitol/sorbitol-specific EIIB com...,4.A.4.1.1,83333,MTHIRIEKGTGGWGGPLELKATPGKKIVYITAGTRPAIVDKLAQLT...
P0AA47,Amino-acid transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,plaP yeeF b2014 JW5330,Low-affinity putrescine importer PlaP,2.A.3.1.14,83333,MSHNVTPNTSRVELRKTLTLVPVVMMGLAYMQPMTLFDTFGIVSGL...
P08722,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,bglF bglC bglS b3722 JW3700,PTS system beta-glucoside-specific EIIBCA comp...,4.A.1.2.2,83333,MTELARKIVAGVGGADNIVSLMHCATRLRFKLKDESKAQAEVLKKT...
...,...,...,...,...,...,...,...,...
P19642,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,malX b1621 JW1613,PTS system maltose-specific EIICB component [I...,4.A.1.1.3,83333,MTAKTAPKVTLWEFFQQLGKTFMLPVALLSFCGIMLGIGSSLSSHD...
P0AAD4,Amino-acid transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,tyrP b1907 JW1895,Tyrosine-specific transport protein (Tyrosine ...,2.A.42.1.1,83333,MKNRTLGSVFIVAGTTIGAGMLAMPLAAAGVGFSVTLILLIGLWAL...
P23173,Amino-acid transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,tnaB trpP b3709 JW5619/JW5622,Low affinity tryptophan permease,2.A.42.1.3,83333,MTDQAEKKHSAFWGVMVIAGTVIGGGMFALPVDLAGAWFFWGAFIL...
P33361,Amino-acid transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,yehY b2130 JW2118,Glycine betaine uptake system permease protein...,3.A.1.12.15,83333,MTYFRINPVLALLLLLTAIAAALPFISYAPNRLVSGEGRHLWQLWP...


## Clustering

In [4]:
cd_hit(
    executable_location="cd-hit",
    input_fasta=f"../data/datasets/{ORGANISM}_amino_sugar.fasta",
    output_fasta=f"../data/datasets/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}.fasta",
    log_file=LOG_FILE,
    identity_threshold=IDENTITY_THRESHOLD,
    n_threads=N_THREADS,
    memory=4096,
    verbose=True,
)

## Annotations

In [5]:
df_annotations = pd.read_table(f"../data/datasets/{ORGANISM}_amino_sugar.tsv", index_col=0)
df_annotations.head()

Unnamed: 0_level_0,keywords_transport,keywords_location,keywords_transport_related,gene_names,protein_names,tcdb_id,organism_id,sequence
Uniprot,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
P69801,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,manY pel ptsP b1818 JW1807,PTS system mannose-specific EIIC component (EI...,4.A.6.1.1,83333,MEITTLQIVLVFIVACIAGMGSILDEFQFHRPLIACTLVGIVLGDM...
P36672,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,treB b4240 JW4199,PTS system trehalose-specific EIIBC component ...,4.A.1.2.4,83333,MMSKINQTDIDRLIELVGGRGNIATVSHCITRLRFVLNQPANARPK...
P56580,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,srlE gutA gutE b2703 JW5430,PTS system glucitol/sorbitol-specific EIIB com...,4.A.4.1.1,83333,MTHIRIEKGTGGWGGPLELKATPGKKIVYITAGTRPAIVDKLAQLT...
P0AA47,Amino-acid transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,plaP yeeF b2014 JW5330,Low-affinity putrescine importer PlaP,2.A.3.1.14,83333,MSHNVTPNTSRVELRKTLTLVPVVMMGLAYMQPMTLFDTFGIVSGL...
P08722,Sugar transport,Cell inner membrane;Cell membrane;Membrane;Tra...,Transport,bglF bglC bglS b3722 JW3700,PTS system beta-glucoside-specific EIIBCA comp...,4.A.1.2.2,83333,MTELARKIVAGVGGADNIVSLMHCATRLRFKLKDESKAQAEVLKKT...


# Feature generation

## Labels

In [6]:
fasta_to_labels(
    input_fasta=f"../data/datasets/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}.fasta",
    output_tsv=f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_labels.tsv",
)
df_labels = pd.read_table(
    f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_labels.tsv",
    index_col=0,
)
df_labels.labels.value_counts()

Amino-acid transport    49
Sugar transport         46
Name: labels, dtype: int64

## AAC, PAAC

In [7]:
for composition_type in ["aac", "paac"]:
    calculate_composition_feature(
        input_fasta=f"../data/datasets/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}.fasta",
        output_tsv=f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_{composition_type}.tsv",
        feature_type=composition_type,
    )

## PSSM

In [8]:
for uniref_cluster_threshold in [50, 90]:
    for psiblast_iterations in [1, 3]:
        calculate_pssm_feature(
            input_fasta=f"../data/datasets/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}.fasta",
            output_tsv="../data/features/{}_amino_sugar_cluster{}_pssm_ur{}_{}it.tsv".format(
                ORGANISM, IDENTITY_THRESHOLD, uniref_cluster_threshold, psiblast_iterations
            ),
            tmp_folder="../data/intermediate/blast/pssm_uniref{}_{}it".format(
                uniref_cluster_threshold, psiblast_iterations
            ),
            blast_db="../data/raw/uniref/uniref{}/uniref{}.fasta".format(
                uniref_cluster_threshold, uniref_cluster_threshold
            ),
            iterations=psiblast_iterations,
            psiblast_executable="psiblast",
            psiblast_threads=N_THREADS,
            verbose=False,
        )

## COEXP

In [9]:
# TODO: optimize parameters first

## Reading dataframes

In [10]:

df_aac = pd.read_table(
    f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_aac.tsv",
    index_col=0,
)
df_paac = pd.read_table(
    f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_paac.tsv",
    index_col=0,
)
df_pssm_50_1it = pd.read_table(
    f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_pssm_ur50_1it.tsv",
    index_col=0,
)
df_pssm_50_3it = pd.read_table(
    f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_pssm_ur50_3it.tsv",
    index_col=0,
)
df_pssm_90_1it = pd.read_table(
    f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_pssm_ur90_1it.tsv",
    index_col=0,
)
df_pssm_90_3it = pd.read_table(
    f"../data/features/{ORGANISM}_amino_sugar_cluster{IDENTITY_THRESHOLD}_pssm_ur90_3it.tsv",
    index_col=0,
)

## Combining dataframes

In [11]:
df_all = pd.concat(
    [
        df_aac,
        df_paac,
        df_pssm_50_1it.rename(columns=lambda c: "PSSM_" + c + "_50_1"),
        df_pssm_50_3it.rename(columns=lambda c: "PSSM_" + c + "_50_3"),
        df_pssm_90_1it.rename(columns=lambda c: "PSSM_" + c + "_90_1"),
        df_pssm_90_3it.rename(columns=lambda c: "PSSM_" + c + "_90_3"),
    ],
    axis=1,
)

df_all

Unnamed: 0_level_0,A,C,D,E,F,G,H,I,K,L,...,PSSM_VL_90_3,PSSM_VK_90_3,PSSM_VM_90_3,PSSM_VF_90_3,PSSM_VP_90_3,PSSM_VS_90_3,PSSM_VT_90_3,PSSM_VW_90_3,PSSM_VY_90_3,PSSM_VV_90_3
Uniprot,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
P69801,0.150376,0.007519,0.030075,0.022556,0.037594,0.093985,0.015038,0.120301,0.011278,0.101504,...,0.496933,0.487730,0.503067,0.558282,0.450920,0.472393,0.475460,0.490798,0.592025,0.481595
P36672,0.093023,0.012685,0.023256,0.023256,0.044397,0.103594,0.016913,0.105708,0.029598,0.120507,...,0.436330,0.421348,0.436330,0.556180,0.370787,0.464419,0.423221,0.503745,0.644195,0.436330
P56580,0.097179,0.012539,0.040752,0.034483,0.034483,0.122257,0.015674,0.100313,0.040752,0.100313,...,0.410876,0.389728,0.407855,0.504532,0.371601,0.413897,0.404834,0.444109,0.555891,0.398792
P0AA47,0.106195,0.006637,0.028761,0.026549,0.077434,0.070796,0.013274,0.075221,0.030973,0.110619,...,0.464052,0.416122,0.472767,0.640523,0.350763,0.420479,0.424837,0.570806,0.838780,0.461874
P08722,0.100800,0.012800,0.038400,0.033600,0.062400,0.100800,0.016000,0.092800,0.035200,0.108800,...,0.443131,0.437223,0.454948,0.514032,0.410635,0.438700,0.440177,0.472674,0.584934,0.447563
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
P19642,0.090566,0.013208,0.026415,0.024528,0.066038,0.105660,0.024528,0.084906,0.032075,0.120755,...,0.480253,0.461295,0.481833,0.532385,0.454976,0.464455,0.464455,0.519747,0.593997,0.473934
P0AAD4,0.116625,0.007444,0.019851,0.012407,0.064516,0.106700,0.019851,0.062035,0.017370,0.171216,...,0.525449,0.494012,0.519461,0.622754,0.479042,0.510479,0.517964,0.574850,0.718563,0.508982
P23173,0.084337,0.016867,0.024096,0.012048,0.093976,0.096386,0.014458,0.113253,0.040964,0.120482,...,0.546237,0.464516,0.548387,0.608602,0.455914,0.516129,0.511828,0.531183,0.752688,0.529032
P33361,0.135065,0.012987,0.025974,0.007792,0.044156,0.083117,0.012987,0.067532,0.015584,0.197403,...,0.547009,0.527066,0.539886,0.574074,0.521368,0.531339,0.532764,0.571225,0.589744,0.544160


## Custom Transformer to try all parameters

New version: Ignores all additional features that don't start with "PSSM"

In [12]:
class PSSMSelector(BaseEstimator, TransformerMixin):
    def __init__(self, feature_names, uniref_threshold="all", iterations="all"):
        self.feature_names = feature_names
        self.uniref_threshold = uniref_threshold
        self.iterations = iterations

    def fit(self, X, y=None):
        if self.uniref_threshold in {50, 90}:
            has_uniref = (
                np.char.find(self.feature_names, str(self.uniref_threshold)) >= 0
            )
        elif self.uniref_threshold == "all":
            has_uniref = np.array([True] * len(self.feature_names))
        else:
            raise ValueError(f"Incorrect uniref threshold {self.uniref_threshold}")

        if self.iterations in {1, 3}:
            has_iterations = np.char.find(self.feature_names, str(self.iterations)) >= 0
        elif self.iterations == "all":
            has_iterations = np.array([True] * len(self.feature_names))
        else:
            raise ValueError(f"Incorrect iteration count: {self.iterations}")
        
        not_pssm_feature = ~np.char.startswith(self.feature_names, "PSSM")
        self.mask = np.bitwise_or(np.bitwise_and(has_uniref, has_iterations), not_pssm_feature)
        return self

    def transform(self, X, y=None):
        X = np.array(X)
        X = X[:, self.mask]
        return X


# Functions

In [13]:
def get_feature_stats(df_features, df_labels_, labels=["Amino-acid transport", "Sugar transport"]):
    df_stats = pd.concat(
        {
            "corr": df_features.corrwith(
                df_labels_.labels.transform(lambda x: 1.0 if x == labels[1] else 0.0)
            ),
            "mean": df_features.mean(),
            "std": df_features.std(),
        },
        axis=1,
    )

    df_stats["sum"] = df_stats.sum(axis=1)
    df_stats["corr_abs"] = df_stats["corr"].abs()

    df_stats["mean0"] = df_features.loc[df_labels_[df_labels_.labels == labels[0]].index].mean()
    df_stats["mean1"] = df_features.loc[df_labels_[df_labels_.labels == labels[1]].index].mean()

    df_stats["median0"] = df_features.loc[
        df_labels_[df_labels_.labels == labels[0]].index
    ].median()
    df_stats["median1"] = df_features.loc[
        df_labels_[df_labels_.labels == labels[1]].index
    ].median()

    df_stats["mediandiff"] = (df_stats["median0"] - df_stats["median1"]).abs()
    df_stats = df_stats.sort_values("mediandiff", ascending=False)
    return df_stats

In [14]:
def get_independent_test_set(
    df_features, df_labels_, labels=["Amino-acid transport", "Sugar transport"], test_size=0.2
):
    X = df_features.to_numpy()
    y = np.where(df_labels_.labels == labels[1], 1, 0)
    feature_names = df_features.columns.to_numpy(dtype=str)
    sample_names = df_features.index.to_numpy(dtype=str)
    (
        X_train,
        X_test,
        y_train,
        y_test,
        sample_names_train,
        sample_names_test,
    ) = train_test_split(
        X, y, sample_names, stratify=y, random_state=42, shuffle=True, test_size=test_size
    )
    return (
        X_train,
        X_test,
        y_train,
        y_test,
        sample_names_train,
        sample_names_test,
        feature_names,
    )


In [15]:
def print_validation_results(y_true_, y_pred_, labels = ["Amino", "Sugar"]):
    report_dict = classification_report(y_true=y_true_, y_pred=y_pred_, output_dict=True)
    report_dict = {
        labels[0]: report_dict['0'],
        labels[1]: report_dict['1'],
        "Macro": report_dict["macro avg"],
        "Weighted": report_dict["weighted avg"]
    }
    report_df = pd.DataFrame.from_dict(report_dict)
    confusion_matrix_df = pd.DataFrame(
        confusion_matrix(y_true_, y_pred_),
        columns=labels,
        index=labels,
    )
    return report_df, confusion_matrix_df

# Combined Features

### Stats, Plots

Only three of the top 30 features comes from PAAC, all others come from PSSM with Uniref50 and 3 iterations. This is also the best parameter PSSM dataset that was found in the PSSM notebook.

In [16]:
df_stats = get_feature_stats(df_all, df_labels)
df_stats.sort_values("corr_abs", ascending=False).head(30)

Unnamed: 0,corr,mean,std,sum,corr_abs,mean0,mean1,median0,median1,mediandiff
PSSM_NF_50_3,-0.529604,0.539118,0.077356,0.08687,0.529604,0.578603,0.497059,0.568075,0.498067,0.070008
PSSM_NV_50_3,-0.521937,0.54751,0.073979,0.099553,0.521937,0.584725,0.507869,0.575824,0.524589,0.051235
PSSM_NC_50_3,-0.521722,0.522518,0.075725,0.076522,0.521722,0.560595,0.481958,0.548263,0.490975,0.057288
PSSM_NP_50_3,-0.521382,0.568139,0.071179,0.117937,0.521382,0.603907,0.530039,0.60452,0.538767,0.065752
PSSM_NL_50_3,-0.518716,0.538332,0.078501,0.098117,0.518716,0.577577,0.496527,0.567347,0.50934,0.058007
PSSM_IL_50_1,-0.516235,0.913574,0.088909,0.486248,0.516235,0.95781,0.866453,1.0,0.863242,0.136758
PSSM_FR_50_3,-0.514193,0.595827,0.065741,0.147375,0.514193,0.628406,0.561122,0.629344,0.566384,0.06296
PSSM_NY_50_3,-0.513372,0.550437,0.070385,0.107449,0.513372,0.585262,0.51334,0.583673,0.517405,0.066269
PSSM_NI_50_3,-0.512207,0.541947,0.074344,0.104084,0.512207,0.578648,0.502853,0.571429,0.516219,0.05521
PSSM_SS_50_1,-0.506774,0.679742,0.068003,0.240971,0.506774,0.712956,0.644361,0.713636,0.637024,0.076613


### Independent test set

In [17]:
(
    X_train,
    X_test,
    y_train,
    y_test,
    sample_names_train,
    sample_names_test,
    feature_names,
) = get_independent_test_set(df_all, df_labels, test_size=0.2)

### Model selection

Linear SVC could be a good choice to avoid overfitting.

In [18]:
for estimator in [
    LinearSVC(max_iter=1e6, class_weight="balanced", random_state=0),
    SVC(class_weight="balanced"),
    RandomForestClassifier(class_weight="balanced", random_state=0),
    LinearSVC(max_iter=1e6, random_state=0),
    SVC(),
    RandomForestClassifier(random_state=0),
    GaussianNB(),
    KNeighborsClassifier(),
    SGDClassifier(random_state=0),
]:
    pipe = make_pipeline(StandardScaler(), estimator)
    scores = cross_val_score(pipe, X_train, y_train, scoring="f1_macro")
    print("### ", str(estimator))
    print(f"CV folds: {scores.round(3)}")
    print(f"Mean: {scores.mean().round(3)}")
    print(f"Std: {scores.std().round(3)}")


###  LinearSVC(class_weight='balanced', max_iter=1000000.0, random_state=0)
CV folds: [0.875 0.732 0.932 0.933 0.7  ]
Mean: 0.835
Std: 0.1
###  SVC(class_weight='balanced')
CV folds: [0.686 0.732 0.796 0.932 0.796]
Mean: 0.789
Std: 0.083
###  RandomForestClassifier(class_weight='balanced', random_state=0)
CV folds: [0.686 0.732 0.732 0.932 0.732]
Mean: 0.763
Std: 0.086
###  LinearSVC(max_iter=1000000.0, random_state=0)
CV folds: [0.875 0.732 0.932 0.933 0.7  ]
Mean: 0.835
Std: 0.1
###  SVC()
CV folds: [0.686 0.732 0.796 0.932 0.796]
Mean: 0.789
Std: 0.083
###  RandomForestClassifier(random_state=0)
CV folds: [0.686 0.732 0.796 0.932 0.722]
Mean: 0.774
Std: 0.087
###  GaussianNB()
CV folds: [0.435 0.661 0.464 0.732 0.533]
Mean: 0.565
Std: 0.114
###  KNeighborsClassifier()
CV folds: [0.619 0.796 0.667 0.8   0.796]
Mean: 0.736
Std: 0.077
###  SGDClassifier(random_state=0)
CV folds: [0.937 0.866 0.785 0.933 0.661]
Mean: 0.836
Std: 0.104


### Parameter tuning

In [19]:
gsearch = GridSearchCV(
    estimator=make_pipeline(
        PSSMSelector(feature_names=feature_names),
        StandardScaler(),
        LinearSVC(max_iter=1e6, random_state=0),
    ),
    param_grid={
        "pssmselector__uniref_threshold": [50, 90, "all"],
        "pssmselector__iterations": [1, 3, "all"],
        "linearsvc__class_weight": ["balanced", None],
        "linearsvc__C": [1, 10, 100],
        "linearsvc__dual": [True, False],
    },
    cv=5,
    scoring="f1_macro",
    n_jobs=-1,
    return_train_score=True,
)
gsearch.fit(X_train, y_train)
print(gsearch.best_params_)
print(gsearch.best_score_)
best_estimator_lsvc = gsearch.best_estimator_


{'linearsvc__C': 10, 'linearsvc__class_weight': 'balanced', 'linearsvc__dual': False, 'pssmselector__iterations': 1, 'pssmselector__uniref_threshold': 'all'}
0.8346396250808017


In [20]:
gsearch = GridSearchCV(
    estimator=make_pipeline(
        PSSMSelector(feature_names=feature_names),
        StandardScaler(),
        SVC(max_iter=1e6),
    ),
    param_grid={
        "pssmselector__uniref_threshold": [50, 90, "all"],
        "pssmselector__iterations": [1, 3, "all"],
        "svc__class_weight": ["balanced", None],
        "svc__C": [1, 10, 100],
        "svc__gamma": ["scale", 0.1, 0.01, 0.001],
    },
    cv=5,
    scoring="f1_macro",
    n_jobs=-1,
    return_train_score=True,
)
gsearch.fit(X_train, y_train)
print(gsearch.best_params_)
print(gsearch.best_score_)
best_estimator_svc = gsearch.best_estimator_


{'pssmselector__iterations': 'all', 'pssmselector__uniref_threshold': 90, 'svc__C': 10, 'svc__class_weight': 'balanced', 'svc__gamma': 'scale'}
0.8508110680169505


### Dimensionality reduction

In [21]:
pca = PCA()
pca.fit(X_train)
csum = np.cumsum(pca.explained_variance_ratio_)
print("Number of components to explain 97% of variance:", np.argmax(csum >= 0.97) + 1)

Number of components to explain 97% of variance: 39


In [22]:
gsearch = GridSearchCV(
    estimator=make_pipeline(
        PSSMSelector(feature_names=feature_names),
        StandardScaler(),
        PCA(),
        StandardScaler(),
        LinearSVC(max_iter=1e6,random_state=0),
    ),
    param_grid={
        "pssmselector__uniref_threshold": [50, 90, "all"],
        "pssmselector__iterations": [1, 3, "all"],
        "linearsvc__class_weight": ["balanced", None],
        "linearsvc__C": [0.01, 0.1, 1],
        "linearsvc__dual": [True, False],
        "pca__n_components": np.linspace(0.8, 0.99, 20)
    },
    cv=5,
    scoring="f1_macro",
    n_jobs=-1,
    return_train_score=True,
)
gsearch.fit(X_train, y_train)
print(gsearch.best_params_)
print(gsearch.best_score_)
best_estimator_lsvc_pca = gsearch.best_estimator_

{'linearsvc__C': 0.1, 'linearsvc__class_weight': 'balanced', 'linearsvc__dual': True, 'pca__n_components': 0.9, 'pssmselector__iterations': 1, 'pssmselector__uniref_threshold': 'all'}
0.8653777921424982


In [23]:
gsearch = GridSearchCV(
    estimator=make_pipeline(
        PSSMSelector(feature_names=feature_names),
        StandardScaler(),
        PCA(),
        StandardScaler(),
        SVC(max_iter=1e6),
    ),
    param_grid={
        "pssmselector__uniref_threshold": [50, 90, "all"],
        "pssmselector__iterations": [1, 3, "all"],
        "svc__class_weight": ["balanced", None],
        "svc__C": [0.1, 1, 10],
        "svc__gamma": ["scale", 0.01, 0.1, 1],
        "pca__n_components": np.linspace(0.8, 0.99, 20)
    },
    cv=5,
    scoring="f1_macro",
    n_jobs=-1,
    return_train_score=True,
)
gsearch.fit(X_train, y_train)
print(gsearch.best_params_)
print(gsearch.best_score_)
best_estimator_svc_pca = gsearch.best_estimator_

{'pca__n_components': 0.8200000000000001, 'pssmselector__iterations': 'all', 'pssmselector__uniref_threshold': 'all', 'svc__C': 10, 'svc__class_weight': None, 'svc__gamma': 0.01}
0.8787204625439919


### Validation

In [24]:
best_estimator = best_estimator_svc_pca
best_scores = cross_val_score(
    estimator=clone(best_estimator), X=X_train, y=y_train, scoring="f1_macro"
)
print(f"Train scores: {best_scores.mean().round(3)}+-{best_scores.std().round(3)}")

y_pred = best_estimator.predict(X_test)
y_true = y_test.copy()

report_df, confusion_matrix_df = print_validation_results(y_true, y_pred, labels=["Amino", "Sugar"])
display(report_df.round(3))
display(confusion_matrix_df)

Train scores: 0.879+-0.067


Unnamed: 0,Amino,Sugar,Macro,Weighted
precision,0.833,1.0,0.917,0.912
recall,1.0,0.778,0.889,0.895
f1-score,0.909,0.875,0.892,0.893
support,10.0,9.0,19.0,19.0


Unnamed: 0,Amino,Sugar
Amino,10,0
Sugar,2,7


In [25]:
best_estimator = best_estimator_lsvc_pca
best_scores = cross_val_score(
    estimator=clone(best_estimator), X=X_train, y=y_train, scoring="f1_macro"
)
print(f"Train scores: {best_scores.mean().round(3)}+-{best_scores.std().round(3)}")

y_pred = best_estimator.predict(X_test)
y_true = y_test.copy()

report_df, confusion_matrix_df = print_validation_results(y_true, y_pred, labels=["Amino", "Sugar"])
display(report_df.round(3))
display(confusion_matrix_df)

Train scores: 0.865+-0.045


Unnamed: 0,Amino,Sugar,Macro,Weighted
precision,0.9,0.889,0.894,0.895
recall,0.9,0.889,0.894,0.895
f1-score,0.9,0.889,0.894,0.895
support,10.0,9.0,19.0,19.0


Unnamed: 0,Amino,Sugar
Amino,9,1
Sugar,1,8
