In [1]:
# default_exp glp.clustering

%reload_ext autoreload
%autoreload 2

# glp.clustering



## Embedding Proteins



In [2]:
import sys
sys.path.append('../')

In [15]:
#hide
#export

from itertools import islice
import pandas as pd
import numpy as np

from transformers import AutoModel, AutoTokenizer
from umap import UMAP
from fastai.text.all import *
import hdbscan

from justenough.nlp.core import *
from justenough.explain.core import *

### Loading Data

For this example we're going to use a small dataset of 2000 HIV tat proteins with labels for tissue of isolation and co-receptor binding.
We'll see if the clusters generated by topic modeling will correspond to these groups.

In [16]:
df = pd.read_csv('../tutorials/HIV_tat_example.csv')
df['sequence'] = df['sequence'].str.strip('*')
df.head()

Unnamed: 0,accession,sample_tissue,coreceptor,sequence
0,M17449,PBMC,CXCR4,MEPVDPRLEPWKHPGSQPKTACTTCYCKKCCFHCQVCFTKKALGISYGRKKRRQRRRAPEDSQTHQVSLPKQPAPQFRGDPTGPKESKKKVERETETHPVD
1,M26727,PBMC,CCR5,MEPVDPRLEPWKHPGSQPKTASNNCYCKRCCLHCQVCFTKKGLGISYGRKKRRQRRRAPQDSKTHQVSLSKQPASQPRGDPTGPKESKKKVERETETDPED
2,M17451,PBMC,CCR5|CXCR4,MEPVDPRLEPWKHPGSQPKTACNNCYCKKCCYHCQVCFLTKGLGISYGRKKRRQRRGPPQGSQTHQVSLSKQPTSQPRGDPTGPKESKEKVERETETDPAVQ
3,K02007,PBMC,CCR5|CXCR4,MEPVDPNLEPWKHPGSQPRTACNNCYCKKCCFHCYACFTRKGLGISYGRKKRRQRRRAPQDSQTHQASLSKQPASQSRGDPTGPTESKKKVERETETDPFD
4,M62320,blood,,MEPVDPNLEPWKHPGSQPTTACSNCYCKVCCWHCQLCFLKKGLGISYGKKKRKPRRGPPQGSKDHQTLIPKQPLPQSQRVSAGQEESKKKVESKAKTDRFA


Great, now we have an easy to use `DataFrame` of our pubmed information.
Let's get into the deep learning.

### Pipeline

In [17]:
tmi = TopicModelingInterface(model_name = 'Rostlab/prot_bert')

In [18]:
cluster_info, emb_data = tmi.process_df(df, col = 'sequence')
cluster_info.head()

Unnamed: 0,cluster,X,Y,label,d0,d1,d2,d3,d4,d5,d6,d7,d8,d9
0,2,13.392046,3.501023,2,12.251259,3.45661,5.252148,4.470291,5.588035,3.107115,6.219954,4.726951,4.389,2.776044
1,2,10.992587,4.044052,2,12.339379,4.051465,5.208279,4.696549,5.899921,2.954771,6.186308,5.194275,4.74995,3.058691
2,2,13.214004,2.636483,2,11.905041,2.852295,4.529519,4.956338,5.160463,3.50931,6.128185,4.714086,5.1291,3.607634
3,2,11.296652,3.913122,2,12.441009,3.151718,5.157255,5.394715,5.221587,3.30023,6.42404,4.640056,5.061777,3.219835
4,2,13.337667,2.834853,2,12.317655,3.328009,5.412022,4.715364,5.127587,2.995078,6.269528,4.415875,4.608104,2.736147


In [19]:
full_df = pd.concat([cluster_info, df], axis=1)
full_df.head()

