In [None]:
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from hdbscan import HDBSCAN
from umap import UMAP
from common import data_dir, GMT

In [None]:
df = pd.read_csv("../data/rummagenexrummageo.csv")
df

In [None]:
import json
from common import data_dir, cached_urlretrieve, maybe_tqdm

(data_dir/'Enrichr').mkdir(parents=True, exist_ok=True)

cached_urlretrieve(
  'https://maayanlab.cloud/Enrichr/datasetStatistics',
  data_dir/'Enrichr'/'datasetStatistics.json'
)
with (data_dir/'Enrichr'/'datasetStatistics.json').open('r') as fr:
  datasetStatistics = json.load(fr)
datasetStatistics

for library in maybe_tqdm(datasetStatistics['statistics'], desc='Downloading Enrichr database...'):
  cached_urlretrieve(
    f"https://maayanlab.cloud/Enrichr/geneSetLibrary?mode=text&libraryName={library['libraryName']}",
    data_dir/'Enrichr'/(library['libraryName']+'.gmt')
  )

In [None]:
from common import data_dir, GMT, maybe_tqdm


with (data_dir/'enrichr.gmt').open('w') as fw:
  for gene_set_library in maybe_tqdm((data_dir/'Enrichr').glob('*.gmt'), desc='Processing enrichr libraries...'):
    for (term, _desc), genes in maybe_tqdm(GMT.reader(gene_set_library), desc=f"Processing {gene_set_library}..."):
      print(
        gene_set_library.stem,
        term,
        *genes,
        sep='\t',
        file=fw,
      )

In [None]:
import json
import pandas as pd
from common import data_dir, cached_urlretrieve, maybe_tqdm

organism = 'Mammalia/Homo_sapiens'

def maybe_split(record):
  ''' NCBI Stores Nulls as '-' and lists '|' delimited
  '''
  if record in {'', '-'}:
    return set()
  return set(record.split('|'))

def supplement_dbXref_prefix_omitted(ids):
  ''' NCBI Stores external IDS with Foreign:ID while most datasets just use the ID
  '''
  for id in ids:
    # add original id
    yield id
    # also add id *without* prefix
    if ':' in id:
      yield id.split(':', maxsplit=1)[1]

cached_urlretrieve(
  f"ftp://ftp.ncbi.nih.gov/gene/DATA/GENE_INFO/{organism}.gene_info.gz",
  data_dir/f"{organism}.gene_info.gz"
)
ncbi_genes = pd.read_csv(data_dir/f"{organism}.gene_info.gz", sep='\t', compression='gzip')
ncbi_genes['All_synonyms'] = [
  set.union(
    maybe_split(gene_info['Symbol']),
    maybe_split(gene_info['Symbol_from_nomenclature_authority']),
    maybe_split(str(gene_info['GeneID'])),
    maybe_split(gene_info['Synonyms']),
    maybe_split(gene_info['Other_designations']),
    maybe_split(gene_info['LocusTag']),
    set(supplement_dbXref_prefix_omitted(maybe_split(gene_info['dbXrefs']))),
  )
  for _, gene_info in maybe_tqdm(ncbi_genes.iterrows())
]
synonyms, symbols = zip(*{
  (synonym, gene_info['Symbol'])
  for _, gene_info in maybe_tqdm(ncbi_genes.iterrows())
  for synonym in gene_info['All_synonyms']
})
ncbi_lookup = pd.Series(symbols, index=synonyms)
index_values = ncbi_lookup.index.value_counts()
ambiguous = index_values[index_values > 1].index
ncbi_lookup_disambiguated = ncbi_lookup[(
  (ncbi_lookup.index == ncbi_lookup) | (~ncbi_lookup.index.isin(ambiguous))
)]
ncbi_lookup = ncbi_lookup_disambiguated.to_dict()

with (data_dir / 'lookup.json').open('w') as fw:
  json.dump(ncbi_lookup, fw)

In [None]:
import sys
from common import GMT, maybe_tqdm, gene_lookup

input_file = './data/enrichr.gmt'
output_file = 'data/enrichr-clean.gmt'

# Open the input and output files
with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
    for (term, desc), genes in maybe_tqdm(GMT.reader(infile), desc='Cleaning gmt...'):
        # Map genes to their canonical forms
        genes_mapped = {
            gene_mapped
            for gene in genes
            for gene_mapped in (gene_lookup(gene),)
            if gene_mapped
        }
        # Skip gene sets with fewer than 5 mapped genes
        if len(genes_mapped) < 5:
            continue
        
        # Write the cleaned term, description, and genes to the output file
        outfile.write(f"{term}\t{desc}\t" + "\t".join(genes_mapped) + "\n")


In [None]:
from tqdm import tqdm

