In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import sys
import os
import scipy
import math
from sklearn.metrics import r2_score, mean_squared_error

from pathlib import Path
import subprocess

from Bio.SeqIO import read, parse
import orthoani

from multiprocess import Pool
import pandas as pd

max_pool = 16

sys.path.append('../')
sys.path.append('../../')


In [2]:
from db_sketching.parallel_helpers import check_genome_files, compute_kmer_ani_parallel, compute_ortho_ani_parallel, compute_estimator_kmer_sketches_parallel
from db_sketching.kmer_set import EstimatorFracMinHash

In [3]:
# Helper function to obtain genome files
def get_genome_files(data_home_dir, analysis_type, analysis_data_dir):
    genome_files = []
    collection_path = os.path.join(data_home_dir,analysis_type,analysis_data_dir)
    for filename in os.listdir(collection_path):
        full_filename = os.path.join(collection_path,filename)
        genome_files.append(full_filename)
    return genome_files
    

In [4]:
def generate_polycond(degree, mod, c, seed_val):
    rng = np.random.default_rng(seed_val)
    coeffs = np.random.randint(
        low = 0,
        high = mod,
        size = degree+1
    )

    def poly_cond(kmer_hash):
        hash_mod = kmer_hash % mod
        end_val = 0
        for coeff in reversed(coeffs):
            end_val *= hash_mod
            end_val += coeff
            end_val %= mod

        return end_val < (mod / c)
    
    return poly_cond

def all_cond(kmer_hash):
    return True

In [5]:
def median_of_means_generator(subsample):
    def median_of_means(sample):
        blocks = np.array_split(sample, subsample)
        block_means = np.array(np.mean(block) for block in blocks)
        return np.median(block_means)
    
    return median_of_means

In [6]:
ANALYSIS_TYPE = "analysis_type"
ANALYSIS_DATA_DIR = "analysis_data_dir"
GENOME_FILE_1 = "genome_file_1"
GENOME_FILE_2 = "genome_file_2"
ORTHOANI_VAL = "orthoani_val"
KMER_LENGTH = "kmer_length"
C_VAL = "c_val"
CANONICAL = "canonical"
NUM_CONDITIONS = "num_conditions"
SUBSAMPLE = "subsample"
KMER_ANI_VAL = "kmer_ani_val"
PEARSON_COEFF = "pearson_coeff"
PEARSON_COEFF_PVAL = "pearson_coeff_pval"
SPEARMAN_COEFF = "spearman_coeff"
SPEARMAN_COEFF_PVAL = "spearman_coeff_pval"
R2_VAL_RAW = "r2_val_raw"
R2_VAL_LIN = "r2_val_lin"
DROPPED_ZEROS = "dropped_zeros"
MEAN_SQUARED_ERROR = "mean_squared_error"


In [7]:
data_home_dir = "../../data_temp"

species_analysis_type = "Single-Species"
species_analysis_data_dirs = {
    # "Escherichia coli": (10,100),
    # "Lactobacillus helveticus": (10,100),
    # "Staphylococcus hominis": (10,100),
    # "Mycoplasmoides pneumoniae": (10,100),
    # "Brucella melitensis": (10,100),
    # "Xanthomonas oryzae": (10,100)
}

genus_analysis_type = "Single-Genus"
genus_analysis_data_dirs = {
    # "Pectobacterium": (10,35),
    # "Morganella": (10,35),
    "Xylella": (20,35),
}

family_analysis_type = "Single-Family"
family_analysis_data_dirs = {
    "Enterobacteriaceae": (10,35),
    "Cyanobiaceae": (10,35),
    "Rhizobiaceae": (10,35)
}

family_genus_analysis_type = "Single-Family-Multi-Genus"
family_genus_analysis_data_dirs = {
    "31989": (10,35),
    "49546": (10,35),
    "186803": (10,35),
    "186817": (10,35),
}

genus_species_analysis_type = "Single-Genus-Multi-Species"
genus_species_analysis_data_dirs = {
    "Psychrobacter": (10,35),
    "Aeromonas": (10,35),
    "Rathayibacter": (10,35),
}

