In [None]:
% load_ext autoreload
% autoreload 2

In [None]:
import itertools
import logging
import pickle
from multiprocessing import Pool
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from gensim.models import Word2Vec
from kneed import KneeLocator
from sklearn.decomposition import NMF
from sklearn.ensemble import IsolationForest
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.notebook import tqdm

In [None]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

In [None]:
uniref_s_file = "data/uniref50_10_70_95_shapemers.txt"
uniref_i_file = "data/uniref50_10_70_95_indices.txt"
uniref_wv_file = "data/uniref50_10_70_95_word2vec.txt"
uniref_topic_file = "data/uniref50_casp12_nmf_400.pkl"

pdb_s_file = "data/pdb_chain_10_shapemers.txt"
pdb_i_file = "data/pdb_chain_10_indices.txt"
pdb_wv_file = "data/pdb_chain_10_word2vec.txt"
pdb_topic_file = "data/pdb_casp12_nmf_400.pkl"

swissprot_s_file = "data/v2_shapemers/swissprot_10_shapemers.txt"
swissprot_i_file = "data/v2_shapemers/swissprot_10_indices.txt"
swissprot_wv_file = "data/swissprot_10_word2vec.txt"

# Topic modelling

In [None]:
corpus_file = Path("data/v2_shapemers/casp12_10_shapemers.txt")
keys_corpus = (line.strip().split("\t") for line in tqdm(open(corpus_file)) if len(line.strip().split("\t")) == 2)
keys, corpus = itertools.tee(keys_corpus)
keys = [k[0] for k in keys]
corpus = (k[1] for k in corpus)

In [None]:
vectorizer = TfidfVectorizer(min_df=0)
tfidf_matrix = vectorizer.fit_transform(corpus)

In [None]:
num_topics = 1000
topic_model = NMF(n_components=num_topics,
                  random_state=42,
                  solver='cd', tol=0.001,
                  alpha=.1,
                  l1_ratio=.5,
                  verbose=1)
w_matrix = topic_model.fit_transform(tfidf_matrix)

In [None]:
with open(f"data/casp12_nmf_{num_topics}_model.pkl", "wb") as f:
    pickle.dump((vectorizer, topic_model), f)

In [None]:
uniref50_corpus_file = uniref_s_file
keys_corpus = (line.strip().split("\t") for line in tqdm(open(uniref50_corpus_file)) if
               len(line.strip().split("\t")) == 2)
uniref50_keys, uniref50_corpus = itertools.tee(keys_corpus)
uniref50_keys = [k[0] for k in uniref50_keys]
uniref50_corpus = (k[1] for k in uniref50_corpus)
uniref50_tfidf_matrix = vectorizer.transform(uniref50_corpus)
uniref50_w_matrix = topic_model.transform(uniref50_tfidf_matrix)
with open(f"data/uniref50_casp12_nmf_{num_topics}.pkl", "wb") as f:
    pickle.dump((uniref50_keys, uniref50_tfidf_matrix, uniref50_w_matrix), f)

In [None]:
pdb_corpus_file = pdb_s_file
keys_corpus = (line.strip().split("\t") for line in tqdm(open(pdb_corpus_file)) if len(line.strip().split("\t")) == 2)
pdb_keys, pdb_corpus = itertools.tee(keys_corpus)
pdb_keys = [k[0] for k in pdb_keys]
pdb_corpus = (k[1] for k in pdb_corpus)
pdb_tfidf_matrix = vectorizer.transform(pdb_corpus)
pdb_w_matrix = topic_model.transform(pdb_tfidf_matrix)
with open(f"data/pdb_casp12_nmf_{num_topics}.pkl", "wb") as f:
    pickle.dump((pdb_keys, pdb_tfidf_matrix, pdb_w_matrix), f)

In [None]:
with open(pdb_topic_file, "rb") as f:
    t_pdb_keys, _pdb_tfidf_matrix, pdb_w_matrix = pickle.load(f)

with open(uniref_topic_file, "rb") as f:
    t_uniref_keys, _uniref_tfidf_matrix, uniref_w_matrix = pickle.load(f)

t_keys = t_pdb_keys + t_uniref_keys
with open("data/pdb_uniref50_topic_w_matrix.pkl", "wb") as f:
    pickle.dump((t_keys, np.vstack((pdb_w_matrix, uniref_w_matrix))), f)

In [None]:
def get_sorted_proteins(keys, idx, matrix, S=2, plot=False):
    kn = KneeLocator(np.arange(matrix.shape[0]),
                     np.sort(matrix[:, idx])[::-1],
                     S=S,
                     curve='convex',
                     direction='decreasing')
    if plot:
        kn.plot_knee()
        plt.show()
    if kn.knee_y is None:
        return [], np.zeros(0)
    values = matrix[:, idx]
    indices = np.argsort(values)[::-1]
    indices = [i for i in indices if values[i] > kn.knee_y]
    if plot:
        print(len(indices),
              sum(1 for i in indices if "-" not in keys[i]),
              sum(1 for i in indices if "-" in keys[i]))
    return [keys[i] for i in indices], values[indices]

