In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import sys
import os
import scipy
import math
from sklearn.metrics import r2_score
from pathlib import Path
import subprocess

from Bio.SeqIO import read, parse
import orthoani

from multiprocess import Pool

max_pool = 16
lim_file = 200

sys.path.append('../')
sys.path.append('../../')


In [2]:
from db_sketching.kmer_set import KMerSet, FracMinHash, TruncatedKMerSet, MeanFracMinHash

In [3]:
def cond(kmer_hash):
    hash = (976369 * kmer_hash + 1982627) % 10000
    if hash < 50:
        return True
    else:
        return False
    



def generate_polycond(degree, mod, c, seed_val):
    rng = np.random.default_rng(seed_val)
    coeffs = np.random.randint(
        low = 0,
        high = mod,
        size = degree+1
    )

    def poly_cond(kmer_hash):
        hash_mod = kmer_hash % mod
        end_val = 0
        for coeff in reversed(coeffs):
            end_val *= hash_mod
            end_val += coeff
            end_val %= mod

        return end_val < (mod / c)
    
    return poly_cond

def all_cond(kmer_hash):
    return True

In [4]:
def compare_ANI(genome_file_1, genome_file_2, kmer_length, kmer_class, sketching_condition):
    """
    Function to compute the estimated ANI using FracMinHash 
    and compare it to the value from OrthoANI
    """
    return compute_ortho_ani(genome_file_1, genome_file_2), compute_kmer_ani(genome_file_1, genome_file_2, kmer_length, kmer_class, sketching_condition)

def compute_kmer_sketches(genome_file, kmer_class, sketching_condition, kmer_length, canonical):
    genome_kmer = kmer_class(sketching_condition, kmer_length, canonical)
    genome_kmer.insert_file(genome_file)
    return genome_kmer

def compute_kmer_sketches_parallel(genome_files, kmer_length, kmer_class, sketching_cond, canonical):
    args = [(g,kmer_class,sketching_cond,kmer_length,canonical) for g in genome_files]
    with Pool(max_pool) as p:
        return p.starmap(compute_kmer_sketches,args)


def compute_kmer_ani(genome_1_kmer, genome_2_kmer):    
    kmer_estimated_ani = genome_1_kmer.ANI_estimation(genome_2_kmer)
    return kmer_estimated_ani


def compute_kmer_ani_parallel(kmer_sketches_1,kmer_sketches_2):
    args = [(s1,s2) for s1,s2 in zip(kmer_sketches_1,kmer_sketches_2)]
    with Pool(max_pool) as p:
        return p.starmap(compute_kmer_ani,args)

def compute_ortho_ani(genome_file_1, genome_file_2):
    try:
        genome_1_read = parse(genome_file_1,"fasta")
        genome_2_read = parse(genome_file_2,"fasta")
        ortho_ani_value = orthoani.orthoani(genome_1_read,genome_2_read)

        return ortho_ani_value
    except:
        return 0




In [5]:
# analysis_type = "Single-Species"
# analysis_data_dir = "Escherichia coli" 
# analysis_data_dir = "Yersinia pestis" 

analysis_type = "Single-Genus"
# analysis_data_dir = "Yersinia" 
analysis_data_dir = "Salmonella"

In [6]:
data_dir = os.path.join(os.path.join("../../data_temp",analysis_type),analysis_data_dir)
genome_files = []


for filename in os.listdir(data_dir):
    full_filename = os.path.join(data_dir,filename)
    genome_files.append(full_filename)


print(genome_files)
print(len(genome_files))

