<a id=top></a>
# DEV: Archetype Classification

This notebook served the purpose of testing and hyperparameter optimization for morphological archetype classification.

## Prep

In [None]:
### Import modules

# External, general
from __future__ import division
import os, sys, pickle, warnings
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# External, specific
from sklearn import model_selection, metrics
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import sklearn.svm as svm
from sklearn.metrics import confusion_matrix
from mpl_toolkits.mplot3d import Axes3D

# Internal
import katachi.utilities.loading as ld
import katachi.utilities.atlas_helpers as atlas

In [None]:
### Settings

# Specify data source
target_dir = r'data/experimentA/extracted_measurements/'
fspace_type = "TFOR"
#fspace_type = "CFOR"

# Cross validation
num_CVs = 5

# Archetype annotation
celltype_decodedict = {0 : 'unclassified',
                       1 : 'innerRosetteCells',
                       2 : 'outerRosetteCells',
                       3 : 'betweenRosetteCells',
                       4 : 'leaderCells'}
celltype_encodedict = {name:key for key,name in celltype_decodedict.iteritems()}

# Visualization options
cellcolors = {0 : 'lightgray',
              1 : 'royalblue',
              2 : 'limegreen',
              3 : 'cyan',
              4 : 'orangered'}

In [None]:
### Load data

# Prep loader
loader = ld.DataLoaderIDR(target_dir, recurse=True, verbose=True)

# Load manual archetype annotations
archetypes, prim_IDs, fspace_idx = loader.load_dataset("_archetype_manual_annotations.tsv")
print "Imported manual archetype annotations of shape:", archetypes.shape

# Load corresponding feature spaces
fspace = loader.load_dataset("_shape_"+fspace_type+"_raw_measured.tsv", IDs=prim_IDs)[0]
print "Imported feature space of shape:", fspace.shape

In [None]:
### Remove cells that are not annotated

# Select
fspace     = fspace[~np.isnan(archetypes), :]
archetypes = archetypes[~np.isnan(archetypes)]

# Report
print "Reduced fspace to shape:    ", fspace.shape
print "Reduced archetypes to shape: ", archetypes.shape

## Classification

In [None]:
### Prep for cross-validation

# Shuffle split for CV
cv_sets = model_selection.StratifiedShuffleSplit(n_splits=num_CVs, test_size=0.3, random_state=42)

# Prepare CV scorers
scoring = {'accuracy'    : 'accuracy',
           'f1_macro'    : 'f1_macro',
           'f1_micro'    : 'f1_micro',
           'f1_weighted' : 'f1_weighted'}

In [None]:
### Build pipeline from preprocessing and regression

# Prep pipeline
pip = []

# Preprocessing
pip.append( ('Standardize', StandardScaler()) )
pip.append( ('PCA', PCA()) )
pip.append( ('Restandardize', StandardScaler()) )
    
# Regressor
svc = svm.SVC(probability=True, kernel='rbf')
pip.append( ('SVC', svc) )

# Pipeline
pip = Pipeline(pip)

In [None]:
### Perform hyperparameter optimization
    
# Param grid
gd = 1.0 / fspace.shape[1]
param_grid = [ {'Standardize'   : [None, StandardScaler()],
                'PCA'           : [None, PCA(15), PCA(30), PCA(50)],
                'Restandardize' : [None, StandardScaler()],
                'SVC__C'        : [0.01, 0.1, 1.0, 10.0, 100.0],
                'SVC__gamma'    : [gd*100.0, gd*10.0, gd*1.0, gd*0.1, gd*0.01]} ]

# Run grid search
clf = model_selection.GridSearchCV(pip, param_grid, cv=cv_sets, n_jobs=num_CVs, verbose=2)
clf.fit(fspace, archetypes)

# Available outputs
print "\nOutputs:"
print sorted(clf.cv_results_.keys())

# Key results
print "\nResults:"
print clf.best_estimator_
print clf.best_score_

# Use best estimator
best = clf.best_estimator_

In [None]:
### Perform cross-validation on best predictor

# Run CV
print "Performing cross-validation..."
scores = model_selection.cross_validate(best, fspace, archetypes, cv=cv_sets, 
                                        scoring=scoring, return_train_score=True, 
                                        n_jobs=num_CVs)

# Report scores
atlas.report_cv_scores(scores)

In [None]:
### Create confusion matrix

# Axis renaming for publication
publ_rename_dict = {'innerRosetteCells' : 'central',
                    'outerRosetteCells' : 'peri',
                    'betweenRosetteCells' : 'inter',
                    'leaderCells' : 'leader'}