In [None]:
topics_to_proteins = {}
for topic_id in tqdm(range(w_matrix.shape[1])):
    topics_to_proteins[topic_id] = []
    if np.sum(w_matrix[:, topic_id]) == 0:
        continue
    protein_ids, scores = get_sorted_proteins(t_keys, topic_id,
                                              w_matrix,
                                              S=4, plot=False)
    if scores.sum() > 0:
        topics_to_proteins[topic_id] = list(zip(list(protein_ids), list(scores)))

In [None]:
with open("data/topics_to_proteins.pkl", "wb") as f:
    pickle.dump(topics_to_proteins, f)

# Word2Vec

In [None]:
corpus_folder = Path("data/corpus_10")
corpus_folder.mkdir(exist_ok=True)

In [None]:
pdb_corpus_file = "data/pdb_chain_10_shapemers.txt"

In [None]:
corpus_sentences_file = corpus_folder / "pdb_chain_corpus_sentences.txt"
corpus_keys_file = corpus_folder / "pdb_chain_corpus_keys.txt"

In [None]:
with open(corpus_sentences_file, "w") as s_f:
    with open(corpus_keys_file, "w") as k_f:
        with open(pdb_corpus_file) as f:
            for line in tqdm(f):
                parts = line.strip().split("\t")
                if len(parts) != 2:
                    continue
                key, shapemers = parts
                k_f.write(f"{key}\n")
                s_f.write(f"{shapemers}\n")

In [None]:
word2vec = Word2Vec(vector_size=1024, window=16, min_count=2, workers=32)
word2vec.build_vocab(corpus_file=str(corpus_sentences_file))
word2vec.train(
    corpus_file=str(corpus_sentences_file), epochs=word2vec.epochs,
    total_examples=word2vec.corpus_count, total_words=word2vec.corpus_total_words,
)
word2vec.save(str(corpus_folder / "pdb_chain_word2vec_1024.model"))

# Isolation Forest

In [None]:
word2vec = Word2Vec.load(str(corpus_folder / "pdb_chain_word2vec_1024.model"), mmap='r')

In [None]:
pdb_keys = []
with open(pdb_corpus_file) as f:
    for line in tqdm(f):
        key, _ = line.strip().split("\t")
        pdb_keys.append(key)

In [None]:
word_vectors = word2vec.wv.get_normed_vectors()


def infer_vector_worker(document):
    vector = np.array(
        [word_vectors[word2vec.wv.key_to_index[y]] for y in document if y in word2vec.wv.key_to_index]).mean(axis=0)
    return vector


with open(pdb_corpus_file) as f:
    with Pool(processes=100) as pool:
        pdb_word2vec = list(tqdm(pool.imap(infer_vector_worker,
                                           (list(line.strip().split("\t")[1].split()) for line in f)),
                                 total=len(pdb_keys)))
pdb_word2vec = np.array(pdb_word2vec)

In [None]:
forest = IsolationForest(n_jobs=100, verbose=True, contamination=0.05)
forest.fit(pdb_word2vec)

In [None]:
pdb_scores = forest.decision_function(pdb_word2vec)

In [None]:
with open("data/pdb_chain_word2vec_isolation_forest.pkl", "wb") as f:
    pickle.dump(forest, f)

In [None]:
def write_vector_score_worker(document):
    key, shapemers = document.strip().split("\t")
    vector = np.array(
        [word_vectors[word2vec.wv.key_to_index[y]] for y in shapemers.split() if y in word2vec.wv.key_to_index]).mean(
        axis=0)
    score = forest.decision_function([vector])[0]
    return key, " ".join(str(s) for s in vector), score


for filename in [uniref_s_file, pdb_s_file, swissprot_s_file]:
    with open(filename) as f:
        total = sum(1 for line in f)
    with Pool(processes=100) as pool:
        with open(filename) as f:
            lines = (line for line in f)
            with open(Path("data") / (filename.stem.split("_shapemers")[0] + "_word2vec.txt"), 'w') as f1:
                for key, vector, score in tqdm(pool.imap(write_vector_score_worker, lines),
                                               total=total):
                    f1.write(f"{key}\t{vector}\t{score}\n")

In [None]:
wv_keys = []
wv_embeddings = []
wv_scores = {}

with open(pdb_wv_file) as f:
    for line in tqdm(f):
        key, vector, score = line.strip().split("\t")
        wv_keys.append(key)
        wv_embeddings.append(list(map(float, vector.split())))
        wv_scores[key] = float(score)

with open(uniref_wv_file) as f:
    for line in tqdm(f):
        key, vector, score = line.strip().split("\t")
        wv_keys.append(key)
        wv_embeddings.append(list(map(float, vector.split())))
        wv_scores[key] = float(score)

with open("data/pdb_uniref50_word2vec_embeddings_scores.pkl", "wb") as f:
    pickle.dump((wv_keys, np.array(wv_embeddings), wv_scores), f)