Unnamed: 0,cluster,X,Y,label,d0,d1,d2,d3,d4,d5,d6,d7,d8,d9,accession,sample_tissue,coreceptor,sequence
0,2,13.392046,3.501023,2,12.251259,3.45661,5.252148,4.470291,5.588035,3.107115,6.219954,4.726951,4.389,2.776044,M17449,PBMC,CXCR4,MEPVDPRLEPWKHPGSQPKTACTTCYCKKCCFHCQVCFTKKALGISYGRKKRRQRRRAPEDSQTHQVSLPKQPAPQFRGDPTGPKESKKKVERETETHPVD
1,2,10.992587,4.044052,2,12.339379,4.051465,5.208279,4.696549,5.899921,2.954771,6.186308,5.194275,4.74995,3.058691,M26727,PBMC,CCR5,MEPVDPRLEPWKHPGSQPKTASNNCYCKRCCLHCQVCFTKKGLGISYGRKKRRQRRRAPQDSKTHQVSLSKQPASQPRGDPTGPKESKKKVERETETDPED
2,2,13.214004,2.636483,2,11.905041,2.852295,4.529519,4.956338,5.160463,3.50931,6.128185,4.714086,5.1291,3.607634,M17451,PBMC,CCR5|CXCR4,MEPVDPRLEPWKHPGSQPKTACNNCYCKKCCYHCQVCFLTKGLGISYGRKKRRQRRGPPQGSQTHQVSLSKQPTSQPRGDPTGPKESKEKVERETETDPAVQ
3,2,11.296652,3.913122,2,12.441009,3.151718,5.157255,5.394715,5.221587,3.30023,6.42404,4.640056,5.061777,3.219835,K02007,PBMC,CCR5|CXCR4,MEPVDPNLEPWKHPGSQPRTACNNCYCKKCCFHCYACFTRKGLGISYGRKKRRQRRRAPQDSQTHQASLSKQPASQSRGDPTGPTESKKKVERETETDPFD
4,2,13.337667,2.834853,2,12.317655,3.328009,5.412022,4.715364,5.127587,2.995078,6.269528,4.415875,4.608104,2.736147,M62320,blood,,MEPVDPNLEPWKHPGSQPTTACSNCYCKVCCWHCQLCFLKKGLGISYGKKKRKPRRGPPQGSKDHQTLIPKQPLPQSQRVSAGQEESKKKVESKAKTDRFA


In [20]:
from justenough.explain.core import *

In [21]:
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource
from bokeh.transform import factor_cmap
from bokeh.palettes import viridis
    
    
class BokehFigureExplanation(DataFrameExplanation):
    
    bokeh_source = None
    
    tooltips = None
    tip_cols = None
    
    factor_col = None
    factor_cmap = None
    factors = None
    
    pallette = viridis
    
    fig = None
    
    # For reference, do not overload
    def plot(self, **polish_kwargs):
        
        #Setup the parent class info if needed.
        self.setup() #which calls self._setup
        
        self.generate_bokeh()
        
        self.polish_bokeh(**polish_kwargs)
        
        return self.for_show()
        
                
    def _setup(self):
        """Subclass to setup explaination specific processes.
        
        Should at the very least set self.bokeh_source"""
        raise NotImplementedError
        
    def generate_bokeh(self):
        """Subclass to setup specific plotting"""
        raise NotImplementedError
        
                
    def polish_bokeh(self):
        """Useful to subclass for adding labels, annotations, etc. to the plot.
        
        Assume all setup, setup_bokeh, and generate_bokeh have been done.
        """
        pass
    
    
    def for_show(self):
        """Retuns an object ready to show with Bokeh.
        
        Useful to subclass if there's anything beyond returning self.fig"""
        
        return self.fig
    
    
    
    ##### Below this are utility functions that should be left alone.
    
    
    def setup(self):
        
        super().setup()
        self._setup()
    
    
    def setup_bokeh(self):
        
        triggers = [
            (self.factor_col, self._build_factors),
            (self.tip_cols, self._build_tips)
                    ]
        
        for trig, func in triggers:
            if trig is not None:
                func()
                        
        self.setup_source()
        assert self.bokeh_source is not None
    
    
    def save_png(self, path, fig_kw = {}, export_kw={}):
        
        fig = self.generate_figure(**fig_kw)
        export_png(fig, path, **export_kw)
            
    
def _build_tips(tip_cols):

    
    return [(col, '@' + col) for col in tip_cols if ' ' not in col]


