# Paper Visualizations

Here, we produce some visualizations for my M.S. thesis paper. 

In [1]:
import plotly.express as px 
import plotly.graph_objects as go
import numpy as np
import scipy as sp
import pandas as pd 
import anndata as an
import matplotlib.pyplot as plt 

import sys 
sys.path.append('../src')

from model import *
from lightning_train import *
from data import *

# Label Distributions

Let's visualize the label distributions for all datasets we've been testing our model with.

In [76]:
labelfiles = {
    'Human Cortical Subclass': '../data/benchmark/human_labels_clean.csv',
    'Mouse Cortical Subclass': '../data/benchmark/mouse_labels_clean.csv',
    'Human Dental': '../data/dental/labels_human_dental.tsv',
    'Retinal': '../data/retina/retina_labels_numeric.csv'
}

label_cols = [
    'subclass_label',
    'categorical_subclass_label',
    'cluster_celltype',
    'CellType',
]

for (name, file), col in zip(labelfiles.items(), label_cols):
    labels = pd.read_csv(file, sep=None)
    
    vals = labels.loc[:, col].value_counts()
    fig = go.Figure(
        data=[go.Bar(x=vals.index, y=vals.values)],
        layout=go.Layout(
            title=f'Distribution of Labels in {name} Dataset',
        )
    )
    
    fig.write_image(f"../../ms-thesis/images/label_distributions/{name}_label_dist.pdf", scale=3)
    fig.show()
    
    fig = go.Figure(
        data=[go.Bar(x=vals.index, y=np.log(vals.values))],
        layout=go.Layout(
            title=f'Log-Distribution of Labels in {name} Dataset',
        )
    )

    fig.write_image(f"../../ms-thesis/images/label_distributions/{name}_log_label_dist.pdf", scale=3)

















# Feature Masks for our Pretrained Models

Let's visualize the feature masks for the pretrained models so far, both by cell type and by overall features. This will help us understand which features the model is using for prediction!

### Feature Mask for Retina Model

In [66]:
cols = an.read_h5ad('../data/retina/retina_T.h5ad').var['gene'].values
cols

Variable names are not unique. To make them unique, call `.var_names_make_unique`.


array(['ENSG00000000003|TSPAN6', 'ENSG00000000419|DPM1',
       'ENSG00000000457|SCYL3', ..., 'ENSG00000284744|AL591163.1',
       'ENSG00000284747|AL034417.4', 'ENSG00000284748|AL513220.1'],
      dtype=object)

In [56]:
module = DataModule(
    datafiles=['../data/retina/retina_T.h5ad'],
    labelfiles=['../data/retina/retina_labels_numeric.csv'],
    class_label='class_label',
    index_col='cell',
    batch_size=4,
    num_workers=0,
    shuffle=True,
    drop_last=True,
    normalize=True,
    deterministic=True,
)

module.setup()

model = TabNetLightning.load_from_checkpoint(
    '../checkpoints/checkpoint-80-desc-retina.ckpt',
    input_dim=37475,
    output_dim=13,
    n_d=32,
    n_a=32,
    n_steps=10,
)

loader = module.testloader
mask = model.explain(loader)

Creating train/val/test DataLoaders...


Variable names are not unique. To make them unique, call `.var_names_make_unique`.


Done, continuing to training.
Calculating weights
Initializing network
Initializing explain matrix


Variable names are not unique. To make them unique, call `.var_names_make_unique`.


In [74]:
sums = mask[0].sum(axis=0)
idx = np.where(sums > 1)
relevant = sums[idx]

features = pd.DataFrame(index=cols[idx])
features['weights'] = relevant

features = features.sort_values(by='weights', ascending=False)
fig = go.Figure(data=go.Bar(x=features.index, y=features['weights']))

fig.update_layout(
    title='Retina Model', 
    xaxis_title='Input Feature', 
    yaxis_title='Sum of Feature Masks'
)

fig.show()

In [75]:
fig.write_image('../../ms-thesis/images/retina_feature_weights.pdf', scale=3)

In [80]:
mask[0].shape

(2624, 37475)

## Feature Mask for Dental Model

In [10]:
dental_cols = an.read_h5ad('../data/dental/human_dental_T.h5ad').var.index
module = DataModule(
    datafiles=['../data/dental/human_dental_T.h5ad'],
    labelfiles=['../data/dental/labels_human_dental.tsv'],
    class_label='cell_type',
    sep='\t',
    batch_size=16,
    num_workers=0,
)

module.setup()

model = TabNetLightning.load_from_checkpoint(
    '../checkpoints/checkpoint-80-desc-dental.ckpt',
    input_dim=module.num_features,
    output_dim=module.num_labels,
)

loader = module.testloader

dental_mask = model.explain(loader)
dental_sum = dental_mask[0].sum(axis=0)

Creating train/val/test DataLoaders...
Done, continuing to training.
Calculating weights
Initializing network
Initializing explain matrix


100%|████████████████████████████████████| 417/417 [06:31<00:00,  1.07it/s]


In [72]:
idx = np.where(dental_sum > 20)
relevant = dental_sum[idx]

features = pd.DataFrame(index=dental_cols[idx])
features['weights'] = relevant

features = features.sort_values(by='weights', ascending=False)

fig2 = go.Figure(
    data=go.Bar(x=features.index, y=features['weights']),
)

fig2.update_layout(
    xaxis_title='Input Feature', 
    yaxis_title='Sum of Feature Masks'
)

fig2.show()

In [73]:
fig2.write_image('../../ms-thesis/images/dental_feature_weights.pdf', scale=3)

In [None]:
# np.where(dental_mask > 0.01).shape

## Mouse Inhibitory Model 

In [None]:
module = DataModule(
    datafiles=['../data/benchmark/mouse_clipped.h5ad'],
    labelfiles=['../data/benchmark/mouse_labels_clean.csv'],
    class_label='subclass_label',
    sep='\t',
    batch_size=16,
    num_workers=0,
)

module.setup()

model = TabNetLightning.load_from_checkpoint(
    '../checkpoints/checkpoint-280-desc-mouse.ckpt',
    input_dim=module.num_features,
    output_dim=module.num_labels,
)

loader = module.testloader

dental_mask = model.explain(loader)


sums = dental_mask[0]
sums = sums.sum(axis=0)

idx = np.where(sums > 0.01)
relevant = sums[idx]

features = pd.DataFrame(index=cols[idx])
features['weights'] = relevant

features = features.sort_values(by='weights', ascending=False)
fig2 = go.Figure(
    data=go.Bar(x=features.index, y=features['weights']),
    layout=go.Layout(
        title='Global Feature Weights for Dental Model', xaxis=dict(title='Input feature'), 
        yaxis='Aggregate mask weight'
    )
)