In [1]:
from random import randint
import random
import sys, os
from joblib import Parallel, delayed
from Bio import SeqIO
from p_bloom_filter import encode
from p_database import dotproduct, magnitude
import time
from phe import paillier
import numpy as np

In [2]:
from seal import ChooserEvaluator,     \
                 Ciphertext,           \
                 Decryptor,            \
                 Encryptor,            \
                 EncryptionParameters, \
                 Evaluator,            \
                 IntegerEncoder,       \
                 KeyGenerator,         \
                 MemoryPoolHandle,     \
                 Plaintext

In [3]:
#d = '/local_data/atitus/data/bacterial_genomes/bacteria'
d = '/SEAL/local_data/atitus/data/bacterial_genomes/bacteria'

random.seed(3)

files = os.listdir(d)
query_file_ind = randint(0, 8)

f = files[query_file_ind]

print('Selected "' + f + '" from file number %s' % str(query_file_ind))

f = os.path.join(d, 'GCF_000446565.2_ASM44656v2_genomic.fasta')#f)
print(f)


Selected "addgenePartial982.fasta" from file number 3
/SEAL/local_data/atitus/data/bacterial_genomes/bacteria/GCF_000446565.2_ASM44656v2_genomic.fasta


# PHE
Paillier Homomorphic Encryption

In [17]:
start = time.time()

############
# Set up environment
############
parameters = Parameters(seq_len = 10000, 
                        LSH_size = 50000, 
                        num_cores = 48, 
                        kmer_size = 8, 
                        H = hash, 
                        hash_max = sys.maxsize + 1,
                        data_dir = d, 
                        search_n_entries = 100,
                        comparison = 'pe',
                        scheme = 'paillier')


############
# Initialize query to be passed
############
investigator = Querier(parameters)
investigator.load_query_from_fasta(f)
investigator.encode_query(investigator.query)

enc_start = time.time()
investigator.generate_keys()
investigator.encrypt_LSH(investigator.LSH)
enc_end = time.time()

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

############
# Search database
############
query = investigator.pass_query()
participant = Database(query = query, Parameters = parameters)

query_start = time.time()
participant.gen_database_scores()
query_end = time.time()


############
# Calculated scores
############
enc_results = participant.pass_results()

score_start = time.time()
investigator.calc_scores(enc_results)
score_end = time.time()

end_time = time.time() - start
enc_time = enc_end - enc_start
query_time = query_end - query_start
score_time = score_end - score_start

In [18]:
print('Algorithm took %s minutes' % str(round(end_time/60, 2)))
print('Encryption took %s minutes' % str(round(enc_time/60, 2)))
print('Query took %s minutes' % str(round(query_time/60, 2)))
print('Calculating scores took %s minutes\n' % str(round(score_time/60, 2)))

print(investigator.max_iou, investigator.max_ioLquery, investigator.max_ioLresult)
print(investigator.best_seq == investigator.query)

Algorithm took 2.22 minutes
Encryption took 0.31 minutes
Query took 1.31 minutes
Calculating scores took 0.6 minutes

0.19446125907990314 0.3181874458338492 0.33337657283694383
False


# FHE
Fully Homomorphic Encryption

In [None]:
start = time.time()

############
# Set up environment
############
parameters = Parameters(seq_len = 100, 
                        LSH_size = 500, 
                        num_cores = 48, 
                        kmer_size = 8, 
                        H = hash, 
                        hash_max = sys.maxsize + 1,
                        data_dir = d, 
                        search_n_entries = 700,
                        comparison = 'pe',
                        scheme = 'FHE')


############
# Initialize query to be passed
############
investigator = Querier(parameters)
investigator.load_query_from_fasta(f)
investigator.encode_query(investigator.query)

enc_start = time.time()
investigator.generate_keys()
investigator.encrypt_LSH(investigator.LSH)
enc_end = time.time()

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

############
# Search database
############
query = investigator.pass_query()
participant = Database(query = query, Parameters = parameters)

query_start = time.time()
participant.gen_database_scores()
query_end = time.time()


