In [1]:
import types
from collections import Counter

from qiime.sdk import Artifact
from qiime.plugins import feature_classifier
import pandas as pd
from sklearn import (pipeline, naive_bayes, feature_selection, 
                     grid_search, cross_validation, metrics, feature_extraction)
import skbio

from q2_feature_classifier._skl import _extract_labels, _extract_features

In [2]:
reference = Artifact.load('85_ref_feat.qza')

In [3]:
read_length = 150 # need to adjust for actual MOCK-6 data
f_primer = 'GTGCCAGCMGCCGCGGTAA' # from mockrobiota/data/mock-6/sample-metadata.tsv
r_primer = 'GGACTACHVGGGTWTCTAAT'
extract_reads = feature_classifier.methods.extract_reads
reads = extract_reads(reference, read_length, f_primer, r_primer, method='position')

In [4]:
def get_seq_id(read):
    if isinstance(read, skbio.DNA):
        return read.metadata['id']
    else:
        return read[0].metadata['id']

In [5]:
word_length = 8
taxonomy_separator = '; '
taxonomy_depth = 6
multioutput = False
cv = 2
read_seqs = list(reads.view(types.GeneratorType))
taxonomy = reference.view(pd.Series)

seq_ids = [get_seq_id(s) for s in read_seqs]
labels = [taxonomy.get(s, 'unknown') for s in seq_ids]
labels = _extract_labels(labels, taxonomy_separator, taxonomy_depth, multioutput)
counted_labels = Counter(labels)
ok_labels = {l for l in counted_labels if counted_labels[l] >= 2*cv}

filtered = [(s, l) for s, l in zip(read_seqs, labels) if l in ok_labels]
read_seqs, y = zip(*filtered)
dummy, X = _extract_features(read_seqs, word_length)

# hold some back for validation
X_train, X_test, y_train, y_test = cross_validation.train_test_split(
    X, y, test_size=0.5, random_state=0, stratify=y)

In [6]:
vectorizer = feature_extraction.DictVectorizer()
classifier = naive_bayes.MultinomialNB()
selector = feature_selection.SelectPercentile()
steps = [('vec', vectorizer), ('sel', selector), ('cls', classifier)]
pipeline = pipeline.Pipeline(steps)
grid_params = {'cls__alpha': [1., 0.01, 0.001],
               'sel__percentile': [100, 10, 1]}
grid = grid_search.GridSearchCV(pipeline, grid_params, cv=cv, n_jobs=4)
grid.fit(X_train, y_train)

GridSearchCV(cv=2, error_score='raise',
       estimator=Pipeline(steps=[('vec', DictVectorizer(dtype=<class 'numpy.float64'>, separator='=', sort=True,
        sparse=True)), ('sel', SelectPercentile(percentile=10,
         score_func=<function f_classif at 0x1114d71e0>)), ('cls', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))]),
       fit_params={}, iid=True, n_jobs=4,
       param_grid={'sel__percentile': [100, 10, 1], 'cls__alpha': [1.0, 0.01, 0.001]},
       pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)

In [7]:
print(grid.best_params_)

{'sel__percentile': 100, 'cls__alpha': 0.01}


In [8]:
y_true, y_pred = y_test, grid.predict(X_test)
report = metrics.classification_report(y_true, y_pred).split('\n')
print(report[0])
print(report[-2])

             precision    recall  f1-score   support
avg / total       0.53      0.47      0.48      1864


  'precision', 'predicted', average, warn_for)


In [9]:
opposite_params = {'cls__alpha': 1., 'sel__percentile': 1}
pipeline.set_params(**opposite_params)
pipeline.fit(X_train, y_train)
y_true, y_pred = y_test, pipeline.predict(X_test)
report = metrics.classification_report(y_true, y_pred).split('\n')
print(report[0])
print(report[-2])

             precision    recall  f1-score   support
avg / total       0.05      0.11      0.05      1864


  'precision', 'predicted', average, warn_for)


In [11]:
X

({'AAAAACCC': 1,
  'AAAACCCG': 1,
  'AAACCCGA': 1,
  'AAACTATC': 1,
  'AAAGGGCG': 1,
  'AACCCGAG': 1,
  'AACTATCT': 1,
  'AACTCTGT': 1,
  'AACTTGGG': 1,
  'AAGATAAG': 1,
  'AAGCGTTA': 1,
  'AAGGGCGC': 1,
  'AAGTCAGA': 1,
  'AATACAGA': 1,
  'AATCGGAT': 1,
  'ACAGAGGG': 1,
  'ACCCGAGC': 1,
  'ACTATCTC': 1,
  'ACTCTGTG': 1,
  'ACTGCATT': 1,
  'ACTGGGCG': 1,
  'ACTTGGGG': 1,
  'AGAGGGTG': 1,
  'AGATAAGT': 1,
  'AGATGTTA': 1,
  'AGCAGCCG': 1,
  'AGCCGCGG': 1,
  'AGCGTTAA': 1,
  'AGCTCAAC': 1,
  'AGGCGGTA': 1,
  'AGGGCGCG': 1,
  'AGGGTGCA': 1,
  'AGTCAGAT': 1,
  'ATAAGTCA': 1,
  'ATACAGAG': 1,
  'ATCGGATT': 1,
  'ATCTCACT': 1,
  'ATGTTAAA': 1,
  'ATTGACTG': 1,
  'ATTTGAAA': 1,
  'CAACTTGG': 1,
  'CAAGCGTT': 1,
  'CAGAGGGT': 1,
  'CAGATGTT': 1,
  'CAGCAGCC': 1,
  'CAGCCGCG': 1,
  'CATTTGAA': 1,
  'CCAGCAGC': 1,
  'CCCGAGCT': 1,
  'CCGAGCTC': 1,
  'CCGCGGTA': 1,
  'CGAGCTCA': 1,
  'CGCGGTAA': 1,
  'CGCGTAGG': 1,
  'CGGATTGA': 1,
  'CGGTAAGA': 1,
  'CGGTAATA': 1,
  'CGTAAAGG': 1,
  'CGTAGGCG': 