In [27]:
import sys
sys.path = list(set([
    "../../lib/",
]+sys.path))
import json
import pandas as pd
import requests
from local.caching import load, save, DictCache, save_exists
from local.web import ncbi_search, chain_get
from local.constants import WORKSPACE_ROOT

In [15]:
df = pd.read_csv("../../data/positive_selection_summary.tsv", sep="\t", header=None)
df.columns = "org, p, org2?, category, cog, desc".split(", ")
print(df.shape)
df_notnull = df[~df.desc.isna()]
df_notnull

(239, 6)


Unnamed: 0,org,p,org2?,category,cog,desc
1,GC_00002175,0.000000**,GC_00002175,Core,COG0025|COG0569,NhaP-type Na+/H+ or K+/H+ antiporter (NhaP) (P...
2,GC_00002333,0.000000**,GC_00002333,Core,COG1337,CRISPR-Cas system type III CSM-effector comple...
3,GC_00002431,0.000000**,GC_00002431,Core,COG3463,Uncharacterized membrane protein
4,GC_00002482,0.000000**,GC_00002482,Core,COG0192,S-adenosylmethionine synthetase (MetK) (PDB:1FUG)
6,GC_00002720,0.000000**,GC_00002720,Core,COG0458,Carbamoylphosphate synthase large subunit (Car...
...,...,...,...,...,...,...
234,GC_00001543,0.046876*,GC_00001543,Core,COG1807,"PMT family glycosyltransferase ArnT/Agl22, inv..."
235,GC_00001870,0.048071*,GC_00001870,Core,COG2211,Na+/melibiose symporter or related transporter...
236,GC_00002653,0.048373*,GC_00002653,Core,COG1058|COG1546,ADP-ribose pyrophosphatase domain of DNA damag...
237,GC_00002850,0.048454*,GC_00002850,Aux_HF,COG0594,RNase P protein component (RnpA) (PDB:1A6F)


In [19]:
with open(WORKSPACE_ROOT.joinpath("secrets/openai_key")) as s:
    OPENAI_KEY = s.readline().replace("\n", "").strip()
def get_embedding(entry: str):
    MAX_L = 8191
    if len(entry) > MAX_L:
        e = entry[:MAX_L]
        print(f"truncated to {MAX_L} from {len(entry)}")
    else:
        e = entry

    with DictCache("ada_embeddings") as cache:
        if e in cache:
            return cache[e]
        else:
            r = requests.post(
                url="https://api.openai.com/v1/embeddings",
                headers={
                    "Content-Type": "application/json",
                    "Authorization": f"Bearer {OPENAI_KEY}",
                },
                json={
                    "model": "text-embedding-ada-002",
                    "input": e,
                }
            )
            data = r.json()
            if r.status_code == 200:
                cache[e] = data
            return data
        
embeddings = []
for i, (j, row) in enumerate(df_notnull.iterrows()):
    print(f"\r{i} of {len(df_notnull)}   ", end="")
    text = row.desc
    # print(text)
    d = get_embedding(text)
    embeddings.append((row.org, d))

182 of 183   

In [26]:
import numpy as np
from sklearn.manifold import TSNE

mappings_x = np.array([d["data"][0]["embedding"] for o, d in embeddings], dtype=np.float64)
mappings_x.shape

(183, 1536)

In [28]:
import numpy as np
from sklearn.manifold import TSNE

mname = "cy_pos_sel"

save_name = f"latent-{mname}-{'_'.join(str(v) for v in mappings_x.shape)}"
regen = False
# regen = True

R = 80
if not regen and save_exists(save_name):
    latentx = load(save_name)
else:
    rand_seed = 36
    model = TSNE(n_components=2, random_state=rand_seed, perplexity=30)
    latentx = model.fit_transform(mappings_x)
    a, b = latentx.min(axis=0), latentx.max(axis=0)
    middle = (a+b)/2
    scale = (b-a)/2
    latentx[:, 0] -= middle[0]
    latentx[:, 1] -= middle[1]
    latentx[:,0] /= scale[0]
    latentx[:,1] /= scale[1]
    latentx *= R*0.90
    
    save(save_name, latentx)
latentx.shape

compressing & caching data to [{WORKSPACE}/main/scratch/cache/latent-cy_pos_sel-183_1536.pkl.gz]


(183, 2)

In [32]:
from typing import Any
from plotly import graph_objects as go, subplots as sp

def divide_chunks(l, n):
    for i in range(0, len(l), n):
        yield l[i:i + n]
        
# settings
axis_col = 'rgba(0, 0, 0, 0.15)'
no_col = 'rgba(0, 0, 0, 0)'
axis_desc: dict = dict(linecolor=no_col, gridcolor=axis_col, zerolinecolor=axis_col, zerolinewidth=1)
layout = dict(
    autosize=False,
    width=800,
    height=800,
    margin=dict(
        l=25, r=25, b=25, t=25, pad=5
    ),
    # paper_bgcolor="white",
    font_family="Times New Roman",
    font_color="black",
    font_size=20,
    plot_bgcolor='white',
    xaxis=axis_desc,
    yaxis=axis_desc,
    xaxis2=axis_desc,
    yaxis2=axis_desc,
)

fig = sp.make_subplots(
    rows=1, cols=1, shared_xaxes=True, shared_yaxes=True, horizontal_spacing=0.02,
    # x_title="% Completeness"
)

def _lines(s: str):
    return '<br>'.join(' '.join(c) for c in divide_chunks(s.split(' '), 12))

s, o = 5, 0.7
fig.add_trace(
    go.Scatter(
        x = [x for i, (x, y) in enumerate(latentx)],
        y = [y for i, (x, y) in enumerate(latentx)],
        mode='markers',
        marker=dict(
            size=s,
            color='#3679c6',
            opacity=o
        ),
        showlegend=False,
        text=[f"{row.org}<br>{_lines(row.desc)}" for i, row in df_notnull.iterrows()],
    ),
    row=1, col=1,
)

_layout: dict[Any, Any] = layout.copy()
_layout.update(dict(
    hoverlabel=dict(
        # bgcolor="rgba(0, 0, 0, 0.2)",
        font_size=12,
        # font_family="Rockwell"
    ),
    xaxis=dict(title="", range=(-R, R), **axis_desc),
    yaxis=dict(title="", range=(-R, R), **axis_desc),
))
fig.update_annotations(font_size=10)
fig.update_layout(go.Layout(**_layout))
fig.show()

SyntaxError: f-string: Generator expression must be parenthesized (2178729157.py, line 47)