In [None]:
% load_ext autoreload
% autoreload 2

In [None]:
import pickle
from collections import defaultdict, Counter
from pathlib import Path

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import prody as pd
import pynndescent
from gensim.models import Word2Vec
from sklearn.metrics.pairwise import paired_distances
from tqdm.notebook import tqdm

from scripts import plotting

In [None]:
paper_dir = Path("data")
paper_dir.mkdir(exist_ok=True)
protein_dir = Path("data/proteins")

In [None]:
uniref_s_file = paper_dir / "uniref50_10_70_95_shapemers.txt"
uniref_i_file = paper_dir / "uniref50_10_70_95_indices.txt"
uniref_wv_file = paper_dir / "uniref50_10_70_95_word2vec.txt"
uniref_topic_file = paper_dir / "uniref50_casp12_nmf_400.pkl"

pdb_s_file = paper_dir / "pdb_chain_10_shapemers.txt"
pdb_i_file = paper_dir / "pdb_chain_10_indices.txt"
pdb_wv_file = paper_dir / "pdb_chain_10_word2vec.txt"
pdb_topic_file = paper_dir / "pdb_casp12_nmf_400.pkl"

swissprot_s_file = paper_dir / "swissprot_10_shapemers.txt"
swissprot_i_file = paper_dir / "swissprot_10_indices.txt"
swissprot_wv_file = paper_dir / "swissprot_10_word2vec.txt"

In [None]:
with open(paper_dir / "pdb_chain_word2vec_isolation_forest.pkl", "rb") as f:
    forest = pickle.load(f)

with open(paper_dir / "casp12_nmf_400_model.pkl", "rb") as f:
    vectorizer, topic_model = pickle.load(f)

word2vec = Word2Vec.load(str(paper_dir / "pdb_chain_word2vec_1024.model"), mmap='r')

In [None]:
with open(paper_dir / "pdb_uniref50_word2vec_embeddings_scores.pkl", "rb") as f:
    wv_keys, wv_embeddings, wv_scores = pickle.load(f)

In [None]:
uniref50_id_to_darkness = {}
uniref50_id_to_num = {}
uniref50_id_to_cluster = {}
with open("data/AFDBv3_UniRef50.csv") as f:
    for i, line in tqdm(enumerate(f)):
        if i == 0:
            continue
        parts = line.strip().split(",")
        key = f"{parts[-1]}-F1"
        if key in wv_keys:
            uniref50_id_to_darkness[key] = max(0., float(parts[5]))
            uniref50_id_to_num[key] = int(parts[-4])
            uniref50_id_to_cluster[key] = parts[0]

# Repeat proteins

In [None]:
repeat_protein_counts = defaultdict(set)
with open(uniref_s_file) as f:
    for line in tqdm(f):
        key, shapemers = line.strip().split("\t")
        shapemers = list(map(int, shapemers.split()))
        most_common_shapemer, most_common_count = Counter(shapemers).most_common(1)[0]
        fraction = most_common_count / len(shapemers)
        if fraction > 0.5:
            repeat_protein_counts[most_common_shapemer].add(key)

# SIFTS comparison

In [None]:
swissprot_lengths = {}
with open(swissprot_i_file) as f:
    for line in tqdm(f):
        key, indices = line.strip().split("\t")
        swissprot_lengths[key] = max(map(int, indices.split())) + 8

In [None]:
pdb_lengths = {}
with open(pdb_s_file) as f:
    for line in tqdm(f):
        key, shapemers = line.strip().split("\t")
        pdb_lengths[key] = len(shapemers.split()) + 16

In [None]:
pdb_keys = []
pdb_embeddings = []
swissprot_keys = []
swissprot_embeddings = []

with open(pdb_wv_file) as f:
    for line in tqdm(f):
        key, vector, score = line.strip().split("\t")
        pdb_keys.append(key)
        pdb_embeddings.append(list(map(float, vector.split())))

with open(swissprot_wv_file) as f:
    for line in tqdm(f):
        key, vector, score = line.strip().split("\t")
        swissprot_keys.append(key)
        swissprot_embeddings.append(list(map(float, vector.split())))

