In [18]:
import mmh3
import pandas as pd
import numpy as np
from pybloom_live import BloomFilter
import time

In [66]:
import time
import numpy as np
import mmh3
from collections import defaultdict

def generate_kmers(sequence, k):
    """Generate k-mers from a given sequence."""
    for i in range(len(sequence) - k + 1):
        yield sequence[i:i + k]

def create_optimized_mash_sketch(sequence, k, sketch_size, coverage_threshold=1):
    """Create an optimized Mash-like sketch using a two-pass approach with known k-mer counts."""
    kmer_counts = defaultdict(int)
    sketch = []

    # First pass: Count k-mer occurrences
    for kmer in generate_kmers(sequence, k):
        kmer_hash = mmh3.hash(kmer)
        kmer_counts[kmer_hash] += 1
        
    print(sum(kmer_counts.values()) / len(kmer_counts))

    # Pre-filter k-mer hashes based on the coverage threshold
    filtered_hashes = [kmer_hash for kmer_hash, count in kmer_counts.items() if count >= coverage_threshold]

    # Sort the filtered k-mer hashes and select the smallest ones for the sketch
    sketch = sorted(filtered_hashes)[:sketch_size]

    return np.array(sketch)

In [4]:
data = pd.read_parquet('../../data/processed/mock_data.parquet', engine='pyarrow')  # You can use 'fastparquet' as the engine
data

Unnamed: 0,Accession ID,Lineage,Sequence,Coverage,Train
0,EPI_ISL_15104785,BA.5.1,TTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTC...,99.213260,2
1,EPI_ISL_3411570,AY.19,TGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTG...,99.939512,2
2,EPI_ISL_2433815,C.1,ATACCTTCCTAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGA...,95.359497,2
2,EPI_ISL_1715397,L.3,CTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGT...,99.775462,1
3,EPI_ISL_14795073,BA.4.6,ACCAGTGGCTTACCGCAAGGTTCTTCTTCGTAAGAACGGTAATAAA...,99.972625,0
...,...,...,...,...,...
691,EPI_ISL_5099144,AY.46,TACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGAT...,100.000000,0
691,EPI_ISL_9695375,BA.1.21,ACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACT...,99.160820,1
692,EPI_ISL_602627,B.1.1.84,CTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCT...,100.000000,0
693,EPI_ISL_17231201,BQ.1.1.38,TATACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAG...,99.979875,0


In [5]:
data['Sequence'] = data['Sequence'].str.replace('[^ACTG]', '', regex=True)

In [71]:
start = time.time()
sketch = create_optimized_mash_sketch(data["Sequence"][3], 13, 1000)
end = time.time()
print(end-start)

1.0017837541163557
0.04586172103881836


In [5]:
import time

k = 22
sketch_size = 1000

start = time.time()
sketch_array = []
for genome in data["Sequence"]:
    sketch_array.append(create_mash_sketch(genome, k, sketch_size))
end = time.time()

print(end-start)

1095.9188392162323


In [6]:
mash_data = pd.DataFrame(sketch_array)
mash_data["Target"] = data["Lineage"].tolist()
mash_data["Test"] = data["Test"].tolist()
mash_data.to_parquet('../../data/features/mash.parquet', engine='pyarrow')

  table = self.api.Table.from_pandas(df, **from_pandas_kwargs)