# Grab a specific train-test split
split_indices = list(cv_sets.split(fspace, archetypes))[0]
fspace_train, archetypes_train = fspace[split_indices[0], :], archetypes[split_indices[0]]
fspace_test,  archetypes_test  = fspace[split_indices[1], :], archetypes[split_indices[1]]

# Fit and predict
best.fit(fspace_train, archetypes_train)
archetypes_train_pred = best.predict(fspace_train)
archetypes_test_pred  = best.predict(fspace_test)

# Compute confusion matrices
cm_train = metrics.confusion_matrix(archetypes_train, archetypes_train_pred)
cm_test  = metrics.confusion_matrix(archetypes_test, archetypes_test_pred)

# Prep plot
fig, ax = plt.subplots(1, 2, figsize=(8,8), sharey=True)

# Function for creating and styling a confusion matrix
def confmat(ax, cm):
    
    # Handle axis boundaries...
    ax.set_adjustable('box-forced')
    
    # Show image and text
    ax.imshow(cm, interpolation='nearest', cmap='Blues')
    for (i, j), z in np.ndenumerate(cm):
        ax.text(j, i, z, ha='center', va='center', fontsize=15)
    
    # Adjust ticks
    ax.set_xticks(range(cm.shape[0]))
    ax.set_yticks(range(cm.shape[0]))

# Plot cm_train
confmat(ax[0], cm_train)
ax[0].set_title(fspace_type+" train (n=%i)" % fspace_train.shape[0], fontsize=17, y=1.02)
ax[0].set_yticklabels([publ_rename_dict[celltype_decodedict[i]] for i in best.classes_], 
                      rotation=45, va='top', fontsize=14)
ax[0].set_xticklabels([publ_rename_dict[celltype_decodedict[i]] for i in best.classes_], 
                      rotation=45, ha='right', fontsize=14)

# Plot cm_test
confmat(ax[1], cm_test)
ax[1].set_title(fspace_type+" test (n=%i)" % fspace_test.shape[0], fontsize=17, y=1.02)
ax[1].set_yticklabels([publ_rename_dict[celltype_decodedict[i]] for i in best.classes_], 
                      rotation=45, va='top', fontsize=14)
ax[1].set_xticklabels([publ_rename_dict[celltype_decodedict[i]] for i in best.classes_], 
                      rotation=45, ha='right', fontsize=14)

# Labels
fig.text( 0.545, 0.19, 'prediction',   ha='center', va='center', fontsize=16)
fig.text(-0.01, 0.50, 'ground truth', ha='center', va='center', rotation='vertical', fontsize=16)

# Layout
plt.tight_layout()

# Done
plt.show()

### Probability Embedding

In [None]:
### Create probability embedding

# Predict probabilities
print "\nPredicting probabilities..."
archetypes_proba = best.predict_proba(fspace)
print "  Predicted archetypes_proba of shape:", archetypes_proba.shape

# PCA of probabilities
print "\nEmbedding probabilities..."
embedPCA = PCA()
proba_embedding = embedPCA.fit_transform(archetypes_proba)
#proba_embedding = proba_embedding[:, embedPCA.explained_variance_ratio_ > 0.05]
print "  Expl. var. ratios:", embedPCA.explained_variance_ratio_
print "  Created embedding of shape:  ", proba_embedding.shape

In [None]:
### Plot archetype space in 2D

# Prep plot
plt.figure(figsize=(6,5))

# Create scatter
for key in celltype_decodedict.keys():
    mask = archetypes==key
    if np.any(mask):
        scat = plt.scatter(proba_embedding[mask, 0], proba_embedding[mask, 1],
                           color=cellcolors[key], edgecolor='', 
                           s=10, alpha=0.85, label=celltype_decodedict[key])
    
# Cosmetics  
plt.legend(frameon=False, fontsize=8)
plt.xlabel("PC 0")
plt.ylabel("PC 1")
plt.title("archetype classification probability embedding")
plt.show()

In [None]:
### Plot archetype space in 3D

# Prep plot
fig = plt.figure(figsize=(9,7))
ax  = fig.add_subplot(111, projection='3d')

# Create scatter
for key in celltype_decodedict.keys():
    mask = archetypes==key
    if np.any(mask):
        ax.scatter(proba_embedding[mask, 0], 
                   proba_embedding[mask, 1], 
                   proba_embedding[mask, 2],
                   c=cellcolors[key], edgecolor='face', 
                   s=10, alpha=0.85, label=celltype_decodedict[key])

# Switch off gray panes
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
        
# Cosmetics  
ax.legend(frameon=False, fontsize=8)
ax.set_xlabel("PC 0")
ax.set_ylabel("PC 1")
ax.set_zlabel("PC 2")
plt.axis('equal')
plt.show()

----

[Back to Top](#top)