pdb_embeddings = np.array(pdb_embeddings)
swissprot_embeddings = np.array(swissprot_embeddings)

In [None]:
pdb_key_to_index = dict(zip(pdb_keys, range(len(pdb_keys))))
swissprot_key_to_index = dict(zip(swissprot_keys, range(len(swissprot_keys))))
swissprot_keys_set = set(swissprot_keys)
pdb_keys_set = set(pdb_keys)

In [None]:
pdb_to_uniprots = defaultdict(list)
with open("data/uniprot_segments_observed.tsv") as f:
    for i, line in tqdm(enumerate(f)):
        if i == 0:
            continue
        parts = line.strip().split("\t")
        pdb_key = f"{parts[0]}_{parts[1]}"
        swissprot_key = f"{parts[2]}-F1"
        if pdb_key not in pdb_keys_set or swissprot_key not in swissprot_keys_set:
            continue
        uniprot_start, uniprot_end = map(int, parts[-2:])
        pdb_start, pdb_end = map(int, parts[3:5])
        swissprot_length = swissprot_lengths[swissprot_key]
        pdb_length = pdb_lengths[pdb_key]
        if np.abs(swissprot_length - pdb_length) < 10:
            pdb_to_uniprots[pdb_key].append(f"{parts[2]}-F1")

In [None]:
indices_same = set()
indices_different = []
for k in tqdm(pdb_to_uniprots):
    for s in pdb_to_uniprots[k]:
        indices_same.add((pdb_key_to_index[k], swissprot_key_to_index[s]))
indices_same = np.array(list(indices_same))
for k, _ in indices_same:
    indices_different.append((k, np.random.randint(len(swissprot_keys))))
indices_different = np.array(indices_different)
indices_same.shape

In [None]:
distances_same = paired_distances(pdb_embeddings[indices_same[:, 0]],
                                  swissprot_embeddings[indices_same[:, 1]],
                                  metric="euclidean")
distances_different = paired_distances(pdb_embeddings[indices_different[:, 0]],
                                       swissprot_embeddings[indices_different[:, 1]],
                                       metric="euclidean")

In [None]:
threshold = 0.15
with plt.style.context('ipynb'):
    plt.figure(figsize=(10, 7))
    plt.hist(distances_different, bins=70, alpha=0.5, label="PDB chain and random AF structure")
    plt.hist(distances_same, bins=70, alpha=0.5, label="PDB chain and matching AF structure")
    plt.vlines(threshold, 0, 40000, color="black", label="distance=0.15")
    plt.legend()
    plt.savefig("data/figures/embedding_distance.png")

In [None]:
(np.where(distances_same < threshold)[0].shape[0] / len(distances_same),
 np.where(distances_different < threshold)[0].shape[0] / len(distances_different))

# Topic modelling

In [None]:
with open(pdb_topic_file, "rb") as f:
    (pdb_topic_keys, pdb_tfidf_matrix, pdb_w_matrix) = pickle.load(f)

with open(uniref_topic_file, "rb") as f:
    (uniref50_topic_keys, uniref50_tfidf_matrix, uniref50_w_matrix) = pickle.load(f)

In [None]:
with open(paper_dir / "topics_to_proteins.pkl", "rb") as f:
    topics_to_proteins = pickle.load(f)

In [None]:
key_to_topics = defaultdict(list)
for t in tqdm(topics_to_proteins):
    for key, score in topics_to_proteins[t]:
        key_to_topics[key].append((t, score))
for k in tqdm(key_to_topics):
    key_to_topics[k] = sorted(key_to_topics[k],
                              key=lambda x: x[1], reverse=True)

In [None]:
num_uniref = sum(1 for k in wv_keys if "-" in k)
num_af = sum(uniref50_id_to_num[k] for k in wv_keys if "-" in k)
num_pdb_chains = sum(1 for k in wv_keys if "-" not in k)
num_pdb = len(set(k.split("_")[0] for k in wv_keys if "-" not in k))
num_uniref, num_af, num_pdb_chains, num_pdb

