## RUN: Archetype Prediction

This notebook runs the archetype prediction with the shape space as features and manual annotations as target.

Hyperparameter optimization was done in `DEV_Archetypes.ipynb`.

## Notes
----

- The paper uses slightly different nomenclature than what is used here:

| Paper | Code |
| :---: | :--: |
| `central rosette cells` | `innerRosetteCells` |
| `peripheral rosette cells` | `outerRosetteCells` |
| `inter-organ cells` | `betweenRosetteCells` |
| `leader cells` | `leaderCells` |


- Selection criteria for manual cell annotation
    - `innerRosetteCells:` cells directly adjacent to a lumen without contact to the outside of the tissue
    - `outerRosetteCells:` cells to the left or right of a lumen with a large area of contact to the outside of the tissue
    - `betweenRosetteCells:` cells between two rosettes, both within and on the outside of the tissue
    - `leaderCells:` the first few cells from the front, in particular those without a clear apical backwards-polarity
    
 
- Manual annotation was originally done with the Fiji `Point` tool and `Ctrl+M` to get measurements
    - Results stored in a separate `.csv` file for each class
    - This was converted to a single numpy vector (see code below)
    - The numeric encoding is as follows:
        - `0 : 'unclassified',`
        - `1 : 'innerRosetteCells',`
        - `2 : 'outerRosetteCells',`
        - `3 : 'betweenRosetteCells',`
        - `4 : 'leaderCells'`
    - This is also stored in the metadata as `archetype_decodedict` (and the inverse as `archetype_encodedict`)


- On the IDR, the manual annotations are stored as `.tsv` files like everything else
    - The code below allows this data to be loaded and used here

### Prep

In [None]:
### Imports

# Generic
from __future__ import division
import os, sys, pickle
import numpy as np
import matplotlib.pyplot as plt

# Specific
from sklearn.decomposition import PCA
import sklearn.svm as svm
from sklearn.metrics import confusion_matrix

# Modules
import katachi.utilities.loading as ld

In [None]:
### Prep encoding and decoding dict

archetype_decodedict = {0 : 'unclassified',
                       1 : 'innerRosetteCells',
                       2 : 'outerRosetteCells',
                       3 : 'betweenRosetteCells',
                       4 : 'leaderCells'}

archetype_encodedict = {name:key for key,name 
                       in archetype_decodedict.iteritems()}

### Conversion from Fiji Output

Conversion of manual annotation files from Fiji into a clean numpy vector.

When using data from the IDR, this is not necessary.

The filename expected is `<sample_ID>_<archetype>_manual.csv`, where `<archetype>` is one of `innerRosetteCells`, `outerRosetteCells`, `betweenRosetteCells`, `leaderCell`. One file must be present for each archetype.

The resulting numpy array is saved as `<sample_ID>_archetype_manual.npy`. It contains all four archetypes numerically encoded now, so there won't be one file for each archetype.

### Archetype Prediction

In [None]:
### Load shape space data from IDR

# Path to the data
dirpath = r'data/experimentA/extracted_measurements/'

# Set whether to use TFOR or CFOR
fspace_type = 'TFOR'
#fspace_type = 'CFOR'

# Load corresponding shape space
loader = ld.DataLoaderIDR(dirpath, recurse=True)
suffix = 'shape_'+fspace_type+'_raw_measured.tsv'
fspace, prim_IDs, fspace_idx = loader.load_dataset(suffix)
print "Imported shape space of shape:", fspace.shape

In [None]:
### Load annotations from IDR

archetypes, _, _ = loader.load_dataset("archetype_manual_annotations.tsv")
archetypes[np.isnan(archetypes)] = 0
print "Imported manual archetype annotations of shape:", archetypes.shape

In [None]:
### Preprocess data

# Reduce the training data to the actually annotated samples
fspace_train   = fspace[archetypes!=0]
fspace_predict = fspace
archetypes_rdy = archetypes[archetypes!=0]

# Standardization (beneficial for TFOR according to tests)
if fspace_type == 'TFOR':
    train_means, train_stds = (np.mean(fspace_train, axis=0), np.std(fspace_train, axis=0))
    fspace_train   = (fspace_train   - train_means) / train_stds
    fspace_predict = (fspace_predict - train_means) / train_stds

# PCA (beneficial for CFOR according to tests)
if fspace_type == 'CFOR':
    pca = PCA().fit(fspace_train)
    fspace_train   = pca.transform(fspace_train)
    fspace_predict = pca.transform(fspace_predict)

