In [1]:
import sys

# Replace this with the path to the repository
# data_path = "/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Species/Staphylococcus hominis"
# data_path = "/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Species/Mycoplasmoides pneumoniae"
# data_path = "/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family-Cross-Species/Salmonella-diarizonae"
data_path = "/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae"
sys.path.append(data_path)
sys.path.append("/home/bensonlzl/Desktop/UROP/GIS-2024/coding/ETFMH/")

In [2]:
from db_sketching.genome_selection import GenomeFiltering
from db_sketching.kmer_set import KMerSet, FracMinHash


def cond(kmer_hash):
    hash = (976369 * kmer_hash + 1982627) % 10000
    if hash < 100:
        return True
    else:
        return False

kmer_set = FracMinHash(cond, 12, True)
gf = GenomeFiltering(kmer_set)

In [3]:
from glob import glob


for f in glob(data_path+"/*.fna"):
    gf.insert_genome(f)
    print(f)

/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/479.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/306.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/228.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/404.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/112.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/164.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/429.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/367.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/64.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadaceae/491.fna
/home/bensonlzl/Desktop/UROP/GIS-2024/coding/data_temp/Single-Family/Aeromonadace

In [4]:
import numpy as np

genome_list = list(gf.genome_dict.keys())
distance_matrix = np.zeros((len(genome_list), len(genome_list)))
for i in range(len(genome_list)):
    for j in range(i+1, len(genome_list)):
        distance = gf.genome_dict[genome_list[i]].resemblence(gf.genome_dict[genome_list[j]])
        distance_matrix[i][j] = distance_matrix[j][i] = 1 - distance

KeyboardInterrupt: 

In [None]:
import seaborn

seaborn.heatmap(distance_matrix)

In [None]:
hc = gf.hierarchical_clustering(genome_list, distance_matrix, 0.1)
hc

In [None]:
# Naive way of selecting signature
from collections import Counter

naive_counter = Counter()

for f in genome_list:
    naive_counter.update(list(gf.genome_dict[f].set))


In [None]:
seaborn.histplot([i for i in list(naive_counter.values()) if i >= 0], bins=200)


In [None]:
# After hierarchical clustering
import random
hc_counter = Counter()

for j in hc:
    sampled_genome = random.sample(j, 1)
    print(sampled_genome)
    hc_counter.update(list(gf.genome_dict[genome_list[sampled_genome[0]]].set))

In [None]:
hc_counts = [i for i in list(hc_counter.values()) if i >= 0]
m1_hc_counts = [x-1 for x in hc_counts]
n = len(hc) - 1

In [None]:
# seaborn.histplot(hc_counts, bins=200)
seaborn.histplot(m1_hc_counts, bins=200)

In [None]:
from scipy.special import gammaln

def zinb_pmf(k, mu, theta, p):
    term1 = np.exp(gammaln(k + theta) - gammaln(k + 1) - gammaln(theta))
    term2 = (theta / (theta + mu)) ** theta
    term3 = (mu / (theta + mu)) ** k
    nb_pmf = term1 * term2 * term3

    if k == 0:
        return p + (1 - p) * nb_pmf
    else:
        return (1 - p) * nb_pmf
    
def combined_zinb_pmf(k, n, mu1, theta1, p1, mu2, theta2, p2, w1):
    w2 = 1 - w1
    pmf1 = w1 * zinb_pmf(k, mu1, theta1, p1)
    pmf2 = w2 * zinb_pmf(n-k, mu2, theta2, p2) # reflected across the multiplicities
    return pmf1 + pmf2


In [None]:
from scipy.optimize import minimize

# Define the negative log-likelihood function
def neg_log_likelihood(params, data):
    mu1, theta1, p1, mu2, theta2, p2, w1 = params
    likelihoods = [combined_zinb_pmf(k, n, mu1, theta1, p1, mu2, theta2, p2, w1) for k in data]
    return -np.sum(np.log(likelihoods))

# Initial parameter guesses
initial_params = [1, 1, 0.3, 1, 1, 0.4, 0.5]

# Set parameter bounds
bounds = [(0.1, None), (0.1, None), (0, 1), (0.1, None), (0.1, None), (0, 1), (0, 1)]

# Perform the optimization
result = minimize(neg_log_likelihood, initial_params, args=(m1_hc_counts,), method='L-BFGS-B', bounds=bounds)
fitted_params = result.x

print(f"Fitted parameters: {fitted_params}")


In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10,10))
mu1, theta1, p1, mu2, theta2, p2, w1 = fitted_params
print(mu1,theta1,p1)
print(mu2,theta2,p2)
print(w1)

target_data = data_path.split('/')[-1]
print(target_data)
ax.set_title(f"Mixed ZINB Regression on {target_data} k-mer multiplicities")
ax.set_xlabel("Number of species kmer appears in")
ax.set_ylabel("Number of kmers")
ax.plot(list(range(0,n+1)), [combined_zinb_pmf(k, n, mu1, theta1, p1, mu2, theta2, p2, w1)*len(m1_hc_counts) for k in range(0,n+1)], label="ZINB Regression")
ax.hist(m1_hc_counts, bins=200, label="Empirical kmer multiplicities")
ax.legend()

In [None]:
print(sum([combined_zinb_pmf(k, n, mu1, theta1, p1, mu2, theta2, p2, w1) for k in range(0,n+1)]))
print(sum([zinb_pmf(k, mu1, theta1, p1) for k in range(0,n+1)]))