############
# Calculated scores
############
enc_results = participant.pass_results()

score_start = time.time()
investigator.calc_scores(enc_results)
score_end = time.time()

end_time = time.time() - start
enc_time = enc_end - enc_start
query_time = query_end - query_start
score_time = score_end - score_start

In [None]:
print('Algorithm took %s minutes' % str(round(end_time/60, 2)))
print('Encryption took %s minutes' % str(round(enc_time/60, 2)))
print('Query took %s minutes' % str(round(query_time/60, 2)))
print('Calculating scores took %s minutes\n' % str(round(score_time/60, 2)))

print(investigator.max_iou, investigator.max_ioLquery, investigator.max_ioLresult)
print(investigator.best_seq == investigator.query)

In [None]:
len(investigator.enc_LSH)

In [4]:
class Parameters(object):
    """
    Comments about this class
    
    Attributes:
        
    """
    def __init__(self, seq_len, LSH_size, num_cores, 
                 kmer_size, H, hash_max, search_n_entries, 
                 data_dir, comparison, scheme):
        """
        """
        self.seq_len = seq_len
        self.LSH_size = LSH_size
        self.num_cores = num_cores
        self.kmer_size = kmer_size
        self.H = H
        self.H_max = hash_max
        self.search_n_entries = search_n_entries
        self.data_dir = data_dir
        self.comparison = comparison
        self.scheme = scheme
        
        if self.scheme == 'FHE':
            self.fhe_params = EncryptionParameters()
            self.fhe_params.set_poly_modulus('1x^2048 + 1')
            self.fhe_params.set_coeff_modulus(ChooserEvaluator.default_parameter_options()[2048]);
            self.fhe_params.set_plain_modulus(1 << 8)
            self.fhe_params.validate()
            self.memorypool = MemoryPoolHandle.acquire_global()
            self.encoder = IntegerEncoder(self.fhe_params.plain_modulus(), 2, self.memorypool)
            
        
    def get_num_cores(self):
        """
        """
        return(self.num_cores)
        
    
    def get_seq_len(self):
        """
        """
        return(self.seq_len)
    
    
    def get_LSH_size(self):
        """
        """
        return(self.LSH_size)
    
    
    def get_kmer_size(self):
        """
        """
        return(self.kmer_size)
    
    
    def get_hash_func(self):
        """
        """
        return(self.H)
    
    
    def get_hash_max(self):
        """
        """
        return(self.H_max)
    
    
    def get_search_size(self):
        """
        """
        return(self.search_n_entries)
    
    
    def get_data_dir(self):
        """
        """
        return(self.data_dir)
    
    
    def get_comparison(self):
        """
        """
        return(self.comparison)
    
    
    def get_enc_scheme(self):
        """
        """
        return(self.scheme)
    
    
    def get_fhe_params(self):
        """
        """
        return(self.fhe_params)
    
    
    def get_mempool(self):
        """
        """
        return(self.memorypool)
    
    
    def get_fhe_encoder(self):
        """
        """
        return(self.encoder)