In [None]:
topics = list(range(uniref50_w_matrix.shape[1]))
uniref_freqs = np.array([sum(1 for key, _ in topics_to_proteins[i] if "-" in key) / num_uniref for i in topics])
af_freqs = np.array([sum(uniref50_id_to_num.get(key, 0) for key, _ in topics_to_proteins[i]) / num_af for i in topics])
pdb_chain_freqs = np.array(
    [sum(1 for key, _ in topics_to_proteins[i] if "-" not in key) / num_pdb_chains for i in topics])
pdb_freqs = np.array(
    [len(set(key.split("_")[0] for key, _ in topics_to_proteins[i] if not "-" in key)) / num_pdb for i in topics])

In [None]:
diff = np.abs(af_freqs - pdb_chain_freqs)
sort_idx = np.argsort(diff)[::-1]

In [None]:
sorted_topics = [topics[i] for i in sort_idx]

In [None]:
example_keys_per_topic = {}
for i in range(3):
    topic = sorted_topics[i]
    protein_scores = sorted(topics_to_proteins[topic], key=lambda x: x[1], reverse=True)[:50]
    if not len(protein_scores):
        continue
    indices = np.linspace(0, len(protein_scores) - 1, 4, dtype=int)
    example_keys_per_topic[topic] = [protein_scores[x][0] for x in indices]

In [None]:
def get_shapemer_indices(query_keys):
    query_keys_values = set([q for q in query_keys])
    query_keys_indices = set([q for q in query_keys])
    per_key_values = {}
    per_key_indices = {}
    with open(uniref_s_file) as f:
        for line in f:
            key, shapemers = line.strip().split("\t")
            if key not in query_keys_values:
                continue
            per_key_values[key] = list(map(int, shapemers.split()))
            query_keys_values.remove(key)
            if not len(query_keys_values):
                break
    with open(uniref_i_file) as f:
        for line in f:
            key, indices = line.strip().split("\t")
            if key not in query_keys_indices:
                continue
            per_key_indices[key] = list(map(int, indices.split()))
            query_keys_indices.remove(key)
            if not len(query_keys_indices):
                break
    if len(query_keys_indices):
        with open(pdb_s_file) as f:
            for line in f:
                key, shapemers = line.strip().split("\t")
                if key not in query_keys_values:
                    continue
                per_key_values[key] = list(map(int, shapemers.split()))
                query_keys_values.remove(key)
                if not len(query_keys_values):
                    break
        with open(pdb_i_file) as f:
            for line in f:
                key, indices = line.strip().split("\t")
                if key not in query_keys_indices:
                    continue
                per_key_indices[key] = list(map(int, indices.split()))
                query_keys_indices.remove(key)
                if not len(query_keys_indices):
                    break
    return {k: dict(zip(per_key_indices[k], per_key_values[k])) for k in query_keys if k in per_key_indices}

In [None]:
shapemer_to_index = {int(k): v for k, v in vectorizer.vocabulary_.items()}

In [None]:
def get_shapemer_topic_scores(query_keys, topic):
    shapemer_indices = get_shapemer_indices(query_keys)
    return {k: {i: topic_model.components_[topic][shapemer_to_index[s]] for i, s in shapemer_indices[k].items() if
                s in shapemer_to_index} for k in query_keys}

In [None]:
shapemer_topic_scores = {}
for t in tqdm(example_keys_per_topic):
    shapemer_topic_scores[t] = get_shapemer_topic_scores(example_keys_per_topic[t], t)

In [None]:
topic_dir = Path(paper_dir / "topic_proteins")
topic_dir.mkdir(exist_ok=True)

In [None]:
for t in example_keys_per_topic:
    directory = topic_dir / f"{t}"
    directory.mkdir(exist_ok=True)
    for k in shapemer_topic_scores[t]:
        if "-" in k:
            key = k.split("-")[0]
            protein = pd.parseMMCIF(f"data/proteins/{key}-AF-v3.cif")
        else:
            key, chain = k.split("_")
            protein = pd.parseMMCIF(f"data/proteins/{key}.cif", chain=chain)
        protein = plotting.get_topic_scores(protein,
                                            shapemer_topic_scores[t][k])
        pd.writePDB(str(directory / f"{k}.pdb"), protein)

