In [None]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from pathlib import Path
from collections import defaultdict
import urllib.request

In [None]:
path_to_mavedb_data = Path("/path/to/mavedb-dump.20241114101443")
with open(Path(path_to_mavedb_data, "main.json")) as handle:
    api_data = json.load(handle)

In [None]:
len(api_data['experimentSets'])

In [None]:
dump_date = api_data['asOf'].split('T')[0]
dump_date

In [None]:
total_experiments = 0
for eset in api_data['experimentSets']:
    total_experiments += len(eset['experiments'])
total_experiments

In [None]:
tax_id_counts = defaultdict(int)
for eset in api_data['experimentSets']:
    for exp in eset['experiments']:
        try:
            t = exp['scoreSets'][0]['targetGenes'][0]['targetSequence']['taxonomy']['taxId']
        except IndexError:
            print("no taxId found for", exp['scoreSets'][0]['urn'], "- substituting 9606 (human)")
            tax_id_counts[9606] += 1  # the one dataset missing a target somehow is human (PAX6)
        else:
            tax_id_counts[t] += 1

In [None]:
sum(tax_id_counts.values())

In [None]:
# get data for each tax_id in the dataset from NCBI so we can classify based on lineage
with urllib.request.urlopen("https://api.ncbi.nlm.nih.gov/datasets/v2alpha/taxonomy/taxon/" + ",".join(str(t) for t in tax_id_counts.keys())) as response:
    ncbi_taxonomy_data = json.load(response)

In [None]:
# insertion order here dictates the priority of term assignment (e.g. "other eukaryote" should be after any eukaryotic terms)
# this order also is used as the plotting order for the donut
aggregation_tax_ids = {
    9606: "human",
    7742: "other vertebrates",
    33208: "invertebrates",
    33090: "plants",
    4751: "fungi",
    2759: "other eukaryotes",
    2: "bacteria",
    2157: "archaea",
    10239: "viruses",
    81077: "artificial sequence",
}

In [None]:
tax_id_plot_categories = dict()
with open(f"tax_id_classifications_{dump_date}.tsv", "w") as handle:
    print("tax_id", "organism_name", "blast_name", "category", "count", sep="\t", file=handle)
    for t in ncbi_taxonomy_data['taxonomy_nodes']:
        t = t['taxonomy']
        tax_id = t['tax_id']
        category = None
        if tax_id in aggregation_tax_ids:  # the tax_id is human or artificial sequence
            category = aggregation_tax_ids[tax_id]
        else:
            for a in aggregation_tax_ids:  # look in the lineage for the classification
                if a in t['lineage']:
                    category = aggregation_tax_ids[a]
                    break
        if category is None:
            print(f"failed to classify {tax_id} ({t['organism_name']})")
        else:
            tax_id_plot_categories[tax_id] = category
            print(tax_id, t['organism_name'], t['blast_name'], category, tax_id_counts[tax_id], sep="\t", file=handle)


In [None]:
tax_id_counts_by_category = defaultdict(int)
for t, c in tax_id_plot_categories.items():
    tax_id_counts_by_category[c] += tax_id_counts[t]

In [None]:
tax_id_counts_by_category

In [None]:
# function that spreads the labels out to avoid overplotting
# source: https://stackoverflow.com/a/68779745
def fix_labels(mylabels, tooclose=0.1, sepfactor=2):
    vecs = np.zeros((len(mylabels), len(mylabels), 2))
    dists = np.zeros((len(mylabels), len(mylabels)))
    for i in range(0, len(mylabels)-1):
        for j in range(i+1, len(mylabels)):
            a = np.array(mylabels[i].get_position())
            b = np.array(mylabels[j].get_position())
            dists[i,j] = np.linalg.norm(a-b)
            vecs[i,j,:] = a-b
            if dists[i,j] < tooclose:
                mylabels[i].set_x(a[0] + sepfactor*vecs[i,j,0])
                mylabels[i].set_y(a[1] + sepfactor*vecs[i,j,1])
                mylabels[j].set_x(b[0] - sepfactor*vecs[i,j,0])
                mylabels[j].set_y(b[1] - sepfactor*vecs[i,j,1])

In [None]:
data = [tax_id_counts_by_category[s] for s in aggregation_tax_ids.values()]
labels = [f"{s.title()} ({c})" for s, c in zip(aggregation_tax_ids.values(), data)]

In [None]:
# set the font
font = {'family': 'Lato',
        'weight': 'normal',
        'size' : 15,}
mpl.rc('font', **font)

# create the figure
fig, ax = plt.subplots(figsize=(10, 6))

# draw the plot
wedges, text = plt.pie(data, labels=labels, wedgeprops={'linewidth': 1, 'edgecolor': 'white'}, colors=plt.cm.tab10.colors, startangle=0, counterclock=False)
center = plt.Circle((0,0), 0.66, color='white')
plt.gcf().gca().add_artist(center)
fix_labels(text, tooclose=0.15, sepfactor=2)
plt.tight_layout()
plt.savefig(f"taxid_donut_{dump_date}.pdf")
plt.savefig(f"taxid_donut_{dump_date}.png")