In [5]:
class Querier(object):
    """
    Comments about this class
    
    Attributes:
        seq_len: length of query sequence
    """
    def __init__(self, Parameters):
        """
        """
        self.seq_len = Parameters.get_seq_len()
        self.LSH_size = Parameters.get_LSH_size()
        self.num_cores = Parameters.get_num_cores()
        self.kmer_size = Parameters.get_kmer_size()
        self.H = Parameters.get_hash_func()
        self.H_max = Parameters.get_hash_max()
        self.comparison = Parameters.get_comparison()
        self.scheme = Parameters.get_enc_scheme()
        self.enc_LSH = None
        
        if self.scheme == 'FHE':
            self.fhe_params = Parameters.get_fhe_params()
            self.memorypool = Parameters.get_mempool()
            self.encoder = Parameters.get_fhe_encoder()
        
        
    def get_num_cores(self):
        """
        """
        return(self.num_cores)
        
    
    def get_seq_len(self):
        """
        """
        return(self.seq_len)
    
        
    def load_query_from_fasta(self, file_loc):
        """
        """
        
        seq = ''
        with open(f, "r") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                seq += str(record.seq)
        
        self.query = seq[:self.seq_len]
        return(0)
    
    
    def load_query_from_database(self, database):
        """
        In progress
        """
        return(1)
    
    def print_query_seq(self):
        """
        """
        return(self.query)
    
    
    def encode_query(self, query_seq):
        """
        """

        self.LSH = encode(query_seq, 
                          size=self.LSH_size, 
                          k=self.kmer_size, 
                          h=self.H,
                          HASH_MAX=self.H_max)
        
        self.query_mag = magnitude(self.LSH)
        
    
    def get_LSH(self):
        return(self.LSH)
    
    
    def generate_keys(self):
        """
        """
        if self.scheme == 'paillier':
            self.public_key, self.private_key = paillier.generate_paillier_keypair()
            
        elif self.scheme == 'FHE':
            # For detailed information, see PySEAL example script
            generator = KeyGenerator(self.fhe_params, self.memorypool)
            generator.generate(0)
            
            self.public_key = generator.public_key()
            self.private_key = generator.secret_key()
            
            
            
            self.encryptor = Encryptor(self.fhe_params, self.public_key, self.memorypool)
            
        else:
            return('Wrong encryption scheme call...')
        
    
    def get_pub_key(self):
        """
        """
        return(self.public_key)
    
    
    def get_priv_key(self):
        """
        """
        return(self.private_key)
    
    
    def _FHE_encode_int(self, integer):
        """
        """
        encoded = self.encoder.encode(integer)
        encrypted = self.encryptor.encrypt(encoded)
        
        return(encrypted)
    
    
    def encrypt_LSH(self, LSH):
        """
        """
        num_cores = self.num_cores
        
        if self.scheme == 'paillier':
            self.enc_LSH = Parallel(n_jobs=num_cores)(delayed(self.public_key.encrypt)(x) for x in LSH)
        
        elif self.scheme == 'FHE':
            self.enc_LSH = []
            for i,x in enumerate(LSH):
                self.enc_LSH.append(self._FHE_encode_int(LSH[i]))
    
        else:
            return('Wrong encryption scheme call...')
        
        
    def get_enc_query(self):
        """
        """
        return(self.enc_LSH)
    
    
    def pass_query(self):
        """
        """
        if self.comparison == 'pe':
            return(self.enc_LSH)
        elif self.comparison == 'pp':
            return(self.LSH)
        else:
            print('Wrong comparison parameter')
            return(1)
    
    
    ####################
    # Calculate the Intersection over X
    ####################
    def ioX(self, intersection, data_mag):
        """
        """
        union = (data_mag + self.query_mag) - intersection

        iou = intersection/union
        max_ioLquery = intersection/self.query_mag
        max_ioLresult = intersection/data_mag

        return iou, max_ioLquery, max_ioLresult
    
    
    ####################
    # Calculate all IoXs
    ####################
    def calc_ioX(self, id_):
        """
        """
        if self.comparison == 'pe':
            if self.scheme == 'paillier':
                intersection = self.private_key.decrypt(id_[0])
            else:
                self.decryptor = Decryptor(self.fhe_params, self.private_key, self.memorypool)
                poly_intersection = self.decryptor.decrypt(id_[0])
                intersection = self.encoder.decode_int32(poly_intersection)
        else:
            intersection = id_[0]
            
        Iou, IoLquery, IoLresult = self.ioX(intersection, id_[1])

        return Iou, IoLquery, IoLresult, id_[2], id_[1]
    
    
    ####################
    # Generate all scores and find max based on IoU
    ####################
    def calc_scores(self, enc_results):
        """
        """
        self.max_iou = 0
        self.best_id = 0
        best_seq = ''
        
        if self.scheme == 'paillier':
            self.result_scores = Parallel(n_jobs=self.num_cores)(delayed(self.calc_ioX)(id_) for id_ in enc_results)
        else:
            self.result_scores = []
            for i,id_ in enumerate(enc_results):
                self.result_scores.append(self.calc_ioX(id_))
                
        for score_set in self.result_scores:
            if score_set[0] >= self.max_iou: 
                self.max_iou = score_set[0]
                self.max_ioLquery = score_set[1]
                self.max_ioLresult = score_set[2]  
                self.best_seq = score_set[3]
                self.result_mag = score_set[4]

        