# Word2Vec

In [None]:
nn_index = pynndescent.NNDescent(wv_embeddings,
                                 n_jobs=100,
                                 verbose=True,
                                 low_memory=True)
with open(paper_dir / "pdb_uniref50_word2vec_embeddings_nn_index.pkl", "wb") as f:
    pickle.dump(nn_index, f)

In [None]:
neighbor_indices, neighbor_distances = nn_index.neighbor_graph

In [None]:
key_to_index = dict(zip(wv_keys, range(len(wv_keys))))

In [None]:
def annotate_graph(graph, node_keys):
    for n in tqdm(node_keys):
        k = wv_keys[n]
        graph.nodes[k]["darkness"] = float(uniref50_id_to_darkness.get(k, 100))
        graph.nodes[k]["outlier_score"] = float(wv_scores[k])
        graph.nodes[k]["isdark"] = int(uniref50_id_to_darkness.get(k, 100) <= 5)
        graph.nodes[k]["isbright"] = int(uniref50_id_to_darkness.get(k, 100) >= 99)
        graph.nodes[k]["isoutlier"] = int(wv_scores[k] < 0)
        graph.nodes[k]["ispdb"] = int("-" not in k)
        # graph.nodes[k]["kingdom"] = get_kingdom(k, data)
        # graph.nodes[k]["interpro"] = get_interpro(k, data)
        # graph.nodes[k]["IDP"] = float(get_idp(k, data))
        # graph.nodes[k]["CC"] = float(get_cc(k, data))
        # graph.nodes[k]["hastm"] = int(is_tm(k, data))
        # graph.nodes[k]["length"] = int(get_length(k, data))
    return graph

# Most populated UniRef50 clusters

In [None]:
per_cluster = defaultdict(list)
with open("AFDBv3_UniRef50_top_most_populated_clusters.csv") as f:
    for i, line in enumerate(f):
        if i == 0:
            continue
        parts = line.strip().split(",")
        if float(parts[0]) >= 95:
            per_cluster[parts[-1]].append(f"{parts[1]}-F1")

In [None]:
clusters = list(per_cluster.keys())
cluster_counts = np.array([len(per_cluster[c]) for c in clusters])
cluster_indices = np.argsort(cluster_counts)[::-1]

key_to_cluster = {}
for c in cluster_indices:
    for k in per_cluster[clusters[c]]:
        key_to_cluster[k] = clusters[c]

In [None]:
per_cluster_word2vec = defaultdict(list)
per_cluster_keys = defaultdict(list)
with open(paper_dir / 'uniref50_10_top_word2vec.txt') as f:
    for i, line in tqdm(enumerate(f)):
        key, vector = line.strip().split("\t")
        if key not in key_to_cluster:
            continue
        per_cluster_keys[key_to_cluster[key]].append(key)
        per_cluster_word2vec[key_to_cluster[key]].append(list(map(float, vector.split())))
for c in tqdm(per_cluster_word2vec):
    per_cluster_word2vec[c] = np.array(per_cluster_word2vec[c])

In [None]:
def get_components(keys, matrix):
    u_nn_index = pynndescent.NNDescent(matrix,
                                       n_jobs=100,
                                       verbose=False,
                                       low_memory=True)
    u_neighbor_indices, u_neighbor_distances = u_nn_index.neighbor_graph
    num = 30
    u_graph = nx.Graph()
    u_graph.add_nodes_from(keys)
    for i, key in enumerate(keys):
        for j, distance in zip(u_neighbor_indices[i][:num],
                               u_neighbor_distances[i][:num]):
            if j == i or distance >= 0.1:
                continue
            u_graph.add_edge(key, keys[j])
    return u_graph, list(nx.connected_components(u_graph))