def _build_factor_cmap(data, factor_col, pallette):

    factors = data[factor_col].map(str).unique().tolist()
    cmap = factor_cmap(factor_col, 
                       pallette(len(factors)),  
                       factors)
    return factors, cmap
    
    

In [22]:
from bokeh.plotting import output_notebook, show
output_notebook()    


In [23]:
#export

from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, CDSView, BooleanFilter
from bokeh.transform import factor_cmap
from bokeh.palettes import viridis


class ClusteringBokehExplanation(BokehFigureExplanation):
    
    
    def __init__(self, cluster_data, 
                 xy_cols = ('X', 'Y'),
                 factor_col = 'label',
                 tooltips = None,
                 tip_cols = None,
                 pallette = viridis,
                 plot_missing = False,
                 extra_figure_args = None):
        
        self.cluster_data = cluster_data
        self.x, self.y = xy_cols
        self.factor_col = factor_col
        
        if tip_cols is None:
            self.tip_cols = ['Cluster']
        else:
            self.tip_cols = tip_cols
    
        self.pallette = pallette
        self.plot_missing = plot_missing
        self.extra_figure_args = {} if extra_figure_args is None else extra_figure_args

    
    def generate_bokeh(self):
                
        if self.plot_missing:
            mask = self.data[self.factor_col].notnull()
            view = CDSView(source=self.bokeh_source, filters=[BooleanFilter(mask.tolist())])
        else:
            view = CDSView(source=self.bokeh_source)

        fig = figure(tooltips = self.tooltips, **self.extra_figure_args)
        
        color = 'black' if self.factor_cmap is None else self.factor_cmap
        
        fig.scatter(x = self.x, y = self.y,
                    source = self.bokeh_source, view = view,
                    #legend_group = self.factor_col,
                    legend_field = self.factor_col,
                    color = color)
        
        fig.legend.click_policy="hide"
        
        self.fig = fig
        
    def _setup(self):
        
        self.bokeh_source = ColumnDataSource(self.cluster_data)
        self.factors, self.factor_cmap = _build_factor_cmap(self.cluster_data,
                                                            self.factor_col,
                                                            self.pallette)
        self.tooltips = _build_tips(self.tip_cols)
        

In [24]:
cl_exp = ClusteringBokehExplanation(full_df,
                                    factor_col= 'sample_tissue',
                                    tip_cols= ['accession', 'cluster', 
                                               'coreceptor', 'sample_tissue'])

fig = cl_exp.plot()
show(fig)

Cool. There are definitely clusters. Unfortunately they look pretty "intermixed" between our phenotypes. 
This is often explained using a `silhouette` score. This score is a measure of the ratio of the distance between the nearest true cluster (a) and the nearest false cluster (b): 
`(b - a) / max(a, b)`

Positive 1 indicates a perfect overlap. -1 is random. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_samples.html

This `SilhoutteBokehExplanation` automates the process of calculating this across multiple columns and generating a useful figure.

In [25]:
from sklearn.metrics import silhouette_samples

from bokeh.models import FactorRange

class SilhoutteBokehExplanation(BokehFigureExplanation):
    
    def __init__(self, cluster_data, feature_cols, label_cols):
        
        self.cluster_data = cluster_data
        self.feature_cols = feature_cols
        self.label_cols = label_cols

    def _setup(self):
        
        self._build_tips()
        
        self.calc_silhouttes()
        
        self.x_range = FactorRange(*self.silhoutte_means.index)
        source = {'feature_col': list(self.silhoutte_means.index),
                  'silhoutte_mean': self.silhoutte_means.tolist(),
                  'color': viridis(len(self.silhoutte_means))}
        self.bokeh_source = ColumnDataSource(data=source)
        
        
    def calc_silhouttes(self):
        
        out = {}
        for col in self.label_cols:
            mask = self.cluster_data[col].notnull()
            sample_scores = silhouette_samples(self.cluster_data.loc[mask, self.feature_cols], 
                                               self.cluster_data.loc[mask, col])
            out[col] = pd.Series(sample_scores, index = self.cluster_data.index[mask])
        
        odf = pd.DataFrame(out)
        self.silhoutte_scores = odf.reindex(self.cluster_data.index, axis=0)
        self.silhoutte_means = self.silhoutte_scores.mean()
        self.silhoutte_means.name = 'silhoutte_mean'
        self.silhoutte_means.index.name = 'feature_col'
        
        return self.silhoutte_scores
    
        
    def generate_bokeh(self):
                
        fig = figure(x_range = self.x_range, 
                     plot_height=250, tooltips = self.tooltips)
        fig.vbar(x = 'feature_col', top = 'silhoutte_mean', 
                 source = self.bokeh_source, color = 'color',
                 width=0.3)
        fig.y_range.start = -1
        fig.y_range.end = 1
        fig.xgrid.grid_line_color = None
        fig.yaxis.axis_label = 'Silhoutte Score'
        
        self.fig = fig
    
    
    def _build_tips(self):
        
        self.tooltips = [('Feature', '@feature_col'),
                         ('Score', '@silhoutte_mean')]
    
    