In [8]:
analysis_types = {
    species_analysis_type: species_analysis_data_dirs,
    genus_analysis_type: genus_analysis_data_dirs,
    family_analysis_type: family_analysis_data_dirs,
    family_genus_analysis_type: family_genus_analysis_data_dirs,
    genus_species_analysis_type: genus_species_analysis_data_dirs
}

In [9]:
c_val_range = (200,)

In [10]:
ortho_ani_filename = "../../data_temp/ortho_ani_values.csv"
ortho_ani_dataframe = pd.read_csv(ortho_ani_filename,index_col=0)
print(ortho_ani_dataframe)

                                         genome_file_1  \
0       ../../data_temp/Single-Genus/Salmonella/64.fna   
1       ../../data_temp/Single-Genus/Salmonella/32.fna   
2       ../../data_temp/Single-Genus/Salmonella/99.fna   
3       ../../data_temp/Single-Genus/Salmonella/43.fna   
4        ../../data_temp/Single-Genus/Salmonella/7.fna   
..                                                 ...   
494  ../../data_temp/Single-Family/Enterobacteriace...   
495  ../../data_temp/Single-Family/Enterobacteriace...   
496  ../../data_temp/Single-Family/Enterobacteriace...   
497  ../../data_temp/Single-Family/Enterobacteriace...   
498  ../../data_temp/Single-Family/Enterobacteriace...   

                                         genome_file_2  orthoani_val  \
0       ../../data_temp/Single-Genus/Salmonella/32.fna      0.985312   
1       ../../data_temp/Single-Genus/Salmonella/99.fna      0.984921   
2       ../../data_temp/Single-Genus/Salmonella/43.fna      0.983598   
3        ../../

In [11]:
kmer_sketch_ani_filename = "../../data_temp/kmer_estimator_sketch_ani_values.csv"
try:
    kmer_sketch_dataframe = pd.read_csv(kmer_sketch_ani_filename,index_col=0)
except:
    kmer_sketch_dataframe = pd.DataFrame(
        columns=[
            ANALYSIS_TYPE,
            ANALYSIS_DATA_DIR,
            GENOME_FILE_1,
            GENOME_FILE_2,
            KMER_LENGTH,
            C_VAL,
            CANONICAL,
            NUM_CONDITIONS,
            SUBSAMPLE,
            KMER_ANI_VAL,
        ]
    )
print(kmer_sketch_dataframe)

Empty DataFrame
Columns: [analysis_type, analysis_data_dir, genome_file_1, genome_file_2, kmer_length, c_val, canonical, num_conditions, subsample, kmer_ani_val]
Index: []


In [12]:
corrleation_filename = "../../data_temp/estimator-correlation_values.csv"
try:
    correlation_dataframe = pd.read_csv(corrleation_filename,index_col=0)
except:
    correlation_dataframe = pd.DataFrame(
        columns=[
            ANALYSIS_TYPE,
            ANALYSIS_DATA_DIR,
            C_VAL,
            CANONICAL,
            NUM_CONDITIONS,
            SUBSAMPLE,
            PEARSON_COEFF,
            PEARSON_COEFF_PVAL,
            SPEARMAN_COEFF,
            SPEARMAN_COEFF_PVAL,
            R2_VAL_RAW,
            R2_VAL_LIN,
            DROPPED_ZEROS,
        ]
    )
print(correlation_dataframe)

Empty DataFrame
Columns: [analysis_type, analysis_data_dir, c_val, canonical, num_conditions, subsample, pearson_coeff, pearson_coeff_pval, spearman_coeff, spearman_coeff_pval, r2_val_raw, r2_val_lin, dropped_zeros]
Index: []


In [13]:
num_conditions = 1
subsample = 1