['../../data_temp/Single-Genus/Salmonella/64.fna', '../../data_temp/Single-Genus/Salmonella/32.fna', '../../data_temp/Single-Genus/Salmonella/99.fna', '../../data_temp/Single-Genus/Salmonella/43.fna', '../../data_temp/Single-Genus/Salmonella/7.fna', '../../data_temp/Single-Genus/Salmonella/98.fna', '../../data_temp/Single-Genus/Salmonella/58.fna', '../../data_temp/Single-Genus/Salmonella/90.fna', '../../data_temp/Single-Genus/Salmonella/14.fna', '../../data_temp/Single-Genus/Salmonella/41.fna', '../../data_temp/Single-Genus/Salmonella/67.fna', '../../data_temp/Single-Genus/Salmonella/73.fna', '../../data_temp/Single-Genus/Salmonella/17.fna', '../../data_temp/Single-Genus/Salmonella/56.fna', '../../data_temp/Single-Genus/Salmonella/50.fna', '../../data_temp/Single-Genus/Salmonella/2.fna', '../../data_temp/Single-Genus/Salmonella/46.fna', '../../data_temp/Single-Genus/Salmonella/52.fna', '../../data_temp/Single-Genus/Salmonella/65.fna', '../../data_temp/Single-Genus/Salmonella/95.fna', '

In [7]:
checked_genome_files = []
for file in genome_files:
    try:
        parsed_file = parse(file,"fasta")
        assert(len([record for record in parsed_file]) > 0)
        checked_genome_files.append(file)
    except:
        print(f"File {file} is damaged / invalid")


In [8]:
print(checked_genome_files)
print(len(checked_genome_files))

['../../data_temp/Single-Genus/Salmonella/64.fna', '../../data_temp/Single-Genus/Salmonella/32.fna', '../../data_temp/Single-Genus/Salmonella/99.fna', '../../data_temp/Single-Genus/Salmonella/43.fna', '../../data_temp/Single-Genus/Salmonella/7.fna', '../../data_temp/Single-Genus/Salmonella/98.fna', '../../data_temp/Single-Genus/Salmonella/58.fna', '../../data_temp/Single-Genus/Salmonella/90.fna', '../../data_temp/Single-Genus/Salmonella/14.fna', '../../data_temp/Single-Genus/Salmonella/41.fna', '../../data_temp/Single-Genus/Salmonella/67.fna', '../../data_temp/Single-Genus/Salmonella/73.fna', '../../data_temp/Single-Genus/Salmonella/17.fna', '../../data_temp/Single-Genus/Salmonella/56.fna', '../../data_temp/Single-Genus/Salmonella/50.fna', '../../data_temp/Single-Genus/Salmonella/2.fna', '../../data_temp/Single-Genus/Salmonella/46.fna', '../../data_temp/Single-Genus/Salmonella/52.fna', '../../data_temp/Single-Genus/Salmonella/65.fna', '../../data_temp/Single-Genus/Salmonella/95.fna', '

In [9]:
checked_genome_files = checked_genome_files[:lim_file]

In [10]:

def get_genome_length(genome_file):
    genome_kmer = KMerSet(20)
    genome_kmer.insert_file(genome_file)
    return genome_kmer.length

def compute_length_parallel(genome_files):
    with Pool(max_pool) as p:
        return p.starmap(get_genome_length,([(g_file,) for g_file in genome_files]))

    
checked_genome_lengths = compute_length_parallel(checked_genome_files)
print(checked_genome_lengths)


[4685839, 5036721, 4799691, 4978368, 4637671, 4552575, 4967096, 5200148, 4788855, 4718240, 5007181, 4696875, 4940173, 4721303, 4924305, 4942132, 4873783, 4892783, 4689836, 4803273, 4848196, 4681839, 4640885, 4862054, 4882017, 5008354, 4934564, 4736114, 4743041, 4800499, 4828175, 4789704, 4932116, 4738752, 4893261, 4637665, 4941618, 4685641, 4716475, 4757477, 4703315, 4781453, 4572954, 5042943, 4713509, 4799202, 4690202, 4803735, 4792764, 5235241, 5062839, 4574588, 4856496, 4971732, 4666409, 4730243, 4699542, 4735543, 4872872, 4741712, 4763382, 5171680, 4826246, 4913446, 4732567, 4723984, 4715231, 4841076, 4763990, 5073979, 4691084, 5141052, 4700938, 4690427, 4795704, 4689188, 4800611, 4935416, 4793608, 4976814, 4860068, 4802782, 5021235, 5092352, 4798856, 4930531, 4763586, 5018183, 4972820, 4933723, 4807729, 4738021, 4678571, 4762840, 5072429, 4915191, 4579778, 5040224, 5050625, 4794639]


In [11]:
ortho_vals = [] 

def compute_ortho_ani_parallel(genome_files_1,genome_files_2):
    args = [(g1,g2) for g1,g2 in zip(genome_files_1,genome_files_2)]
    with Pool(max_pool) as p:
        return p.starmap(compute_ortho_ani,args)

def compute_pairwise_ortho(genome_files):
    genome_files_1, genome_files_2 = zip(*[(g1,g2) for g1 in genome_files for g2 in genome_files])
    return compute_ortho_ani_parallel(genome_files_1,genome_files_2)

# ortho_vals = compute_pairwise_ortho(checked_genome_files)
ortho_vals = compute_ortho_ani_parallel(checked_genome_files,checked_genome_files[1:])
print(ortho_vals)
print(len(ortho_vals))

[0.9853118065949485, 0.9849208189655172, 0.9835980379453957, 0.9841536637523629, 0.9869974640605297, 0.9864424823310601, 0.9961290425531915, 0.98775828991691, 0.9868050339658, 0.9860891503038803, 0.9834794534313726, 0.9816253180169286, 0.9818818963429402, 0.9823691789422845, 0.9835908051257254, 0.9834941638795986, 0.9959086694952133, 0.9880790658959537, 0.9822055607361964, 0.9825560815197788, 0.9883523763066202, 0.9818191887019231, 0.9802555357995226, 0.9871353675450762, 0.9875803285181733, 0.9881910712622264, 0.9820627477042049, 0.982460306673209, 0.9876580955811719, 0.9868489462365592, 0.9859157323350491, 0.9822902078801063, 0.9990454437869822, 0.9828252870126764, 0.9864292223489167, 0.9867554811320755, 0.9893138052095131, 0.9818410542975408, 0.9836246980354111, 0.9986813500545256, 0.9886282980209546, 0.9842468844836025, 0.983321622796426, 0.9876491418685122, 0.9979614176496286, 0.9876466588675714, 0.986047608049394, 0.9880173304083699, 0.9824242713450292, 0.9848925357607282, 0.98278

In [13]:
kmer_class = MeanFracMinHash
canonical = True

mod = 1000000007
degree = 3
c = 2000
num_seeds = 1

In [14]:


kmer_vals = [[]] * 40
for kmer_length in range(10,39+1):
    
    poly_conds = [generate_polycond(degree,mod,c,seed_val) for seed_val in range(num_seeds)]

    kmer_sketches = compute_kmer_sketches_parallel(checked_genome_files, kmer_length, kmer_class, poly_conds, canonical=canonical)
    # kmer_sketches_1,kmer_sketches_2 = zip(*[(s1,s2) for s1 in kmer_sketches[kmer_length] for s2 in kmer_sketches[kmer_length]])
    kmer_sketches_1,kmer_sketches_2 = zip(*[(s1,s2) for s1,s2 in zip(kmer_sketches,kmer_sketches[1:])])
    kmer_vals[kmer_length] = compute_kmer_ani_parallel(kmer_sketches_1,kmer_sketches_2)
    print(kmer_vals[kmer_length])    

    print(len(kmer_vals[kmer_length]))
    



print(kmer_vals)

[0.9996055999670472, 0.9988265072940381, 0.9992190511257478, 0.9984447589885809, 0.9984264037062179, 0.9992097944262852, 1.0, 0.9976859755638828, 0.9988218989313793, 0.9996071493720347, 1.0, 0.9984567597789843, 1.0, 0.9976770160964453, 0.9988172542310758, 0.9980675732611322, 0.9996132271381231, 0.9984386883269164, 0.9996071493720347, 0.9984567597789843, 0.998044904075339, 0.9988125727616297, 0.9984325700866473, 0.9988218989313793, 0.9996055999670472, 0.9980675732611322, 0.999611725395462, 0.9976770160964453, 1.0, 0.9988265072940381, 0.9984201886152007, 1.0, 1.0, 0.9980750132815104, 0.998044904075339, 0.9996071493720347, 0.9972638526743827, 0.9992190511257478, 0.9972850780541653, 0.9992190511257478, 0.9992129042793826, 0.999611725395462, 0.9980675732611322, 0.998044904075339, 1.0, 0.9972531150325543, 0.9992129042793826, 0.9984386883269164, 0.9996055999670472, 0.9980750132815104, 0.9964833288348736, 0.9992035004094181, 0.9992220886858091, 0.9992190511257478, 0.9984264037062179, 0.9996071

In [15]:
Path(f"../../plots/{analysis_type}-{analysis_data_dir}/").mkdir(parents=True, exist_ok=True)

for kmer_length in range(10,39+1):
    fig = plt.figure(figsize=(10,10),dpi=300)
    ax = fig.add_subplot()
    filter_o, filter_k = zip(*[(o,k) for o,k in zip(ortho_vals,kmer_vals[kmer_length]) if (k > 0 and o > 0)])
    lin_fit = np.poly1d(np.polyfit(filter_o, filter_k, 1))
    ax.plot(filter_o,filter_o,"r-")
    ax.scatter(filter_o,filter_k,marker="^",s=[2 for _ in filter_o])
    # ax.scatter(filter_o,filter_o,marker="o",s=[2 for _ in filter_o])
    ax.plot(np.unique(filter_o), lin_fit(np.unique(filter_o)))
    plt.figtext(0.7,0.15,f"Zeros dropped : {len(ortho_vals) - len(filter_o)}")
    pearson_coeff = scipy.stats.pearsonr(filter_o,filter_k)

    plt.figtext(0.7,0.17,f"Pearson Corr. Coeff.: {pearson_coeff.statistic:.3f}")
    plt.figtext(0.7,0.19,f"Pearson Corr. p.: {pearson_coeff.pvalue:.3e}")

    ax.legend(["OrthoANI",f"{kmer_length}-mer dots",f"{kmer_length}-mer line : {lin_fit}"])
    plt.plot()
    plt.savefig(f"../../plots/{analysis_type}-{analysis_data_dir}/{analysis_type}-{analysis_data_dir}-{len(ortho_vals)}genomes-{kmer_length}mer-{c=}-{canonical=}-estimated-ANI.png")
    plt.close()

    