def csv_to_gmt(df):
    batch_file = "data/rummageogene_500k.gmt"
    dfe = df.sort_values(by=["p-value", "odds"], ascending=[True, False])
    dfe = dfe.head(500000)
    dfe.to_csv("data/rummageogene_top_500k.csv", index=False)
    with open(batch_file, 'w') as gmtfile:
        for index, row in dfe.iterrows():
            identifier = row["rummagene"] + ";" + row["rummageo"]
            desc = "N/A"
            # Combine the row elements with tabs
            genes = row["overlaps"].split(";")
            gmt_row = [identifier, desc] + genes
            gmtfile.write('\t'.join(gmt_row) + '\n')

csv_to_gmt(df)



In [None]:
random_state = 42

In [None]:
print('Loading Enrichr GMT...')
enrichr_gmt = GMT.from_file(data_dir/'enrichr-clean.gmt')

In [None]:
print('Loading Rummageogene GMT...')
rummageogene_gmt = GMT.from_file('data/rummageogene_500k.gmt')

In [None]:
print('Collecting metadata...')
meta = pd.DataFrame(
  [
    { 'source': library, 'term': term }
    for library, term in enrichr_gmt.terms
  ] + [
    { 'source': 'rummagenexrummageo', 'term': term }
    for term, desc in rummageogene_gmt.terms
  ]
)

In [None]:
print('Computing IDF...')
vectorizer = TfidfVectorizer(analyzer=lambda gs: gs)
vectors = vectorizer.fit_transform(enrichr_gmt.gene_lists + rummageogene_gmt.gene_lists)

In [None]:
print('Computing SVD...')
svd = TruncatedSVD(n_components=50, random_state=random_state)
svs = svd.fit_transform(vectors)


In [None]:
print('Computing UMAP...')
umap = UMAP(random_state=random_state, low_memory=True)
embedding = umap.fit_transform(svs)


In [None]:
print('Computing outliers...')
x = embedding[:, 0]
y = embedding[:, 1]
x_min, x_mu, x_std, x_max = np.min(x), np.mean(x), np.std(x), np.max(x)
x_lo, x_hi = max(x_min, x_mu - x_std*1.68), min(x_max, x_mu + x_std*1.68)
y_min, y_mu, y_std, y_max = np.min(y), np.mean(y), np.std(y), np.max(y)
y_lo, y_hi = max(y_min, y_mu - y_std*1.68), min(y_max, y_mu + y_std*1.68)
outlier = (x>=x_lo)&(x<=x_hi)&(y>=y_lo)&(y<=y_hi)


In [None]:
print('Saving joint-umap...')
meta['UMAP-1'] = x
meta['UMAP-2'] = y
meta['outlier'] = (~outlier).astype(int)
meta.to_csv(data_dir / 'joint-umap.tsv', sep='\t', index=False)



In [None]:
meta = pd.read_csv(data_dir / 'joint-umap.tsv', sep='\t')
meta

In [None]:
print('Computing Cluster UMAP...')
cluster_umap = UMAP(
  n_neighbors=30,
  min_dist=0.0,
  n_components=2,
  random_state=random_state,
  low_memory=True,
)
cluster_embedding = cluster_umap.fit_transform(svs)

print('Computing Clusters...')
labels = HDBSCAN(
    min_samples=10,
    min_cluster_size=500,
).fit_predict(cluster_embedding)

x = cluster_embedding[:, 0]
y = cluster_embedding[:, 1]
meta['UMAP-1'] = x
meta['UMAP-2'] = y
meta['cluster'] = labels
meta.to_csv(data_dir / 'joint-umap-cluster.tsv', sep='\t')

In [None]:
import json
import glasbey
import numpy as np
import pandas as pd
import pathlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

from common import data_dir

random_state = 42

fig_dir = pathlib.Path('figures')
fig_dir.mkdir(parents=True, exist_ok=True)

In [None]:
meta = pd.concat([
  pd.read_csv(data_dir / 'joint-umap.tsv', sep='\t', index_col=0),
  pd.read_csv(data_dir / 'joint-umap-cluster.tsv', sep='\t', index_col=1)[['cluster']],
], axis=1)
meta

In [None]:
df = pd.read_csv("data/rummageogene_top_500k.csv", usecols=["rummagene", "rummageo", "species"])
df["term"] = df["rummagene"] + ";" + df["rummageo"]
df = df[["term", "species"]]
df

In [None]:
meta = meta.reset_index()
merged_df = pd.merge(meta, df, on='term', how='left')
merged_df

In [None]:

merged_df['source'] = merged_df.apply(
    lambda row: f"{row['source']}-{row['species']}" if pd.notna(row['species']) else row['source'],
    axis=1
)
meta = merged_df.drop(columns='species')
meta

In [None]:
with (data_dir/'Enrichr'/'datasetStatistics.json').open('r') as fr:
  datasetStatistics = json.load(fr)