In [14]:
def update_orthoani_dataframe(ortho_ani_dataframe, analysis_type, analysis_data_dir, checked_genome_files):
    checked_genome_files_off1 = checked_genome_files[1:]

    ortho_ani_condition = (
        ortho_ani_dataframe[GENOME_FILE_1].isin(checked_genome_files)
    )

    filtered_ortho_ani_dataframe = ortho_ani_dataframe[ortho_ani_condition]
    complement_ortho_ani_dataframe = ortho_ani_dataframe[~ortho_ani_condition]

    if len(filtered_ortho_ani_dataframe) == 0:
        print("Computing OrthoANI values ...")
        ortho_ani_vals = compute_ortho_ani_parallel(checked_genome_files,checked_genome_files_off1)
        print("Computed OrthoANI values")
    else:
        print("Using previously computed OrthoANI values")
        ortho_ani_vals = filtered_ortho_ani_dataframe[ORTHOANI_VAL]

    computed_ortho_ani_dataframe = pd.DataFrame(
        {
            ANALYSIS_TYPE: analysis_type,
            ANALYSIS_DATA_DIR: analysis_data_dir,
            GENOME_FILE_1: checked_genome_files[:-1],
            GENOME_FILE_2: checked_genome_files_off1,
            ORTHOANI_VAL: ortho_ani_vals
        }
    )
    return pd.concat([complement_ortho_ani_dataframe,computed_ortho_ani_dataframe]), ortho_ani_vals

In [15]:
def update_kmer_estimator_dataframe(
        analysis_type,
        analysis_data_dir,
        checked_genome_files,
        kmer_length,
        c_val,
        canon,
        sketching_conditions,
        kmer_sketch_dataframe,
        num_conditions,
        subsample,
):
    checked_genome_files_off1 = checked_genome_files[1:]
    
    kmer_sketch_ani_condition = (
        (kmer_sketch_dataframe[ANALYSIS_TYPE] == analysis_type)
        & (kmer_sketch_dataframe[ANALYSIS_DATA_DIR] == analysis_data_dir)
        & (kmer_sketch_dataframe[GENOME_FILE_1].isin(checked_genome_files))
        & (kmer_sketch_dataframe[KMER_LENGTH] == kmer_length)
        & (kmer_sketch_dataframe[C_VAL] == c_val)
        & (kmer_sketch_dataframe[CANONICAL] == canon)
    )

    filtered_kmer_sketch_ani_dataframe = kmer_sketch_dataframe[kmer_sketch_ani_condition]
    complement_kmer_sketch_ani_dataframe = kmer_sketch_dataframe[~kmer_sketch_ani_condition]

    if len(filtered_kmer_sketch_ani_dataframe) == 0:
        print("Computing sketches ...")
        kmer_sketches = compute_estimator_kmer_sketches_parallel(
            genome_files=checked_genome_files,
            conditions=sketching_conditions,
            kmer_template=kmer_length,
            canonical=canon,
            multiplicity=False,
            estimator=median_of_means_generator(subsample)
        )
        kmer_sketches_off1 = kmer_sketches[1:]
        print("Sketching complete, computing ani ...")
        kmer_ani_vals = compute_kmer_ani_parallel(kmer_sketches,kmer_sketches_off1)
        print("ANI computation complete")
    else:
        print("Using previously computed ANI values")
        kmer_ani_vals = filtered_kmer_sketch_ani_dataframe[KMER_ANI_VAL]


    num_entries = len(kmer_ani_vals)
    
    computed_kmer_sketch_ani_dataframe = pd.DataFrame(
        {
            ANALYSIS_TYPE: analysis_type,
            ANALYSIS_DATA_DIR: analysis_data_dir,
            GENOME_FILE_1: checked_genome_files[:-1],
            GENOME_FILE_2: checked_genome_files_off1,
            NUM_CONDITIONS: num_conditions,
            SUBSAMPLE: subsample,
            KMER_LENGTH: [kmer_length for _ in range(num_entries)],
            C_VAL: [c_val for _ in range(num_entries)],
            CANONICAL: [canon for _ in range(num_entries)],
            KMER_ANI_VAL: kmer_ani_vals
        }
    )


    return pd.concat(
        [complement_kmer_sketch_ani_dataframe,computed_kmer_sketch_ani_dataframe]
    ), kmer_ani_vals