In [None]:
id_to_length = {}
with open(paper_dir / 'uniref50_10_top_indices.txt') as f:
    for line in tqdm(f):
        key, vector = line.strip().split("\t")
        id_to_length[key] = max(map(int, vector.split())) + 16

In [None]:
for index in cluster_indices:
    cluster = clusters[index]
    length_range = np.array([id_to_length[c] for c in per_cluster_keys[cluster]])
    max_length = length_range.max()
    keep_indices = [i for i in range(len(per_cluster_keys[cluster])) if length_range[i] > max(100,
                                                                                              max_length - 200)]
    if not len(keep_indices):
        continue
    u_graph, u_c = get_components([per_cluster_keys[cluster][i] for i in keep_indices],
                                  per_cluster_word2vec[cluster][keep_indices])
    if len(u_c) > 1:
        max_n = max(len(x) for x in u_c)
        print(cluster, len(u_c), f"{100 * max_n / len(u_graph):.2f}")
        for x in u_c:
            print(len(x), list(x)[0],
                  f"{np.mean([id_to_length[k] for k in x]):.2f}",
                  f"{np.std([id_to_length[k] for k in x]):.2f}")
        print()


## Structural outliers

In [None]:
num = 4
outlier_edge_keys = set()
outlier_node_keys = set()
outlier_keys = []
for i, key in tqdm(enumerate(wv_keys)):
    if key in repeat_protein_counts[370]:
        continue
    if wv_scores[key] < 0:
        outlier_keys.append(key)
        for j, distance in zip(neighbor_indices[i][:num],
                               neighbor_distances[i][:num]):
            if j == i or distance >= 0.15 or wv_keys[j] in repeat_protein_counts[370]:
                continue
            outlier_edge_keys.add((i, j, distance))
            outlier_node_keys.add(i)
            outlier_node_keys.add(j)

In [None]:
outlier_graph = nx.Graph()
outlier_graph.add_nodes_from([wv_keys[n] for n in outlier_node_keys])
outlier_graph.add_edges_from([(wv_keys[i], wv_keys[j], dict(weight=d)) for i, j, d in outlier_edge_keys])

In [None]:
outlier_graph = annotate_graph(outlier_graph, outlier_node_keys)
centralities = nx.degree_centrality(outlier_graph)
outlier_components = list(nx.connected_components(outlier_graph))

In [None]:
representatives = []
for c in outlier_components:
    key, centrality, len_c = sorted([(k, centralities[k], len(c)) for k in c], key=lambda x: x[1], reverse=True)[0]
    num_pdb = sum(1 for k in c if "-" not in k)
    representatives.append((key, len_c, num_pdb, uniref50_id_to_darkness.get(key, 100)))

Choose examples from representatives

# Dark proteins

In [None]:
num = 4
edge_keys = set()
node_keys = set()
dark_keys = []
for i, key in tqdm(enumerate(wv_keys)):
    if key in repeat_protein_counts[370]:
        continue
    if uniref50_id_to_darkness.get(key, 100) <= 5:
        dark_keys.append(key)
        for j, distance in zip(neighbor_indices[i][:num],
                               neighbor_distances[i][:num]):
            if j == i or distance >= 0.15 or wv_keys[j] in repeat_protein_counts[370]:
                continue
            edge_keys.add((i, j, distance))
            node_keys.add(i)
            node_keys.add(j)

In [None]:
graph = nx.Graph()
graph.add_nodes_from([wv_keys[n] for n in node_keys])
graph.add_edges_from([(wv_keys[i], wv_keys[j]) for i, j, d in edge_keys])

In [None]:
graph = annotate_graph(graph, node_keys)
components = [graph.subgraph(c).copy() for c in nx.connected_components(graph)]
component_indices = np.argsort([len(c) for c in components])[::-1]

In [None]:
subgraph = components[component_indices[0]]
for n in component_indices[1:]:
    if len(components[n]) > 100:
        subgraph = nx.compose(subgraph, components[n])

In [None]:
nx.write_gml(subgraph, "word2vec_dark_graph.gml")