# Report
print "Final shapes:"
print "  archetypes_rdy: ", archetypes_rdy.shape
print "  fspace_train:   ", fspace_train.shape
print "  fspace_predict: ", fspace_predict.shape

In [None]:
### Run prediction

# Prep regressor (Hyperparams optimized in `TEST_archetypes`)
if fspace_type == 'TFOR':               
    svc = svm.SVC(probability=True, kernel='rbf',
                  C=1.0, gamma=1.0 / fspace_train.shape[1]) 
if fspace_type == 'CFOR':
    svc = svm.SVC(probability=True, kernel='rbf',
                  C=10.0, gamma=1.0 / fspace_train.shape[1]) 

# Fit
print "Fitting..."
svc.fit(fspace_train, archetypes_rdy)

# Predict
print "Predicting..."
archetypes_pred = svc.predict(fspace_predict)
archetypes_prob = svc.predict_proba(fspace_predict)

# Report
print "\nPredicted archetypes_pred of shape:", archetypes_pred.shape
print   "Predicted archetypes_prob of shape:", archetypes_prob.shape

In [None]:
### Quick & Simple confusion matrix evaluation on training data

# Repredict just the training samples
archetypes_train_pred = svc.predict(fspace_train)

# Compute confusion matrix
cm = confusion_matrix(archetypes_rdy, archetypes_train_pred)

# Prep plot
fig, ax = plt.subplots(1, 1, figsize=(3,3))

# Function for creating and styling a confusion matrix
def confmat(ax, cm):
    
    # Handle axis boundaries...
    ax.set_adjustable('box-forced')
    
    # Show image and text
    ax.imshow(cm, interpolation='none', cmap='Blues')
    for (i, j), z in np.ndenumerate(cm):
        ax.text(j, i, z, ha='center', va='center')
    
    # Adjust ticks
    ax.set_xticks(range(cm.shape[0]))
    ax.set_yticks(range(cm.shape[0]))

# Plot cm
confmat(ax, cm)
ax.set_title("Quick Check on Training Data", fontsize=10)
ax.set_yticklabels([archetype_decodedict[i] for i in range(1, cm.shape[0]+1)], 
                    rotation=45, va='top', fontsize=8)
ax.set_xticklabels([archetype_decodedict[i] for i in range(1, cm.shape[0]+1)], 
                    rotation=45, ha='right', fontsize=8)
ax.set_ylabel("Ground Truth", labelpad=10)
ax.set_xlabel("Prediction", labelpad=10)

# Done
plt.show()

In [None]:
### Writing the prediction results
#   Note: This writes the results as numpy arrays associated with individual
#         samples. The original results provided as a tsv in the IDR are not 
#         overwritten by this.

# Report
print "Saving output..."

# Find samples in the image_data folder
rawpath = r'data\experimentA\image_data'
rawloader = ld.DataLoader(rawpath, recurse=True)

# Load metadata
meta, prim_IDs, _ = rawloader.load_dataset("stack_metadata.pkl")

# For each prim...
for prim_ID in prim_IDs:
    
    # Find the segmentation path...
    for fpath in rawloader.data[prim_ID]:
        if fpath.endswith('_seg.tif'):
    
            # Construct prediction output path
            outpath = fpath[:-4] + '_archetypes_' + fspace_type + '-PREDICTED.npy'
            
            # Save prediction data
            np.save(outpath, archetypes_pred[fspace_idx==prim_IDs.index(prim_ID)])
            
            # Construct probabilities output path
            outpath = fpath[:-4] + '_archetypes_' + fspace_type + '-PROBA.npy'
            
            # Save probability data
            np.save(outpath, archetypes_prob[fspace_idx==prim_IDs.index(prim_ID)])
            
    # Add proba classes to metadata (just in case)
    meta[prim_ID]['archetype_probaclasses'] = svc.classes_

    # Save metadata
    for fpath in rawloader.data[prim_ID]:
        if '_stack_metadata.pkl' in fpath:
            with open(fpath, 'wb') as metafile:
                pickle.dump(meta[prim_ID], metafile, pickle.HIGHEST_PROTOCOL)

# Save the fitted SVC separately for predictions on other datasets
with open('other/archetype_svc_'+fspace_type+'.pkl', 'wb') as outfile:
    pickle.dump(svc, outfile, pickle.HIGHEST_PROTOCOL)

# Report
print "Done!"