In [16]:
def update_correlation_dataframe(
        analysis_type,
        analysis_data_dir,
        num_conditions,
        subsample,
        c_val,
        canon,
        ortho_ani_vals,
        kmer_ani_vals,
        kmer_length
):
    correlation_condition = (
        (correlation_dataframe[ANALYSIS_TYPE] == analysis_type) 
        & (correlation_dataframe[ANALYSIS_DATA_DIR] == analysis_data_dir) 
        & (correlation_dataframe[NUM_CONDITIONS] == num_conditions) 
        & (correlation_dataframe[SUBSAMPLE] == subsample) 
        & (correlation_dataframe[KMER_LENGTH] == kmer_length) 
        & (correlation_dataframe[C_VAL] == c_val) 
        & (correlation_dataframe[CANONICAL] == canon)
    )

    filtered_correlation_dataframe = correlation_dataframe[~correlation_condition] # remove to recompute

    print("Computing correlations...")


    filter_o, filter_k = zip(*[(o,k) for o,k in zip(ortho_ani_vals,kmer_ani_vals) if (k > 0 and o > 0)])

    dropped_zeros = len(ortho_ani_vals) - len(filter_o)
    pearson_coeff = scipy.stats.pearsonr(filter_o,filter_k)
    spearman_coeff = scipy.stats.spearmanr(filter_o,filter_k)

    lin_fit = np.poly1d(np.polyfit(filter_o, filter_k, 1))
    r2_val_raw = r2_score(filter_o,filter_k)
    r2_val_lin = r2_score(filter_o,[lin_fit(o) for o in filter_o])
    mse = mean_squared_error(filter_o,filter_k)


    return pd.concat([filtered_correlation_dataframe,
        pd.DataFrame({
            ANALYSIS_TYPE: analysis_type,
            ANALYSIS_DATA_DIR: analysis_data_dir,
            KMER_LENGTH: [kmer_length],
            C_VAL: [c_val],
            CANONICAL: [canon],
            PEARSON_COEFF : [pearson_coeff.statistic],
            PEARSON_COEFF_PVAL : [pearson_coeff.pvalue],
            SPEARMAN_COEFF : [spearman_coeff.statistic],
            SPEARMAN_COEFF_PVAL : [spearman_coeff.pvalue],
            R2_VAL_RAW: [r2_val_raw],
            R2_VAL_LIN: [r2_val_lin],
            DROPPED_ZEROS: dropped_zeros,
            MEAN_SQUARED_ERROR: mse
        })],
    )

In [17]:
def generate_conditions(iteration, num_conditions, degree, mod, c_val):
    return [
        generate_polycond(
            degree = degree,
            mod = mod,
            c = c_val,
            seed_val = iteration * num_conditions + idx
        )
        for idx in range(num_conditions)
    ]

In [18]:
iteration = 0

for analysis_type in analysis_types:
    for analysis_data_dir in analysis_types[analysis_type]:
        genome_files = get_genome_files(data_home_dir,analysis_type,analysis_data_dir)
        checked_genome_files = check_genome_files(genome_files)

        print(f"{analysis_type=}, {analysis_data_dir=}")
        print(f"Using genome files {checked_genome_files}")
        
        ortho_ani_dataframe, ortho_ani_vals = update_orthoani_dataframe(ortho_ani_dataframe, analysis_type, analysis_data_dir, checked_genome_files)

        min_kmer_length, max_kmer_length = analysis_types[analysis_type][analysis_data_dir]
        
        for kmer_length in range(min_kmer_length,max_kmer_length+1):
            for c_val in c_val_range:
                canon = True

                iteration += 1

                print(f"ITERATION {iteration} : {kmer_length=}, {c_val=}, {canon=}")


                sketching_conditions = generate_conditions(
                    iteration=iteration,
                    num_conditions=num_conditions,
                    degree=3,
                    mod=1000000007,
                    c_val=c_val
                )

                kmer_sketch_dataframe, kmer_ani_vals = update_kmer_estimator_dataframe(
                        analysis_type=analysis_type,
                        analysis_data_dir=analysis_data_dir,
                        checked_genome_files= checked_genome_files,
                        kmer_length=kmer_length,
                        c_val=c_val,
                        canon=canon,
                        sketching_conditions=sketching_conditions,
                        kmer_sketch_dataframe=kmer_sketch_dataframe,
                        num_conditions=num_conditions,
                        subsample=subsample
                )

                correlation_dataframe = update_correlation_dataframe(
                    analysis_type=analysis_type,
                    analysis_data_dir=analysis_data_dir,
                    num_conditions=num_conditions,
                    subsample=subsample,
                    c_val=c_val,
                    canon=canon,
                    ortho_ani_vals=ortho_ani_vals,
                    kmer_ani_vals=kmer_ani_vals,
                    kmer_length=kmer_length
                )

                