In [26]:
full_df['seq_len'] = full_df['sequence'].map(len)
feats = [col for col in full_df.columns if col.startswith('d')]
sil_exp = SilhoutteBokehExplanation(full_df, feats, 
                                    ['coreceptor', 'sample_tissue', 'seq_len'])

fig = sil_exp.plot()

show(fig)




From these scores we can see that none of these are particularly good explanations for the clusters we see.
In future tutorials we'll train this model to better represent these different features. Right now all we can say is that the model IS grouping sequences by something ... but not these are not the features we're looking for.

For right now, we'll just we'll just package all of this logic into an easy to use pipeline that generates these explanations automagically.

## Pipeline


First we'll need to create an AutoPipeline to contain common features of all pipelines. Right now, it'll be pretty shallow.

In [15]:
# export

class AutoPipeline(object):
    
    
    
    def save(self):
        pass
    
    def load(self):
        pass
    
    def fit(self, data):
        """Fits this data to the model"""
        pass
    
    def fit_transform(self, data):
        """Sometimes its best to do both at once."""
        pass
    
    def transform(self, data):
        """Returns AutoPipelineResult"""
        pass
    
    
    
class AutoResult(object):
    
    
    def explain(self, explanations = 'all'):
        
        pass
    
    def save(self):
        pass
    
    def load(self):
        pass
    
    

In [16]:
#export

class GLPClusteringPipeline(AutoPipeline):
    
    full_embedding = None
    cluster_embedding = None
    vis_embessing = None
    
    
    def __init__(self, tmi = None, 
                 tokenizer = None, model = None, 
                 model_name = None, bs=8,
                 cluster_dim = 10, viz_dim = 2, 
                 device = 'cuda',
                 min_cluster_size = 5,
                 defaults = None):
        
        if tmi is not None:
            self.tmi = tmi
            
        else:
            self.tmi = TopicModelingInterface(tokenizer = tokenizer, model = model, model_name = model_name, bs=bs,
                                              cluster_dim = cluster_dim, viz_dim = viz_dim, device = device,
                                              min_cluster_size = min_cluster_size)
            
        self.model = self.tmi.model
            
        if defaults is None:
            self.defaults = {}
        else:
            self.defaults = defaults
            
            
    def fit(self, glp_data):
        self.fit_transform(glp_data, fit = True)
        
    
    def fit_transform(self, glp_data, 
                      sequence_col = None,
                      feature_cols = None,
                      fit = True):
        
        if sequence_col is None:
            try:
                sequence_col = self.defaults['sequence_col']
            except KeyError:
                sequence_col = 'sequence'
                
        if feature_cols is None:
            try:
                feature_cols = self.defaults['feature_cols']
            except KeyError:
                feature_cols = []
            
        
        seqs = glp_data[sequence_col].dropna()
        cluster_data, raw_embedding = self.tmi.process_df(glp_data, col = sequence_col, fit = fit)
        feature_data = glp_data.loc[seqs.index, feature_cols]
        
        return GLPClusteringResult(feature_data, raw_embedding, cluster_data)
        
    
    def transform(self, glp_data):
        return self.fit_transform(glp_data, fit=False)
    
    def save(self, path):
        pass
    
    @staticmethod
    def load(path):
        pass
             
    
