Skip to content

Commit

Permalink
Adds init files everywhere
Browse files Browse the repository at this point in the history
removes  readthedocs reqs because conda handles it now
  • Loading branch information
Justin Sybrandt committed Apr 29, 2020
1 parent 305861a commit ba1e5b1
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 39 deletions.
11 changes: 5 additions & 6 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@ sphinx:
# Optionally build your docs in additional formats such as PDF and ePub
formats: all

# Optionally set the version of Python and requirements required to build your docs
# The conda env will bring along protobuf and python version
conda:
environment: docs/environment.yml

# Version managed by conda
python:
version: 3.8
install:
- requirements: docs/requirements.txt
- method: pip
path: .

# We do not require git submodules to process this
submodules:
exclude: all

# The conda env will bring along protobuf
conda:
environment: docs/environment.yml
4 changes: 1 addition & 3 deletions agatha/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
__VERSION__ = '20.04.27'

from . import ml
__VERSION__ = "2020.04.29"
Empty file added agatha/construct/__init__.py
Empty file.
Empty file.
Empty file added agatha/ml/__init__.py
Empty file.
Empty file added agatha/ml/util/__init__.py
Empty file.
Empty file added agatha/topic_query/__init__.py
Empty file.
Empty file added agatha/util/__init__.py
Empty file.
63 changes: 55 additions & 8 deletions agatha/util/sqlite3_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,16 @@ def create_lookup_table(
_DEFAULT_KEY_COLUMN_NAME="key"
_DEFAULT_VALUE_COLUMN_NAME="value"
class Sqlite3LookupTable():
"""
Gets values from an Sqlite3 Table called where keys are
strings and values are json encoded.
f"""Dict-like interface for Sqlite3 key-value tables
Assumes that the provided sqlite3 path has a table containing string keys and
json-encoded string values. By default, the table name is
{_DEFAULT_TABLE_NAME}, with columns {_DEFAULT_KEY_COLUMN_NAME} and
{_DEFAULT_VALUE_COLUMN_NAME}.
This interface is pickle-able, and provides caching and preloading. Note that
instances of this object that are recovered from pickles will _NOT_ retain the
preloading or caching information from the original.
"""
def __init__(
self,
Expand All @@ -117,6 +124,21 @@ def __init__(
value_column_name:str=_DEFAULT_VALUE_COLUMN_NAME,
disable_cache:bool=False
):
"""A Dict-like interface for Sqlite3 key-value tables
Creates an Sqlite3LookupTable cheaply. Caching starts empty and preloading
can be enabled later with `.preload()`. Constructor establishes database
file handler, and establishes optimized read-only access.
Args:
db_path: The file-system location of the Sqlite3 file.
table_name: The sql table name to find within `db_path`.
key_column_name: The string column of `table_name`. Performance of the
Sqlite3LookupTable will depend on whether an index has been created on
`key_column_name`.
value_column_name: The json-encoded string column of `table_name`
"""
self.table_name = table_name
self.key_column_name = key_column_name
self.value_column_name = value_column_name
Expand Down Expand Up @@ -149,6 +171,7 @@ def __setstate__(self, state):
self._connect()

def is_preloaded(self)->bool:
"True if database has been loaded to memory."
if not self.connected():
return False
"""
Expand All @@ -163,8 +186,12 @@ def is_preloaded(self)->bool:
return self._cursor.execute("PRAGMA database_list").fetchone()[2] == ""

def preload(self)->None:
"""
Copies the content of the database to memory
"""Copies the database to memory.
This is done by dumping the contents of disk into ram, and _does not_
perform any json parsing. This improves performance because now sqlite3
calls do not have to travel to storage.
"""
assert self.connected()
if not self.is_preloaded():
Expand All @@ -176,6 +203,7 @@ def preload(self)->None:
self._set_db_flags()

def connected(self)->bool:
"True if the database connection has been made."
return self._connection is not None

def _assert_schema(self)->None:
Expand Down Expand Up @@ -249,13 +277,16 @@ def _connect(self)->None:
self._set_db_flags()

def clear_cache(self)->None:
"Removes contents of internal cache"
self._cache.clear()

def disable_cache(self)->None:
"Disables the use of internal cache"
self.clear_cache()
self._use_cache = False

def enable_cache(self)->None:
"Enables the use of internal cache"
self._use_cache = True

def _query(self, key:str)->Optional[Any]:
Expand Down Expand Up @@ -293,6 +324,14 @@ def __contains__(self, key:str)->bool:
return value_or_none is not None

def keys(self)->Set[str]:
"""Get all keys from the Sqlite3 Table.
Recalls _all_ keys from the connected database. This operation may be slow
or even infeasible for larger tables.
Returns:
The set of all keys from the connected database.
"""
assert self.connected(), "Attempting to operate on closed db."
query = self._cursor.execute(
f"""
Expand All @@ -303,6 +342,7 @@ def keys(self)->Set[str]:
return set(r[0] for r in query.fetchall())

def __len__(self)->int:
"Returns the number of entries in the connected database."
if self._len is None:
assert self.connected(), "Attempting to operate on closed db."
self._len = self._cursor.execute(
Expand All @@ -318,10 +358,12 @@ def __len__(self)->int:
## SPECIAL CASES ###############################################################
################################################################################

# These special cases are added for backwards compatibility. Custom table, key
# and column names are potentially used on old data sources.

class Sqlite3Bow(Sqlite3LookupTable):
"""
for backwards compatibility, Sqlite3Bow allows for alternate default table,
key, and value names. However, newer tables following the default
Sqlite3LookupTable schema will still work.
"""
def __init__(
self,
db_path:Path,
Expand All @@ -338,6 +380,11 @@ def __init__(
)

class Sqlite3Graph(Sqlite3LookupTable):
"""
for backwards compatibility, Sqlite3Graph allows for alternate default table,
key, and value names. However, newer tables following the default
Sqlite3LookupTable schema will still work.
"""
def __init__(
self,
db_path:Path,
Expand Down
89 changes: 85 additions & 4 deletions docs/help/train_agatha.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,87 @@
How to train the Agatha Deep-Learning Model
-------------------------------------------
How to Train Agatha
===================

Training the Agatha deep learning model is the last step to generating
hypotheses after you've already processed all necessary information using
`agatha.construct`. This process uses [PyTorch][1] and [PyTorch-Lightning][2] to
efficiently manage the distributed training of the predicate ranking model
stored in `agatha.ml.hypothesis_predictor`.

This is a quick tutorial on how to train the Agatha deep-learning model.
Hopefully, this will be relatively painless to write...
## Background

The Agatha deep learning model learns to rank entity-pairs. To learn this
ranking, we will be comparing existing predicates found within our dataset
against randomly sampled entity-pairs. Of course, if a predicate exists in our
database, it should receive a higher model output than many random pairs.

A `positive sample` is a entity-pair that actually occurs in our dataset. A
`negative sample` is one of those non-existent randomly sampled pairs. We will
use the [margin ranking loss][3] criteria to learn to associate higher values
with positive samples. To do this, we will compare one positive sample to a high
number of negative samples. This is the `negative-sampling rate`.

A single sample, be it positive or negative, is comprised of four parts:

1. Term 1 (the subject).
2. Term 2 (the object).
3. Predicates associated with term 1 (but not term 2).
4. Predicates associated with term 2 (but not term 1).

This as a whole is referred to as a `sample`. Generating samples is the primary
bottleneck in the training process. This is because we have many millions of
terms and predicates. As a result, the Agatha deep learning framework comes
along with a number of utilities to make managing the large datasets easier.

## Datasets

In order to begin training you will need the following data:

1. Embeddings for all entities and predicates, stored as a directory of `.h5`
files.
2. Entity metadata, stored as a `.sqlite3` file.
2. The subgraph containing all entity-predicate edges, stored as a `.sqlite3`
file.

The network construction process will produce these datasets as `sqlite3` files.
[Sqlite][4] is an embedded database, meaning that we can load the database from
storage and don't need to spin up a whole server. Additionally, because we are
only going to _read_ and never going to _write_ to these databases during
training, each machine in our distributed training cluster can have independent
access to the same data very efficiently.

All of the sqlite3 databases managed by Agatha are stored in a simple format
that enables easy python access through the
`agatha.util.sqlite3_lookup.Sqlite3LookupTable` object. This provides read-only
access to the database through a dictionary-like interface.

For instance, if we want to get the neighbors for the node `l:noun:cancer`, we
can simply write this code:

```python3
from agatha.util.sqlite3_lookup import Sqlite3LookupTable
graph = Sqlite3LookupTable("./data./releases/2020/graph.sqlite3")
graph["l:noun:cancer"]
# Returns:
# ... [
# ... < List of all neighbors >
# ... ]
```

This process works by first making an sqlite3 connection to the graph database
file. By default, we assume that this database contains a table called
`lookup_table` that has the schema: `(key:str, value:str)`. Values in this
database are all json-encoded. This means that calling `graph[foo]` looks up
the value associated with "foo" in the database, and parses whatever it find
through `json.loads(...)`.

This process is slow compared to most other operations in the training pipeline.
Each query has to check against the sqlite `key` index, which is stored on disk,
load the `value`, also stored on disk, and then parse the string. There are two
optimizations that make this faster: preloading and caching. Look into the API
documentation for more detail


[1]:https://pytorch.org/
[2]:https://github.com/PytorchLightning/pytorch-lightning
[3]:https://pytorch.org/docs/stable/nn.html#torch.nn.MarginRankingLoss
[4]:https://www.sqlite.org/index.html
16 changes: 0 additions & 16 deletions docs/requirements.txt

This file was deleted.

4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from distutils.command.build_py import build_py as _build_py
from distutils.command.clean import clean as _clean
from distutils.spawn import find_executable
from agatha import __VERSION__
from setuptools import setup, Extension, find_packages
from setuptools.command.install import install as _install
import os
import subprocess
import sys
import agatha


proto_src_files = [
Expand Down Expand Up @@ -84,7 +84,7 @@ def run(self):

setup(
name='Agatha',
version=__VERSION__,
version=agatha.__VERSION__,
author="Justin Sybrandt",
author_email="jsybran@clemson.edu",
description=("Automatic Graph-mining And Transformer based "
Expand Down

0 comments on commit ba1e5b1

Please sign in to comment.