analysis_type='Single-Genus', analysis_data_dir='Xylella'
Using genome files ['../../data_temp/Single-Genus/Xylella/64.fna', '../../data_temp/Single-Genus/Xylella/32.fna', '../../data_temp/Single-Genus/Xylella/99.fna', '../../data_temp/Single-Genus/Xylella/43.fna', '../../data_temp/Single-Genus/Xylella/7.fna', '../../data_temp/Single-Genus/Xylella/98.fna', '../../data_temp/Single-Genus/Xylella/58.fna', '../../data_temp/Single-Genus/Xylella/90.fna', '../../data_temp/Single-Genus/Xylella/14.fna', '../../data_temp/Single-Genus/Xylella/41.fna', '../../data_temp/Single-Genus/Xylella/67.fna', '../../data_temp/Single-Genus/Xylella/73.fna', '../../data_temp/Single-Genus/Xylella/17.fna', '../../data_temp/Single-Genus/Xylella/56.fna', '../../data_temp/Single-Genus/Xylella/50.fna', '../../data_temp/Single-Genus/Xylella/2.fna', '../../data_temp/Single-Genus/Xylella/46.fna', '../../data_temp/Single-Genus/Xylella/52.fna', '../../data_temp/Single-Genus/Xylella/65.fna', '../../data_temp/Single-Genus/X

TypeError: object of type 'generator' has no len()

In [None]:
print(ortho_ani_dataframe)

In [None]:
print(kmer_sketch_dataframe)

In [None]:
print(correlation_dataframe)

In [None]:
ortho_ani_dataframe.to_csv(ortho_ani_filename)
kmer_sketch_dataframe.to_csv(kmer_sketch_ani_filename)
correlation_dataframe.to_csv(corrleation_filename)

In [None]:
def plot_correlation(
        analysis_type, 
        correlation_dataframe,
        c_val_range,
        correlation_string,
    ):

    

    plt.figure(figsize=(10,10),dpi=300)
    plt.xlabel("K-mer Length")
    plt.ylabel(correlation_string)
    plt.title(f"{correlation_string} of {EstimatorFracMinHash.__name__} ANI Esimtation against K-mer Length on {analysis_type}") 
    for analysis_data_dir in analysis_types[analysis_type]:
        analysis_correlation_dataframe = correlation_dataframe[
            (correlation_dataframe[ANALYSIS_TYPE] == analysis_type) 
            & (correlation_dataframe[ANALYSIS_DATA_DIR] == analysis_data_dir)
        ]
        min_kmer_length, max_kmer_length = analysis_types[analysis_type][analysis_data_dir]
        kmer_lengths = [kmer_length for kmer_length in range(min_kmer_length,max_kmer_length+1)]
        canonical = True

        filtered_correlation_dataframe = analysis_correlation_dataframe[
            analysis_correlation_dataframe[CANONICAL] == canonical
        ]
        
        
        for c_val in c_val_range:
            correlation = list(
                filtered_correlation_dataframe[
                    (filtered_correlation_dataframe[C_VAL] == c_val) &
                    (filtered_correlation_dataframe[KMER_LENGTH].isin(set(kmer_lengths)))
                ][correlation_string]
            )
            print(len(kmer_lengths))
            print(len(correlation))
            # plt.scatter(kmer_lengths,correlation,label=f"{analysis_data_dir}",s=[3 for _ in range(len(kmer_lengths))])
            plt.plot(kmer_lengths,correlation,label=f"{analysis_data_dir}")

    plt.legend()
    plt.plot()
    plt.savefig(f"../../plots/estimator/{analysis_type}-{EstimatorFracMinHash.__name__}-{canonical=}-{correlation_string}.png")
    plt.close()


In [None]:
for analysis_type in analysis_types:
    for correlation_string in ["pearson_coeff","spearman_coeff","r2_val_raw","r2_val_lin","dropped_zeros"]:
        plot_correlation(
            analysis_type=analysis_type,
            correlation_dataframe=correlation_dataframe,
            c_val_range=c_val_range,
            correlation_string=correlation_string
        )
    



        