class GLPClusteringResult(AutoResult):
    
    
    raw_embedding = None
    cluster_data = None
    feature_data = None
    explanations = ['cluster_figure', 'silhoutte_figure']
    
    
    def __init__(self, feature_data, raw_embedding, cluster_data):
        self.feature_data = feature_data
        self.raw_embedding = raw_embedding
        self.cluster_data = cluster_data
    
    
    def explain(self, factor_col = 'cluster', tip_cols = None):
        
        if tip_cols is None:
            tip_cols = list(self.feature_data.columns)
        
        
        # Pull out the cluster-feature data for later
        clust_feats = [col for col in self.cluster_data.columns if col.startswith('d')]
        
        full_df = pd.concat([self.cluster_data, self.feature_data], axis=1)
        
        sil_exp = SilhoutteBokehExplanation(full_df, clust_feats, 
                                            self.feature_data.columns)
        
        
        cl_exp = ClusteringBokehExplanation(full_df,
                                            factor_col = factor_col,
                                            tip_cols = tip_cols)
        
        return {'cluster_figure': cl_exp, 'silhoutte_figure': sil_exp}
        
    

Let's reload the data just to bring everythig back around.

In [17]:
df = pd.read_csv('../tutorials/HIV_tat_example.csv')
df['sequence'] = df['sequence'].str.strip('*')
df['seq_length'] = df['sequence'].map(len)
df.head()

Unnamed: 0,accession,sample_tissue,coreceptor,sequence,seq_length
0,M17449,PBMC,CXCR4,MEPVDPRLEPWKHPGSQPKTACTTCYCKKCCFHCQVCFTKKALGISYGRKKRRQRRRAPEDSQTHQVSLPKQPAPQFRGDPTGPKESKKKVERETETHPVD,101
1,M26727,PBMC,CCR5,MEPVDPRLEPWKHPGSQPKTASNNCYCKRCCLHCQVCFTKKGLGISYGRKKRRQRRRAPQDSKTHQVSLSKQPASQPRGDPTGPKESKKKVERETETDPED,101
2,M17451,PBMC,CCR5|CXCR4,MEPVDPRLEPWKHPGSQPKTACNNCYCKKCCYHCQVCFLTKGLGISYGRKKRRQRRGPPQGSQTHQVSLSKQPTSQPRGDPTGPKESKEKVERETETDPAVQ,102
3,K02007,PBMC,CCR5|CXCR4,MEPVDPNLEPWKHPGSQPRTACNNCYCKKCCFHCYACFTRKGLGISYGRKKRRQRRRAPQDSQTHQASLSKQPASQSRGDPTGPTESKKKVERETETDPFD,101
4,M62320,blood,,MEPVDPNLEPWKHPGSQPTTACSNCYCKVCCWHCQLCFLKKGLGISYGKKKRKPRRGPPQGSKDHQTLIPKQPLPQSQRVSAGQEESKKKVESKAKTDRFA,101


As we create the pipeline we can define some defaults for future calls. This makes it easy to train the pipeline on a random subset.

In [18]:
ANALYSIS_DEFAULTS = {'feature_cols': ['coreceptor', 'sample_tissue', 'seq_length']}
clustering_pipeline = GLPClusteringPipeline(model_name = 'Rostlab/prot_bert',
                                            defaults = ANALYSIS_DEFAULTS)

train_data = df.sample(500)

clustering_pipeline.fit(train_data)

And then running it on the whole dataset.

In [40]:
result = clustering_pipeline.transform(df)
result

<__main__.GLPClusteringResult at 0x7f1e224f7a30>

In [33]:
len(result.cluster_data)

2136

In [34]:
explanation = result.explain(factor_col = 'coreceptor')
explanation

{'cluster_figure': <__main__.ClusteringBokehExplanation at 0x7f1e2e04b700>,
 'silhoutte_figure': <__main__.SilhoutteBokehExplanation at 0x7f1e2353fdc0>}

In [35]:
fig = explanation['cluster_figure'].generate_figure(skip_missing = True)
fig

In [36]:
show(fig)