In [1]:
import os
import sys
from itertools import groupby

import numpy as np

import torch

sys.dont_write_bytecode = True
np.set_printoptions(precision=6, suppress=True)

from my_library import Database, ESM_Model, read_fasta, gzip_tensor, validate_fasta, validate_database


# Allocate resources

- **THREADS** : this variable will be passed to `torch`
- **DEVICE** : this will be passed to `torch`; allowed options include "cpu" and "cuda"

If using CUDA, we highly recommend monitoring the GPU memory usage while running this running this code `nvidia-smi -l 1`.

In [2]:
THREADS = 2
DEVICE  = 'cuda'

# Define input and ouput files

- **FASTA_FILE** : (input) fasta file containing protein sequences
- **DB_FILE** : (output) sqlite database for storing protein sequence embedding


In [3]:
# for the phosphatase dataset
FASTA_FILE = 'datasets/phosphatase/phosphatase.fa'
DB_FILE    = 'datasets/phosphatase/phosphatase.db'

# # for the kinase dataset
# FASTA_FILE = 'datasets/protein_kinase/kinase.fa'
# DB_FILE    = 'datasets/protein_kinase/kinase.db'

# # for the radical sam dataset
# FASTA_FILE = 'datasets/radical_sam/radicalsam.fa'
# DB_FILE    = 'datasets/radical_sam/radicalsam.db'

# Check the FASTA file

- All sequence headers and accessions must be unique.
- To avoid issues with the newick tree format, sequence headers cannot contain parenthesis, quotes, colons, or semicolons.
- If using the ESM language model, sequences cannot be longer than 1022 residues.

In [4]:
validate_fasta(FASTA_FILE)

validate_fasta : found 204 sequences in "datasets/phosphatase/phosphatase.fa"
validate_fasta : passed


True

# Load protein language model

In [5]:
try:
    encoder # if the model has already been loaded, do not reload
except:
    encoder = ESM_Model('esm1b_t33_650M_UR50S')

# Generate sequence embeddings

In [6]:
# ensure that an existing database does not exist
validate_database(DB_FILE)

# create a SQLite database to store the embeddings
db = Database(DB_FILE)
db.create_table(columns=[
    ('header',    'TEXT'),
    ('sequence',  'TEXT'),
    ('embedding', 'BLOB'),
])

# define function for generating, then gzipping the embeddings for storage
func_encode = lambda s: gzip_tensor(encoder.encode(s, device=DEVICE, threads=THREADS).type(torch.float16))

# iteratively run the function on each sequence
queue = ((h, s, func_encode(s)) for h, s in read_fasta(FASTA_FILE))
db.add_rows(('header', 'sequence', 'embedding'), queue)


ADDING ENTRIES : 204
