# Definition-based Gene Ontology Term Embedding<a id='top'></a>
**Sections:**<br>
[0) Description](#0)<br>
[1) Importing Modules and Packages](#1)<br>
[2) Configuration](#2)<br>
[3) GO Term Definitions](#3)<br>
[4) MEDLINE Bigrams](#4)<br>
[5) GO term Embeddings Computation](#5)<br>


## Description<a id='0'></a>

**Aim:** This jupyter notebook results in Gene Ontology term embeddings with the help of their textual definitions.<br><br>
The stopwords file and bigramwords of PubMed abstracts are already provided witht this project. Feel Free to replace them if your preference of other stopwords of bigramwords.<br>

_* Requirement:_ To make this work, you need to make sure you have access to large amount of memory (despite working with sparse matrices in the code).<br>


[back to top](#top)<br>

## Importing moduls<a id='1'></a>
[back to top](#top)<br>

In [None]:
import os
import requests
import re
import collections
import numpy as np
import pickle
import linecache

from scipy.sparse import csr_matrix, lil_matrix, coo_matrix, dok_matrix, vstack
from scipy.sparse.linalg import svds
from scipy.spatial.distance import cdist

import operator
import easydict
import pprint
import gzip

import datetime
from pytz import timezone

pp = pprint.PrettyPrinter(indent=4)

tz = timezone('US/Eastern')

## Configuration<a id='2'></a>
[back to top](#top)<br>

In [None]:
args = easydict.EasyDict({
    "go_dir": 'gene_ontology/raw/',             # directory to the Gene Ontology 'go.obo' file
    "go_embedding_dir": 'gene_ontology/embedding/',     # directory to the Gene Ontology embedding file
    "stopwords_dir": 'pubmed/Stoplist.dat',     # stopwords to drop uninformative bigrams
    "bigramwords_dir": 'pubmed/bigramWords.gz', # bigrams from PubMed abstracts
    "embedding_dims": [50, 100, 150, 200, 300], # list of embedding sizes
    "download_gene_ontology": True              # whether to download the latest release of Gene Ontology
})

In [None]:
with open(args.stopwords_dir, "r") as fr:
    stopwords = [x.split("\\b")[1]+"$" for x in fr.readlines()]#[0:1]
    pattern = re.compile("|".join(stopwords))

#### Utility

In [None]:
def remove_zero_columns(sparse_matrix):
    unique_columns = np.unique(np.nonzero(sparse_matrix)[1])
    return sparse_matrix[:, unique_columns]

## GO Term Definitions<a id='3'></a>
[back to top](#top)<br>

#### updating go.obo file ( default: True)

In [None]:
if args.download_gene_ontology:
    os.makedirs(args.go_dir, exist_ok=True)  # create 'data_loc' folder (if it does not exist already)
    print("Downloading the latest version of Gene Ontology into '{}'...".format(args.go_dir))
    url = 'http://current.geneontology.org/ontology/go.obo'
    r = requests.get(url, allow_redirects=True)
    open('{}/go.obo'.format(args.go_dir), 'wb').write(r.content)

print("Gene Ontology {}".format(linecache.getline('{}/go.obo'.format(args.go_dir), 2)))

#### reading go.obo file

In [None]:
"""Reading Gene Ontology to extract Terms and their Descriptive Names"""
with open("{}/go.obo".format(args.go_dir)) as f:
    content = f.readlines()
content = "".join([x for x in content])
content = content.split("[Typedef]")[0].split("[Term]")
print(content[-1])

#### list of unique attributes for one GO term

In [None]:
s = set()
for c in content:
    for l in c.split("\n"):
        if ": " in l: s.add(l.split(": ")[0])
print(s)

#### alt_id

In [None]:
go_term_dict = {}
for c in content:
    go_id = ''
    for l in c.split("\n"):
        # id
        if "id: GO:" in l[0:len("id: GO:")]:
            go_id = l.split("id: ")[1]
            go_term_dict[go_id] = {}
        # alt_id
        if "alt_id:" in l[0:len("alt_id")+1]:
            go_term_dict[go_id].setdefault("alt_id", []).append(l.split("alt_id: ")[1])
        # name
        if "name:" in l[0:len("name")+1]:
            go_term_dict[go_id]["name"] = l.split("name: ")[1]
        # synonym
        if "synonym:" in l[0:len("synonym")+1]:
            go_term_dict[go_id].setdefault("synonym",[]).append(l.split("synonym: ")[1].split('"')[1])
        # namespace
        if "namespace:" in l[0:len("namespace")+1]:
            go_term_dict[go_id]["namespace"] = l.split("namespace: ")[1]
        # def
        if "def:" in l[0:len("def")+1]:
            go_term_dict[go_id]["def"] = l.split("def: ")[1].split('" [')[0][1:]
        # is_a
        if "is_a:" in l[0:len("is_a")+1]:
            go_term_dict[go_id].setdefault("is_a", []).append(l.split("is_a: ")[1].split(" !")[0])
        # relationship
        if "relationship:" in l[0:len("relationship")+1]:
            # part_of
            if "relationship: part_of " in l[0:len("relationship: part_of ")+1]:
                go_term_dict[go_id].setdefault("part_of", []).append(l.split("relationship: part_of ")[1].split(" !")[0])
            # has_part
            if "relationship: has_part " in l[0:len("relationship: has_part ")+1]:
                go_term_dict[go_id].setdefault("has_part", []).append(l.split("relationship: has_part ")[1].split(" !")[0])
            # regulates
            if "relationship: regulates " in l[0:len("relationship: regulates ")+1]:
                go_term_dict[go_id].setdefault("regulates", []).append(l.split("relationship: regulates ")[1].split(" !")[0])
            if "relationship: positively_regulates " in l[0:len("relationship: positively_regulates ")+1]:
                go_term_dict[go_id].setdefault("regulates", []).append(l.split("relationship: positively_regulates ")[1].split(" !")[0])
            if "relationship: negatively_regulates " in l[0:len("relationship: negatively_regulates ")+1]:
                go_term_dict[go_id].setdefault("regulates", []).append(l.split("relationship: negatively_regulates ")[1].split(" !")[0])
            # happens_during
            if "relationship: happens_during " in l[0:len("relationship: happens_during ")+1]:
                go_term_dict[go_id].setdefault("happens_during", []).append(l.split("relationship: happens_during ")[1].split(" !")[0])
            # ends_during
            if "relationship: ends_during " in l[0:len("relationship: ends_during ")+1]:
                go_term_dict[go_id].setdefault("ends_during", []).append(l.split("relationship: ends_during ")[1].split(" !")[0])
            # occurs_in
            if "relationship: occurs_in " in l[0:len("relationship: occurs_in ")+1]:
                go_term_dict[go_id].setdefault("occurs_in", []).append(l.split("relationship: occurs_in ")[1].split(" !")[0])
        # is_obsolete
        if "is_obsolete:" in l[0:len("is_obsolete")+1]:
            go_term_dict[go_id]["is_obsolete"] = l.split("is_obsolete: ")[1]

#### Adding children (GO term X subsumes what other GO term(s))

In [None]:
for go_id in go_term_dict:
    parents = go_term_dict[go_id].get('is_a', [])
    for parent in parents:
        go_term_dict[parent].setdefault("subsumes", []).append(go_id)

In [None]:
for go_id in go_term_dict:
    regulated_gos = go_term_dict[go_id].get('regulates', [])
    for regulated_go in regulated_gos:
        go_term_dict[regulated_go].setdefault("regulated_by", []).append(go_id)

In [None]:
for go_id in go_term_dict:
    parts = go_term_dict[go_id].get('has_part', [])
    for part in parts:
        if go_id not in go_term_dict[part].get("part_of", []): 
            go_term_dict[part].setdefault("part_of", []).append(go_id)

In [None]:
for go_id in go_term_dict:
    parts = go_term_dict[go_id].get('part_of', [])
    for part in parts:
        if go_id not in go_term_dict[part].get("has_part", []): 
            go_term_dict[part].setdefault("has_part", []).append(go_id)

In [None]:
pp.pprint(go_term_dict['GO:0007127'])

#### Definition extension

In [None]:
definition_extension = {}

# direct definition
for go_id in list(go_term_dict):
    if not go_term_dict[go_id].get('is_obsolete', False):
        namespace = go_term_dict[go_id]['namespace']
        definition_extension.setdefault(namespace, {})
        definition_extension[namespace].setdefault(go_id, []).append(go_term_dict[go_id]['def'])

relationships = ['is_a', 'subsumes', 'part_of', 'has_part', 'regulates', 'regulated_by', 'happens_during', 'ends_during', 'occurs_in']
for relationship in relationships:
    for go_id in list(go_term_dict):
        if not go_term_dict[go_id].get('is_obsolete', False):
            namespace = go_term_dict[go_id]['namespace']
            for related_go_term in go_term_dict[go_id].get(relationship, []):
                definition_extension[namespace][go_id].append(go_term_dict[related_go_term]['def'])
                #print(go_term_dict[related_go_term]['def'])

In [None]:
print(len(definition_extension['cellular_component']))
print(definition_extension['cellular_component']['GO:0000015'])

In [None]:
definition_extension_info = {}
definition_extension_terms = set()
for sub_ontology in list(definition_extension):
    definition_extension_info[sub_ontology] = {}
    print(sub_ontology.upper())
    for row_num, go_id in enumerate(definition_extension[sub_ontology]):
        definition = " ".join(definition_extension[sub_ontology][go_id])
        definition = re.sub(r"[|\/\\\'\"%^*\[\](){}_~,.:;@#?!&$=<>+\-]+\ *", " ", definition) # replacing punctuations
        definition = re.sub(r"\d+th|\d+nd|\d+rd|\d+s|\d+d|\d+h|\d+m| . ", " ", definition)
        definition = definition.lower()
        definition = [term for term in definition.split() if not pattern.match(term)]
        definition = " ".join(definition) # replacing multiple whitespace
        definition = {term: definition.split().count(term) for term in set(definition.split())}
        definition_extension_info[sub_ontology][go_id] = definition
        definition_extension_terms.update(set(definition))

## MEDLINE Bigrams<a id='4'></a>

Link to PubMed dump: https://www.nlm.nih.gov/databases/download/data_distrib_main.html

[back to top](#top)<br>

In [None]:
index = 0
term_index_dict = {}
bigrams_dict = {}
frequency_cutoff = 0
start_time = datetime.datetime.now(tz)
former_iteration_endpint = start_time
print("Be patient, it may take some time ... (ETA: 20 minutes)")
print("Time started: {}".format(start_time.strftime("%Y-%m-%d %H:%M:%S")))
with gzip.open(args.bigramwords_dir, "rt") as fr:
    bigrams = fr.readlines()
    for bigram in bigrams:
        bigram = bigram.split("|")
        frequency = int(bigram[0])
        if frequency <= frequency_cutoff: continue
        #if frequency <= frequency_cutoff or bigram[1]!=bigram[1].encode('ascii', 'ignore').decode("utf-8"): continue
        terms = bigram[1].split()
        if terms[0] in definition_extension_terms or terms[1] in definition_extension_terms:
            if not re.match(pattern, terms[0]) and not re.match(pattern, terms[1]):
                if terms[0] not in term_index_dict:
                    term_index_dict[terms[0]] = index
                    index += 1
                if terms[1] not in term_index_dict:
                    term_index_dict[terms[1]] = index
                    index += 1
                bigrams_dict["{} {}".format(terms[0], terms[1])] = frequency
current_time = datetime.datetime.now(tz)
time_elapsed = current_time - start_time
print("Time current: {}".format(current_time.strftime("%Y-%m-%d %H:%M:%S")))
print("Time elapsed: {}".format(str(time_elapsed).split(".")[0]))

In [None]:
print("Number of bigrams:", len(bigrams_dict))
print("Number of unique terms:", len(term_index_dict))

In [None]:
start_time = datetime.datetime.now(tz)
former_iteration_endpint = start_time
print("Be patient, it may take some time ... (ETA: 6 minutes)")
print("Time started: {}".format(start_time.strftime("%Y-%m-%d %H:%M:%S")))
pubmed_bigram_matrix = lil_matrix((len(term_index_dict), len(term_index_dict)), dtype=np.float32)
for bigram in bigrams_dict:
    t1, t2 = bigram.split()
    indx_t1, indx_t2 = term_index_dict[t1], term_index_dict[t2]
    pubmed_bigram_matrix[indx_t1, indx_t2] = bigrams_dict[bigram]
    
pubmed_co_occurrence_matrix = pubmed_bigram_matrix + pubmed_bigram_matrix.transpose()
del(pubmed_bigram_matrix)
current_time = datetime.datetime.now(tz)
time_elapsed = current_time - start_time
print("Time current: {}".format(current_time.strftime("%Y-%m-%d %H:%M:%S")))
print("Time elapsed: {}".format(str(time_elapsed).split(".")[0]))

### GO term Embeddings Computation<a id='5'></a>

[back to top](#top)<br>

In [None]:
def pmi(matrix, smooth_val=1):
    if 0 < smooth_val: matrix = matrix.todense() + smooth_val
    matrix = matrix / np.sum(matrix)
    sc, sr = np.sum(matrix, 0), np.sum(matrix, 1)
    return np.log10(matrix / (sr * sc))

In [None]:
"""biological_process"""
print("FOC computation to do...")
first_order_matrix = lil_matrix((len(definition_extension_info['biological_process']), len(term_index_dict)), dtype=np.float32)
for row_num, go_id in enumerate(definition_extension_info['biological_process']):
    definition = definition_extension_info['biological_process'][go_id]
    #print(definition)
    num_of_content_words = np.sum(list(definition.values()))  # used for normalization
    for content_term in definition:
        if content_term in term_index_dict:
            first_order_matrix[row_num, term_index_dict[content_term]] = definition[content_term]/num_of_content_words
print("SOC computation to do...")
second_order_matrix_BP = first_order_matrix*pubmed_co_occurrence_matrix
print("second_order_matrix_BP shape:", second_order_matrix_BP.shape)

In [None]:
"""cellular_component"""
print("FOC computation to do...")
first_order_matrix = lil_matrix((len(definition_extension_info['cellular_component']), len(term_index_dict)), dtype=np.float32)
for row_num, go_id in enumerate(definition_extension_info['cellular_component']):
    definition = definition_extension_info['cellular_component'][go_id]
    #print(definition)
    num_of_content_words = np.sum(list(definition.values()))  # used for normalization
    for content_term in definition:
        if content_term in term_index_dict:
            first_order_matrix[row_num, term_index_dict[content_term]] = definition[content_term]/num_of_content_words
print("SOC computation to do...")
second_order_matrix_CC = first_order_matrix*pubmed_co_occurrence_matrix
print("second_order_matrix_CC shape:", second_order_matrix_CC.shape)

In [None]:
"""molecular_function"""
print("FOC computation to do...")
first_order_matrix = lil_matrix((len(definition_extension_info['molecular_function']), len(term_index_dict)), dtype=np.float32)
for row_num, go_id in enumerate(definition_extension_info['molecular_function']):
    definition = definition_extension_info['molecular_function'][go_id]
    #print(definition)
    num_of_content_words = np.sum(list(definition.values()))  # used for normalization
    for content_term in definition:
        if content_term in term_index_dict:
            first_order_matrix[row_num, term_index_dict[content_term]] = definition[content_term]/num_of_content_words
print("SOC computation to do...")
second_order_matrix_MF = first_order_matrix*pubmed_co_occurrence_matrix
print("second_order_matrix_MF shape:", second_order_matrix_MF.shape)

In [None]:
second_order_matrix_full = vstack([second_order_matrix_BP, second_order_matrix_CC, second_order_matrix_MF])

In [None]:
del(first_order_matrix, pubmed_co_occurrence_matrix,
    second_order_matrix_BP, second_order_matrix_CC, second_order_matrix_MF)

print("PMI to do... (Full SOC matrix size: {})".format(second_order_matrix_full.shape))
second_order_matrix_full = pmi(second_order_matrix_full)

print("Positive PMI to do...")
second_order_matrix_full[second_order_matrix_full < 0.0] = 0

print("Shrinking PMI to do (removing all-zero features)...")
column_index = [i for i, x in enumerate((np.sum(second_order_matrix_full, 0)!=0).tolist()[0]) if x]
second_order_matrix_full = second_order_matrix_full[:, column_index]
    
for embedding_dim in args.embedding_dims:
    
    print("SVD to do... (dimension: {})".format(embedding_dim))
    U, S, Vt = svds(second_order_matrix_full, embedding_dim)
    LSA1 = U*S
    LSA_BP = LSA1[:len(definition_extension_info['biological_process']), :]
    LSA_CC = LSA1[len(definition_extension_info['biological_process']):
                 len(definition_extension_info['biological_process'])+len(definition_extension_info['cellular_component']),
                        :]
    LSA_MF = LSA1[len(definition_extension_info['biological_process'])+len(definition_extension_info['cellular_component']):,
                        :]
    for sub_ontology in ['biological_process', 'cellular_component', 'molecular_function']:
        if sub_ontology=='biological_process': LSA = LSA_BP
        if sub_ontology=='cellular_component': LSA = LSA_CC
        if sub_ontology=='molecular_function': LSA = LSA_MF
        os.makedirs(args.go_embedding_dir, exist_ok=True)
        with open('{}/GO_{}_Embeddings_{}D.emb'.format(args.go_embedding_dir, "".join([s[0].upper() for s in sub_ontology.split("_")]), embedding_dim), "w") as fw:
            for go_term_it, embedding in zip(list(definition_extension[sub_ontology]), LSA):
                fw.write("{} {}\n".format(go_term_it, " ".join([str(i) for i in embedding])))


---