In [10]:
class Database(object):
    """
    Comments about this class
    
    Attributes:
        
    """
    
    def __init__(self, query, Parameters):
        """
        """
        self.comparison = Parameters.get_comparison()
        self.data_dir = Parameters.get_data_dir()
        self.search_size = Parameters.get_search_size()    
        self.seq_len = Parameters.get_seq_len()
        self.LSH_size = Parameters.get_LSH_size()
        self.num_cores = Parameters.get_num_cores()
        self.kmer_size = Parameters.get_kmer_size()
        self.H = Parameters.get_hash_func()
        self.H_max = Parameters.get_hash_max()
        self.enc_LSH = query
        
        self.scheme = Parameters.get_enc_scheme()
        
        if self.scheme == 'FHE':
            self.fhe_params = Parameters.get_fhe_params()
            self.memorypool = Parameters.get_mempool()
            self.encoder = Parameters.get_fhe_encoder()
    
    def get_data_dir(self):
        """
        """
        return(self.data_dir)
    
    
    def gen_scores(self, id_, LSH):
        """
        """
        seq_file = os.path.join(self.data_dir, id_)

        entry_seq = ''
        
        with open(seq_file, "r") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                entry_seq += str(record.seq)

        entry_seq = entry_seq[ :self.seq_len]

        if seq_file == f:
            return(LSH[0]*0,0.0001, entry_seq)
        
        entry_LSH = encode(entry_seq, 
                           size=self.LSH_size, 
                           k=self.kmer_size, 
                           h=self.H,
                           HASH_MAX=self.H_max)
        
        if self.scheme == 'paillier':
            return(self.phe_dotproduct(entry_LSH, LSH), magnitude(entry_LSH), entry_seq)
        
        else:
            return(self.fhe_dotproduct(entry_LSH, LSH), magnitude(entry_LSH), entry_seq)
    
    def gen_database_scores(self):
        """
        """
        
        data = os.listdir(self.data_dir)
        data = data[:self.search_size]
        
        if self.scheme == 'paillier':
            self.result_scores = Parallel(n_jobs=self.num_cores)(delayed(self.gen_scores)(id_, self.enc_LSH) for id_ in data)
        
        elif self.scheme == 'FHE':
            self.result_scores = []
            for i,id_ in enumerate(data):
                self.result_scores.append(self.gen_scores(id_, self.enc_LSH))
    
    def pass_results(self):
        """
        """
        return(self.result_scores)
    
    
    ####################
    # Calculate the dot product between a binary vector and an encrypted vector
    ####################
    def phe_dotproduct(self, v1, v2):
        """
        """
        v1_array = np.asarray(v1)
        v2_array = np.asarray(v2)
        dot = np.dot(v1_array, v2_array)

        return dot

    
    ####################
    # Calculate the dot product between a binary vector and an encrypted vector
    ####################
    def fhe_dotproduct(self, entry_LSH, enc_LSH):
        """
        """
        self.evaluator = Evaluator(self.fhe_params, self.memorypool)
        
        dot = -1
        for i,_ in enumerate(enc_LSH):
            if entry_LSH[i] == 1:
                if dot == -1:
                    dot = enc_LSH[i]
                else:
                    dot = self.evaluator.add(dot, enc_LSH[i])

        return dot
    

    ####################
    # Calculate the magnitude of a binary vector
    ####################
    def magnitude(v):
        """Finds the magnitude of a binary vector (array).

        Args:
            v: A binary vector (array).

        Returns:
            The magnitude of the vector.
        """

        return sum(v)             