In [None]:
categories = {cat['categoryId']: cat['name'] for cat in datasetStatistics['categories']}
library_categories = {lib['libraryName']: categories[lib['categoryId']] for lib in datasetStatistics['statistics']}
library_categories['rummagenexrummageo'] = 'RummagenexRummaGEO'
library_categories['rummagenexrummageo-human'] = 'RummageneXhumanRummaGEO'
library_categories['rummagenexrummageo-mouse'] = 'RummageneXmouseRummaGEO'


meta['category'] = meta['source'].apply(library_categories.get)
meta

In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

_ = meta[(meta['category'] != 'RummageneXhumanRummaGEO') & (meta['category'] != 'RummageneXmouseRummaGEO')]

cat = 'category'
cats = _[cat].unique()
color_pallete = dict(zip(cats, glasbey.create_palette(len(cats))))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), dpi=300)  # 1 row, 2 columns

for label, data in _.groupby(cat):
    ax1.scatter(
        x=data['UMAP-1'],
        y=data['UMAP-2'],
        s=0.1,  
        color=color_pallete[label],
        alpha=0.1,
        rasterized=True,
    )
ax1.set_xlabel('UMAP-1', fontdict=dict(size=24))
ax1.set_ylabel('UMAP-2', fontdict=dict(size=24))
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_title('Enrichr Only', fontsize=24)

_ = meta
cat = 'category'
cats = _[cat].unique()
color_pallete = dict(zip(cats, glasbey.create_palette(len(cats))))

for label, data in _.groupby(cat):
    ax2.scatter(
        x=data['UMAP-1'],
        y=data['UMAP-2'],
        s=0.1,  
        color=color_pallete[label],
        alpha=0.1,
        rasterized=True,
    )

lgd = ax2.legend(handles=[
    Line2D([0], [0], marker='o', color='w', label=f"{label} ({int((_[cat] == label).sum()):,})",
           markerfacecolor=color_pallete[label], markersize=10)
    for label in cats
], loc='center left', bbox_to_anchor=(1, 0.5), fontsize=16)

ax2.set_xlabel('UMAP-1', fontdict=dict(size=24))
ax2.set_xticks([])
ax2.set_yticks([])
ax2.set_title('Enrichr + RummagenexRummaGEO', fontsize=24)
plt.tight_layout()
plt.savefig(f"{fig_dir}/enrichr_rummageogene_combined.png", dpi=300)
plt.savefig(f"{fig_dir}/enrichr_rummageogene_combined.pdf", dpi=300)
plt.show()

In [None]:
meta["color"] = meta['category'].apply(color_pallete.get)
meta

In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

fig = plt.figure(figsize=(15, 8), dpi=300)

for label, data in _.groupby(cat):
    ax2.scatter(
        x=data['UMAP-1'],
        y=data['UMAP-2'],
        s=0.1,  
        color=color_pallete[label],
        alpha=0.1,
        rasterized=True,
    )

for label, data in _.groupby(cat):
    plt.scatter(
        x=data['UMAP-1'],
        y=data['UMAP-2'],
        s=0.1, 
        color=color_pallete[label],
        alpha=0.1,
        rasterized=True,
    )

plt.legend(handles=[
    Line2D([0], [0], marker='o', color='w', label=f"{label} ({int((_[cat] == label).sum()):,})",
           markerfacecolor=color_pallete[label], markersize=8)
    for label in cats
], loc='center left', bbox_to_anchor=(1, 0.5), fontsize=12)

plt.xlabel('UMAP-1', fontdict=dict(size=24))
plt.ylabel('UMAP-2', fontdict=dict(size=24))
plt.xticks([])
plt.yticks([])
plt.title('Enrichr + RummagenexRummaGEO', fontsize=24)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# Filter for unique points (non-duplicated UMAP-1, UMAP-2 pairs)
unique_data = _.loc[~_.duplicated(subset=['UMAP-1', 'UMAP-2'], keep=False)]

# Create a figure
fig = plt.figure(figsize=(15, 8), dpi=300)

# Scatter plot for unique points (non-overlapping)
for label, data in unique_data.groupby(cat):
    plt.scatter(
        x=data['UMAP-1'],
        y=data['UMAP-2'],
        s=0.1,  
        color=color_pallete[label],
        alpha=0.8,  # Slightly higher alpha to make unique points more visible
        rasterized=True,
    )

# Add legend
plt.legend(handles=[
    Line2D([0], [0], marker='o', color='w', label=f"{label} ({int((unique_data[cat] == label).sum()):,})",
           markerfacecolor=color_pallete[label], markersize=8)
    for label in cats
], loc='center left', bbox_to_anchor=(1, 0.5), fontsize=12)

# Set labels and title
plt.xlabel('UMAP-1', fontdict=dict(size=24))
plt.ylabel('UMAP-2', fontdict=dict(size=24))
plt.xticks([])
plt.yticks([])
plt.title('Enrichr + RummagenexRummaGEO', fontsize=24)

# Adjust layout and show the plot
plt.tight_layout()
plt.show()
