# PerturbAtlas Harmonizome Processing Script
This notebook contains the scripts used to process the PerturbAtlas Perturbation Gene Expression Profiles dataset.
The list of perturbations, conditions, and metadata was downloaded on 10/08/2024 through the PerturbAtlas API. This information was then used to retrieve the DEG download for each perturbation.

In [None]:
import pandas as pd
import datetime
import numpy as np
import scipy.spatial.distance as dist
import seaborn as sns
import os
import sys
from io import BytesIO
import json
import requests
import scanpy as sc
from tqdm import tqdm
from collections import OrderedDict

from sklearn.feature_extraction.text import TfidfVectorizer
import anndata
from collections import OrderedDict

# Bokeh
from bokeh.io import output_notebook
from bokeh.plotting import figure, show, save, output_file
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.palettes import Category20
output_notebook()

from IPython.display import display, HTML, Markdown
sys.setrecursionlimit(100000)

## Load and Pre-process Data
Perturbation metadata and associations can be retrieved from the [PerturbAtlas API](https://perturbatlas.kratoss.site/#/help). We first gathered a list of human perturbations, and then queried the API for differentially expressed genes for each perturbation, keeping associations with adjusted P-value < 0.05.

In [None]:
perturbatlasurl = 'https://perturbatlas.kratoss.site/api/download'
perturbations = []

for offset in tqdm(range(1, 44)):
    response = requests.get(f'{perturbatlasurl}?offset={offset}&length=200&eq=scientific_name:Homo%20sapiens')
    perts = response.json()['data']
    perturbations = perturbations + perts

In [None]:
len(perturbations)

In [None]:
pertframe = pd.DataFrame(perturbations)
pertframe.to_csv('perturbatlasperturbations.tsv', sep='\t')
pertframe = pertframe[pertframe['degs']].drop_duplicates('perturb_id')
pertframe

In [None]:
pertframe = pd.read_csv('perturbatlasperturbations.tsv', sep='\t', index_col=0)
pertframe = pertframe[pertframe['degs']].drop_duplicates('perturb_id')
pertframe['target'] = pertframe['target'].apply(lambda x: x
                                        .split('(')[0]
                                        .split(')')[0]
                                        .replace('+','')
                                        .replace('-','')
                                        .replace('*','')
                                        .replace('@',''))
pertframe

In [None]:
download_url='https://perturbatlas.kratoss.site/api/download?kind=degs'
pertids = pertframe['perturb_id'].unique().tolist()
pertids.reverse()

for pert_id in tqdm(pertids):
    if os.path.exists(f'degdir/{pert_id}_degs.csv'):
        continue
    response = requests.get(f'{download_url}&id={pert_id}')
    if (response.status_code==200):
        try:
            degs = pd.read_excel(BytesIO(response.content))
            degs = degs[degs['padj']<0.05].reset_index(drop=True)
            if len(degs)>0:
                degs.to_csv(f'degdir/{pert_id}_degs.csv')
        except Exception as e:
            print(f'Error processing {pert_id}')
    else:
        print(f'Failed to get DEGs for {pert_id}, HTTP Status: {response.status_code}')

In [None]:
'https://perturbatlas.kratoss.site/api/download?kind=degs&id=Perturb_7768'
pert_id='Perturb_7768'

In [None]:
response = requests.get(f'https://perturbatlas.kratoss.site/api/download?kind=degs&id={pert_id}')
if (response.status_code==200):
        try:
            degs = pd.read_excel(BytesIO(response.content))
            degs = degs[degs['padj']<0.05].reset_index(drop=True)
            if len(degs)>0:
                degs.to_csv('degdir/{pert_id}_degs.csv')
        except Exception as e:
            print(f'Error processing {pert_id}')
else:
    print(f'Failed to get DEGs for {pert_id}, HTTP Status: {response.status_code}')

In [None]:
degs = pd.read_excel(BytesIO(response.content))
degs = degs[degs['padj']<0.05].reset_index(drop=True)
if len(degs)>0:
    degs.to_csv(file_name)

In [None]:
perturbatlas = pd.DataFrame(columns=['perturb_id', 'gene', 'baseMean', 'log2FoldChange', 'lfcSE', 'stat', 'pvalue', 'padj'])

for pert in tqdm(os.listdir('degdir')):
    pertdegs = pd.read_csv(f'degdir/{pert}', index_col=0)
    perturbatlas = pd.concat([perturbatlas, pertdegs])

perturbatlas

In [None]:
perturbatlas.to_csv('degdir/All_Perturbations_degs.csv.gz', compression='gzip')

In [None]:
perturbatlas = pd.read_csv('degdir/All_Perturbations_degs.csv.gz', compression='gzip', index_col=0)
perturbatlas

In [None]:
perturbatlas.groupby('perturb_id')['gene'].count().mean()

In [None]:
1782.83*len(perturbatlas['perturb_id'].unique())

In [None]:
len(perturbatlas['gene'].unique())

### Clean and Harmonize Associations

In [None]:
mappingfile = pd.read_csv('../../mapping/mappingFiles/mappingFile_2024.tsv', sep='\t', index_col=0)
mappingfile['Synonyms'] = mappingfile['Synonyms'].astype(str).apply(str.upper)
mappingfile['Symbol'] = mappingfile['Symbol'].apply(str.upper)
mapping = mappingfile[mappingfile['#tax_id']==9606].set_index('Synonyms')['Symbol'].to_dict()
mappingfile

In [None]:
gene_info = pd.read_csv('../../mapping/source_files/human_gene_info', sep='\t')
gene_info = gene_info[gene_info['#tax_id']==9606][gene_info['type_of_gene']=='protein-coding']
gene_info['Symbol'] = gene_info['Symbol'].apply(str.upper)
gene_info

In [None]:
ensembl = gene_info.copy()
ensembl['dbXrefs'] = ensembl['dbXrefs'].apply(lambda x: x.split('|'))
ensembl = ensembl.explode('dbXrefs').reset_index()
ensembl = ensembl[ensembl['dbXrefs'].apply(lambda x: 'Ensembl' in x)].drop_duplicates('dbXrefs', keep='first')
ensembl['dbXrefs'] = ensembl['dbXrefs'].apply(lambda x: x.replace('Ensembl:', ''))
ensembldict = ensembl.set_index('dbXrefs')['Symbol'].apply(str.upper).to_dict()
ensembl

In [None]:
perturbatlas['gene'] = perturbatlas['gene'].map(ensembldict)
perturbatlas = perturbatlas.dropna()
perturbatlas

In [None]:
perturbatlas['gene'] = perturbatlas['gene'].map(mapping)
perturbatlas = perturbatlas.drop_duplicates(subset=['gene', 'perturb_id'])
perturbatlas

In [None]:
perturbatlas.groupby('perturb_id')['gene'].count().mean(), len(perturbatlas['gene'].unique())

In [None]:
def threshold(val):
    if val>0:
        return 1
    elif val<0:
        return -1
    return 0

In [None]:
perturbatlas['score'] = perturbatlas['log2FoldChange'].apply(threshold) * perturbatlas['padj'].apply(np.log10) * -1
perturbatlas['score'] = perturbatlas['score'].replace(np.inf, 20).replace(-np.inf, -20)
perturbatlas = perturbatlas.sort_values(['perturb_id', 'score', 'baseMean'], ascending=[True, False, False])
perturbatlas['threshold'] = perturbatlas['score'].apply(threshold)
perturbatlas

In [None]:
top = perturbatlas[perturbatlas['threshold']==1].groupby('perturb_id').head(100)
bottom = perturbatlas[perturbatlas['threshold']==-1].groupby('perturb_id').tail(100)
perturbatlas = pd.concat([top, bottom]).reset_index(drop=True)
perturbatlas

In [None]:
len(perturbatlas['perturb_id'].unique()), len(perturbatlas['gene'].unique()), len(perturbatlas)

## Prepare Dataset for Harmonizome DB

### Resource

In [None]:
#(id, name, long_description, short_description, url, num_attributes, num_datasets)
(115, 'PerturbAtlas', 'a comprehensive atlas of public genetic perturbation bulk RNA-seq datasets', 'a genetic perturbation atlas of bulk RNA-seq datasets', 'https://perturbatlas.kratoss.site/#/', 7418, 1)

### Dataset

In [None]:
#(id, name, name_without_resource, description, association, gene_set_description, gene_sets_description, attribute_set_description, positive_association, negative_association, is_signed, is_continuous_valued, last_updated, directory, num_page_views, resource_fk, measurement_fk, dataset_group_fk, attribute_type_fk, attribute_group_fk, evidence_type, evidence_group, measurement_bias, attribute_type_plural)
(159, 'PerturbAtlas Signatures of Differentially Expressed Genes for Gene Perturbations', 'Signatures of Differentially Expressed Genes for Gene Perturbations', 'Gene expression profiles for cell lines, cell types, tissues, and models following genetic perturbation (knockdown, knockout, knockin, over-expression, mutation, and multi-condition)', 'gene-gene perturbation associations by differential expression of gene A following perturbation of gene B', 'genes differentially expressed following the {0} gene perturbation from the PerturbAtlas Signatures of Differentially Expressed Genes for Gene Perturbations dataset.', 'sets of genes diffeentially expressed following gene perturbation from the PerturbAtlas Signatures of Differentially Expressed Genes for Gene Perturbations dataset.', 'gene perturbations changing expression of {0} gene from the PerturbAtlas Signatures of Differentially Expressed Genes for Gene Perturbations dataset.', 'increased expression', 'decreased expression', 1, 1, '2024-10-22', 'perturbatlas', 0, 115, 16, 7, 27, 5, 'gene expression by RNA-seq', 'curated experimental data', 'high throughput, data driven', 'gene perturbations')

### Publication

In [None]:
#(id, long_citation, short_citation, url, pmid, pubmed_url, first_author_last_name, first_author_initials, journal_abbreviation, year, title, volume, pages)
(157, 'Zhang et al. (2024) PerturbAtlas: a comprehensive atlas of public genetic perturbation bulk RNA-seq datasets. Nucleic Acids Res. gkae851', 'Zhang, Nucleic Acids Res, 2024', 'dx.doi.org/10.1093/nar/gkae851', 39351872, 'https://ncbi.nlm.nih.gov/pubmed/39351872', 'Zhang', 'Y', 'Nucleic Acids Res', 2024, 'PerturbAtlas: a comprehensive atlas of public genetic perturbation bulk RNA-seq datasets', 'gkae851', 'NaN')

#(id, dataset_fk, publication_fk)
(242, 159, 157)

### Genes

In [None]:
genes = pd.read_csv('../../tables/gene.csv')
genes = genes.drop_duplicates(subset='ncbi_entrez_gene_id', keep='first').set_index('ncbi_entrez_gene_id')
genes

In [None]:
geneids = pd.read_csv('../../mapping/mappingFiles/GeneSymbolsAndIDS_2024.tsv', sep='\t', index_col=0)
geneids = geneids[geneids['#tax_id']==9606]#.set_index('GeneID')['Symbol'].to_dict()
geneids['Symbol'] = geneids['Symbol'].apply(str.upper)
geneiddict = geneids.set_index('Symbol')['GeneID'].to_dict()
geneids

In [None]:
index = 58416
genefks = genes['id'].to_dict()
geneurl = 'https://ncbi.nlm.nih.gov/gene'
genedescs = gene_info.set_index('Symbol')['description'].to_dict()

for gene in perturbatlas['gene'].unique():
    geneid = geneiddict[gene]
    if geneid not in genes.index:
        print((index, gene, geneid, genedescs[gene], f'{geneurl}/{geneid}'), sep=',', end=',\n')
        genefks[geneid] = index
        index += 1

### Naming Authority

In [None]:
(109, 'PerturbAtlas', 'a comprehensive atlas of public genetic perturbation bulk RNA-seq datasets', 'https://perturbatlas.kratoss.site/?#/', 157)

### Attributes

In [None]:
pertframe['name'] = pertframe.apply(lambda x: f'{x.perturb_id}_{x.condition}_{x.target}_{x.study_accession}', axis=1)
pertframe['description'] = pertframe.apply(lambda x: f'{x.condition} perturbation targeting {x.target} in {x.tissue_name} {x.tissue_type} (Study accession: {x.study_accession})', axis=1)
pertnames = pertframe.set_index('perturb_id')['name'].to_dict()
pertdescs = pertframe.set_index('perturb_id')['description'].to_dict()
pertaccessions = pertframe.set_index('perturb_id')['study_accession'].to_dict()
perturbatlas['pert'] = perturbatlas['perturb_id'].map(pertnames)
attributefks = {}
index = 423271

for pert in perturbatlas['perturb_id'].unique():
    print((index, pertnames[pert], pert, pertdescs[pert], 109), sep=',', end=',\n')
    attributefks[pert] = index
    index += 1

### Gene Sets

In [None]:
#(id, name_from_dataset, description_from_dataset, dataset_fk, attribute_type_fk, attribute_fk)
index = 136000000
genesetfks = {}
accessions = pertframe

for pert in perturbatlas['perturb_id'].unique():
    print((index, pertnames[pert], pert, pertdescs[pert], 159, 27, attributefks[pert]), sep=',', end=',\n')
    genesetfks[pert] = index
    index += 1

### Associations

In [None]:
index = 46000000

associations = perturbatlas.copy()
associations['gene_fk'] = associations['gene'].apply(lambda x: genefks[geneiddict[x]])
associations['gene_set_fk'] = associations['perturb_id'].apply(lambda x: genesetfks[x])
associations = associations[['gene_fk', 'gene_set_fk', 'score', 'threshold']]
associations.columns = ['gene_fk','gene_set_fk', 'standardized_value', 'threshold_value']
associations.index += index
associations.to_csv('../../harmonizome-update/perturbatlas.csv')
associations

## Create Downloads

In [None]:
output_path = 'downloads/'

### Gene-Attribute Ternary Matrix

In [None]:
ternarymatrix = pd.crosstab(perturbatlas['gene'], perturbatlas['pert'], values=perturbatlas['threshold'], aggfunc=max).replace(np.nan, 0).astype(int)
ternarymatrixT = ternarymatrix.T
ternarymatrix.to_csv(output_path+'gene_attribute_matrix.txt.gz', sep='\t', compression='gzip')
ternarymatrix

### Gene-Attribute Edge List

In [None]:
edgelist = perturbatlas.copy()
edgelist['gene ID'] = edgelist['gene'].map(geneiddict)
#edgelist['pert'] = edgelist['perturb_id'].map(pertnames)
edgelist = edgelist[['gene', 'gene ID', 'pert', 'perturb_id', 'score', 'threshold']]
edgelist.columns = ['Gene', 'Gene ID', 'Perturbation', 'Perturbation ID', 'Standardized Value', 'Threshold Value']
edgelist.to_csv(f'{output_path}gene_attribute_edges.txt.gz', sep='\t', compression='gzip')
edgelist

### Gene List

In [None]:
genelist = edgelist[['Gene', 'Gene ID']].drop_duplicates().reset_index(drop=True)
genelist.to_csv(f'{output_path}gene_list_terms.txt.gz', sep='\t', compression='gzip')
genelist

### Attribute List

In [None]:
attributelist = edgelist[['Perturbation', 'Perturbation ID']].drop_duplicates().reset_index(drop=True)
attributelist.to_csv(f'{output_path}attribute_list_entries.txt.gz', sep='\t', compression='gzip')
attributelist

### Up Gene Set Library

In [None]:
with open(output_path+'gene_set_library_up_crisp.gmt', 'w') as f:
    arr = ternarymatrix.reset_index(drop=True).to_numpy(dtype=np.int_)
    attributes = ternarymatrix.columns

    w, h = arr.shape
    for i in tqdm(range(h)):
        if len([*ternarymatrix.index[arr[:, i] == 1]])>= 5:
            print(attributes[i], '', *ternarymatrix.index[arr[:, i] == 1], sep='\t', end='\n', file=f)

### Down Gene Set Library

In [None]:
with open(output_path+'gene_set_library_dn_crisp.gmt', 'w') as f:
    arr = ternarymatrix.reset_index(drop=True).to_numpy(dtype=np.int_)
    attributes = ternarymatrix.columns

    w, h = arr.shape
    for i in tqdm(range(h)):
        if len([*ternarymatrix.index[arr[:, i] == -1]])>= 5:
            print(attributes[i], '', *ternarymatrix.index[arr[:, i] == -1], sep='\t', end='\n', file=f)

### Up Attribute Set Library

In [None]:
with open(output_path+'attribute_set_library_up_crisp.gmt', 'w') as f:
    arr = ternarymatrixT.reset_index(drop=True).to_numpy(dtype=np.int_)
    genes = ternarymatrixT.columns

    w, h = arr.shape
    for i in tqdm(range(h)):
        if len([*ternarymatrixT.index[arr[:, i] == 1]])>= 5:
            print(genes[i], '', *ternarymatrixT.index[arr[:, i] == 1], sep='\t', end='\n', file=f)

### Down Attribute Set Library

In [None]:
with open(output_path+'attribute_set_library_dn_crisp.gmt', 'w') as f:
    arr = ternarymatrixT.reset_index(drop=True).to_numpy(dtype=np.int_)
    genes = ternarymatrixT.columns

    w, h = arr.shape
    for i in tqdm(range(h)):
        if len([*ternarymatrixT.index[arr[:, i] == -1]])>= 5:
            print(genes[i], '', *ternarymatrixT.index[arr[:, i] == -1], sep='\t', end='\n', file=f)

### Gene Similarity Matrix

In [None]:
gene_similarity_matrix = dist.pdist(ternarymatrix.to_numpy(dtype=np.int_), 'cosine')
gene_similarity_matrix = dist.squareform(gene_similarity_matrix)
gene_similarity_matrix = 1 - gene_similarity_matrix

gene_similarity_matrix = pd.DataFrame(data=gene_similarity_matrix, index=ternarymatrix.index, columns=ternarymatrix.index)
gene_similarity_matrix.index.name = None
gene_similarity_matrix.columns.name = None
gene_similarity_matrix.to_csv(output_path+'gene_similarity_matrix_cosine.txt.gz', sep='\t', compression='gzip')
gene_similarity_matrix

### Attribute Similarity Matrix

In [None]:
attribute_similarity_matrix = dist.pdist(ternarymatrixT.to_numpy(dtype=np.int_), 'cosine')
attribute_similarity_matrix = dist.squareform(attribute_similarity_matrix)
attribute_similarity_matrix = 1 - attribute_similarity_matrix

attribute_similarity_matrix = pd.DataFrame(data=attribute_similarity_matrix, index=ternarymatrixT.index, columns=ternarymatrixT.index)
attribute_similarity_matrix.index.name = None
attribute_similarity_matrix.columns.name = None
attribute_similarity_matrix.to_csv(output_path+'attribute_similarity_matrix_cosine.txt.gz', sep='\t', compression='gzip')
attribute_similarity_matrix

### Gene-Attribute Standardized Matrix

In [None]:
standardizedmatrix = pd.crosstab(perturbatlas['gene'], perturbatlas['pert'], values=perturbatlas['score'], aggfunc=max).replace(np.nan, 0)
standardizedmatrix.to_csv(output_path+'gene_attribute_matrix_standardized.txt.gz', sep='\t', compression='gzip')
standardizedmatrix

### Knowledge Graaph Serialization

In [None]:
nodes = {}
edges = []

for gene in genelist.index:
    gene = genelist.loc[gene]
    nodes[int(gene['Gene ID'])] = {
        "type":"gene",
        "properties": {
            "id":int(gene['Gene ID']),
            "label":gene['Gene']
        }}

for term in attributelist.index:
    term = attributelist.loc[term]
    nodes[term['Perturbation ID']] = {
        "type": 'gene perturbation',
        "properties": {
            "label":term['Perturbation'],
            "id":term['Perturbation ID']
        }}

for edge in edgelist.index:
    edge = edgelist.loc[edge]
    if edge['Threshold Value']==1:
        edges.append({
            "source": edge['Perturbation ID'],
            "relation": 'increases expression of',
            "target": int(edge['Gene ID']),
            "properties":{
                "id":edge['Perturbation ID']+":"+str(edge['Gene ID']),
                "source_id":edge['Perturbation ID'],
                "source_label":edge['Perturbation'],
                "target_id":int(edge['Gene ID']),
                "target_label":edge['Gene'],
                "directed":True,
                "standardized": edge['Standardized Value'],
                "threshold":int(edge['Threshold Value'])
            }})
    else:
        edges.append({
            "source": edge['Perturbation ID'],
            "relation": 'represses expression of',
            "target": int(edge['Gene ID']),
            "properties":{
                "id": edge['Perturbation ID']+":"+str(edge['Gene ID']),
                "source_id":edge['Perturbation ID'],
                "source_label": edge['Perturbation'],
                "target_id": int(edge['Gene ID']),
                "target_label": edge['Gene'],
                "directed":True,
                "standardized": edge['Standardized Value'],
                "threshold":int(edge['Threshold Value'])
            }})

### RDF

In [None]:
with open(f'{output_path}kg_serializations/perturbatlas.rdf', 'w') as f:
    #print('@prefix PerturbAtlas: ', file=f)
    print('@prefix RO: purl.obolibrary.org/RO_', file=f)
    print('@prefix gene: ncbi.nlm.nih.gov/gene/', file=f)
    print('', file=f)
    relations = {'increases expression of':'RO:0003003', 'represses expression of':'RO:0003002'}
    for edge in edges:
        print(
            'PerturbAtlas:'+edge['properties']['source_id'], 
            relations[edge['relation']], 
            'gene:'+str(edge['properties']['target_id']), end=' .\n', 
        file=f)

### JSON

In [None]:
with open(f'{output_path}kg_serializations/perturbatlas.json', 'w') as f:
    serial = json.dump(
        {
            "Version":"1", 
            "nodes": nodes,
            "edges": edges
        }, indent=4, fp=f)

### TSV

#### Nodes

In [None]:
nodeframe = pd.DataFrame(nodes).T
nodeframe['id'] = nodeframe['properties'].apply(lambda x: x['id'])
nodeframe['label'] = nodeframe['properties'].apply(lambda x: x['label'])
nodeframe['namespace'] = nodeframe['type'].apply(lambda x: {'gene':'NCBI Entrez', 'gene perturbation':'PerturbAtlas'}[x])
nodeframe = nodeframe.get(['namespace', 'id', 'label']).reset_index(drop=True)
nodeframe.to_csv(f'{output_path}kg_serializations/perturbatlas_tsv/nodes.tsv', sep='\t')
nodeframe

#### Edges

In [None]:
edgeframe = pd.DataFrame(edges)
edgeframe['standardized'] = edgeframe['properties'].apply(lambda x: x['standardized'])
edgeframe['threshold'] = edgeframe['properties'].apply(lambda x: x['threshold'])
edgeframe = edgeframe.get(['source', 'relation', 'target', 'standardized', 'threshold'])
edgeframe.to_csv(f'{output_path}kg_serializations/perturbatlas_tsv/edges.tsv', sep='\t')
edgeframe

## Create Visualizations

### Gene-Attribute Clustered Heatmap

In [None]:
sns.clustermap(ternarymatrix, cmap='seismic', center=0)

### Gene Similarity Clustered Heatmap

In [None]:
sns.clustermap(gene_similarity_matrix, cmap='seismic', center=0)

### Attribute Similarity Clustered Heatmap

In [None]:
sns.clustermap(attribute_similarity_matrix, cmap='seismic', center=0)

### UMAP

In [None]:
def load_gmt(file):
    gmt = OrderedDict()
    for line in file:
        term, *geneset = line.strip().split('\t')
        gmt[term+'_up'] = ' '.join(set(geneset))
    return gmt
libdict = load_gmt(open('downloads/gene_set_library_up_crisp.gmt', 'r'))
def load_gmt(file):
    gmt = OrderedDict()
    for line in file:
        term, *geneset = line.strip().split('\t')
        gmt[term+'_down'] = ' '.join(set(geneset))
    return gmt
downlibdict = load_gmt(open('downloads/gene_set_library_dn_crisp.gmt', 'r'))
libdict.update(downlibdict)
scatterdir = 'images/'

In [None]:
def process_scatterplot(libdict, nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1):
    print("\tTF-IDF vectorizing gene set data...")
    vec = TfidfVectorizer(max_df=maxdf, min_df=mindf)
    X = vec.fit_transform(libdict.values())
    print(X.shape)
    adata = anndata.AnnData(X)
    adata.obs.index = libdict.keys()

    print("\tPerforming Leiden clustering...")
    ### the n_neighbors and min_dist parameters can be altered
    sc.pp.neighbors(adata, n_neighbors=nneighbors, use_rep='X')
    sc.tl.leiden(adata, resolution=1.0)
    sc.tl.umap(adata, min_dist=mindist, spread=spread, random_state=42)

    new_order = adata.obs.sort_values(by='leiden').index.tolist()
    adata = adata[new_order, :]
    adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')

    df = pd.DataFrame(adata.obsm['X_umap'])
    df.columns = ['x', 'y']

    df['cluster'] = adata.obs['leiden'].values
    df['term'] = adata.obs.index
    df['genes'] = [libdict[l] for l in df['term']]

    return df

In [None]:
def get_scatter_colors(df):
    clusters = pd.unique(df['cluster']).tolist()
    colors = list(Category20[20])[::2] + list(Category20[20])[1::2]
    color_mapper = {clusters[i]: colors[i % 20] for i in range(len(clusters))}
    return color_mapper

def get_scatterplot(scatterdf):
    df = scatterdf.copy()
    color_mapper = get_scatter_colors(df)
    df['color'] = df['cluster'].apply(lambda x: color_mapper[x])

    hover_emb = HoverTool(name="df", tooltips="""
        <div style="margin: 10">
            <div style="margin: 0 auto; width:300px;">
                <span style="font-size: 12px; font-weight: bold;">Gene Set:</span>
                <span style="font-size: 12px">@gene_set</span>
            <div style="margin: 0 auto; width:300px;">
                <span style="font-size: 12px; font-weight: bold;">Coordinates:</span>
                <span style="font-size: 12px">(@x,@y)</span>
            <div style="margin: 0 auto; width:300px;">
                <span style="font-size: 12px; font-weight: bold;">Cluster:</span>
                <span style="font-size: 12px">@cluster</span>
            </div>
        </div>
    """)
    tools_emb = [hover_emb, 'pan', 'wheel_zoom', 'reset', 'save']

    plot_emb = figure(
        width=1000, 
        height=700, 
        tools=tools_emb
    )

    source = ColumnDataSource(
        data=dict(
            x = df['x'],
            y = df['y'],
            gene_set = df['term'],
            cluster = df['cluster'],
            colors = df['color'],
            label = df['cluster']
        )
    )

    # hide axis labels and grid lines
    plot_emb.xaxis.major_tick_line_color = None
    plot_emb.xaxis.minor_tick_line_color = None
    plot_emb.yaxis.major_tick_line_color = None
    plot_emb.yaxis.minor_tick_line_color = None
    plot_emb.xaxis.major_label_text_font_size = '0pt'
    plot_emb.yaxis.major_label_text_font_size = '0pt' 

    plot_emb.output_backend = "svg"    
    
    plot_emb.title = 'Gene Sets in the PerturbAtlas Signatures of Differentially Expressed Genes for Gene Perturbations Library'
    plot_emb.xaxis.axis_label = "UMAP_1"
    plot_emb.yaxis.axis_label = "UMAP_2"
    plot_emb.xaxis.axis_label_text_font_style = 'normal'
    plot_emb.xaxis.axis_label_text_font_size = '18px'
    plot_emb.yaxis.axis_label_text_font_size = '18px'
    plot_emb.yaxis.axis_label_text_font_style = 'normal'
    plot_emb.title.align = 'center'
    plot_emb.title.text_font_size = '18px'
    
    s = plot_emb.scatter(
        'x', 
        'y', 
        size = 4, 
        source = source, 
        color = 'colors'
    )
    
    return plot_emb

In [None]:
## defaults: nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1
scatter_df = process_scatterplot(libdict, 
     nneighbors=100,
     mindist=0.01,
     spread=3,
     maxdf=0.9,
     mindf=5,
)

# Display Scatter Plot
plot = get_scatterplot(scatter_df)
output_notebook()
show(plot)

In [None]:
output_file(filename=f"{scatterdir}/perturbatlas.html", title = 'Gene Sets in the PerturbAtlas Signatures of Differentially Expressed Genes for Gene Perturbations Library')
save(plot)