In [None]:
# If installed from pip, import lostruct as ls will work
import lostruct.lostruct as ls
import pickle
import random

# PCoA from skbio.stats is the best implementation of R's MDS algorithm
from skbio.stats.ordination import pcoa

# Much of the output from CyVCF2 and lostruct are numpy arrays
import numpy as np

import pandas as pd
import plotly.express as px
from sklearn.manifold import MDS
import plotly.io as pio
# pio.renderers.default = "notebook_connected"
pio.renderers.default = "plotly_mimetype"

import polars as pl

In [None]:
np.__version__

In [None]:
# Load up metadata from "DNA from hoiho genomesv2.csv"
metadata = pl.read_csv(
    "../Hoiho_Genomes_24Feb2024_JGG_3Pops.csv", separator="\t"
)

metadata = metadata.with_columns(
    pl.col("ID").replace("P29 ", "P29").alias("ID")
)

# Next is C101/CE9
metadata = metadata.with_columns(
    pl.col("ID").replace("C101/CE9", "CE9").alias("ID")
)

In [None]:
bcf_file = "../merged.a9.filtered.qual99_fmissing0.2.maf0.05.biallelic.bcf"
samples = ls.get_samples(bcf_file)

In [None]:
pop_weights = pd.DataFrame(samples, columns=["ID"], dtype=str)
pop_weights.set_index("ID", inplace=True)

metadata_pd = metadata.to_pandas()
metadata_pd.set_index("ID", inplace=True)
metadata_pd.index = metadata_pd.index.astype(str)


# Merge with metadata (samples bcf will have less than metadata)
pop_weights = pop_weights.join(
    metadata_pd, on="ID", how="left"
)

In [None]:
# Count members of each Population from Population3 column
pop_weights_count = pop_weights.groupby("Population3").size()
pop_weights_count

pop_weights_count / pop_weights_count.sum()

# For each sample, create a vector of weights using the population weight (from the Population3 column)
pop_weights["PopWeight"] = pop_weights["Population3"].apply(
    lambda x: pop_weights_count[x] / pop_weights_count.sum()
)


In [None]:
# Get weights as a numpy array
weights = pop_weights["PopWeight"].to_numpy()

In [None]:
landmarks = ls.get_landmarks(bcf_file)

results = list()
snp_positions = list()
windows_positions = list()
chrs = list()

window_size = 512
# weights = 1

last_window = False

for landmark in landmarks:
    # len(positions) is the total number of windows, not the total number of SNPs
    windows, positions = ls.parse_vcf(bcf_file, landmark, window_size)

    # Debugging
    # if len(results) > 100:
        # break

    for i, window in enumerate(windows):
        if len(positions[i]) < window_size:
            break
        windows_positions.append([landmark, positions[i]])
        chrs.append(landmark)
        last_window = window # Debugging
        results.append(ls.eigen_windows(window, 10, weights))
        snp_positions.append(positions[i])

In [None]:
last_window

In [None]:
snps = last_window.todense()
n = len(snps)
print([np.any(np.isnan(snps)), np.any(np.isinf(snps))])
rowmeans = np.nanmean(snps, axis=1)
rowmeans = np.reshape(rowmeans, (n, 1))
subtracted = np.array(snps - rowmeans, dtype=np.float64)
covmat = np.ma.cov(np.ma.array(subtracted, mask=np.isnan(subtracted)), rowvar=False)
print([np.any(np.isnan(covmat)), np.any(np.isinf(covmat))])

total_variance = np.sum(np.power(covmat, 2).flatten())
vals, vectors = np.linalg.eig(covmat)

In [None]:
np.__version__

In [None]:
# Which entries snps are NaN or Inf?
print([np.any(np.isnan(snps)), np.any(np.isinf(snps))])
# Print the actual values that are NaN or Inf
#print(np.argwhere(np.isnan(snps)))
print(np.argwhere(np.isinf(snps)))




In [None]:
snps[0][0]

In [None]:
# Does covmat contain inf or nan?
[np.any(np.isnan(covmat)), np.any(np.isinf(covmat))]

In [None]:
len(results)

In [None]:
np.save(f"windows_n{window_size}.npy", windows)
np.save(f"positions_n{window_size}.npy", positions)
# np.save("results_n95.npy", results, allow_pickle=True)

with open(f"results_n{window_size}.pkl", 'wb') as f:  # Python 3: open(..., 'wb')
    pickle.dump(results, f)

#with open(f"results_n{window_size}.pkl") as f:  # Python 3: open(..., 'rb')
#    results = pickle.load(f)


In [None]:
# Convert to numpy array
# results = np.vstack(results)

# Get PCA distances comparison matrix
pc_dists = ls.get_pc_dists(results, fastmath=True)

In [None]:
np.save(f"pc_dists_n{window_size}.npy", pc_dists, allow_pickle=True)

In [None]:
mds = pcoa(pc_dists, method="fsvd", inplace=True, number_of_dimensions=4)

In [None]:
mds.samples.to_csv("mds.csv", index=False)
mds.eigvals.to_csv("eigvals.csv", index=False)
mds.write("mds_pcoa_output")

In [None]:
mds.eigvals

In [None]:
mds.proportion_explained

In [None]:
# Add chr's to the MDS samples
mds_samples = mds.samples
mds_samples["chr"] = chrs


In [None]:
fig = px.scatter(mds.samples, y="PC1", color="chr")
fig

In [None]:
# Find all windows within 2 sd's of the mean
mean = np.mean(mds.samples["PC1"])
sd = np.std(mds.samples["PC1"])
inliers = mds.samples[(mds.samples["PC1"] > mean - 1 * sd) & (mds.samples["PC1"] < mean + 1 * sd)]

# Choose 10 random windows from the inliers
chosen = inliers.sample(10)

# Get the positions of the chosen windows
chosen_positions = [windows_positions[int(i)] for i in chosen.index]
# Print in bed format
for pos in chosen_positions:
    chr = pos[0]
    start = pos[1][0]
    end = pos[1][-1]
    print(f"{chr}\t{start}\t{end}")

In [None]:
chosen.index

In [None]:
fig = px.scatter(mds.samples, y="PC2", color="chr")
fig

In [None]:
fig = px.scatter(mds.samples, y="PC3", color="chr")
fig

In [None]:
fig = px.scatter(mds.samples, y="PC4", color="chr")
fig

In [None]:
fig = px.scatter(mds.samples, x="PC1", y="PC2", color="chr")
fig

In [None]:
windows_positions[311]

In [None]:
# Windows that explain
eigenvectors = [x[2] for x in results]
eigenvectors = np.array(eigenvectors)

# Proportion (per row)
eigenvectors = eigenvectors / np.sum(eigenvectors, axis=1)[:, None]
eigenvectors

In [None]:
# Find index of the windows that explain the most variance
max_explained = np.argmax(eigenvectors, axis=0)
max_explained

In [None]:
eigenvectors[0] / np.sum(eigenvectors[0])

In [None]:
pcas = [x[3] for x in results]

In [None]:
px.scatter(np.mean(pcas, axis=1).T, x=0, y=1)

In [None]:
window_pca = results[311][3]

In [None]:
# Load up metadata from "DNA from hoiho genomesv2.csv"
metadata = pl.read_csv(
    "../Hoiho_Genomes_24Feb2024_JGG_3Pops.csv", separator="\t"
)

metadata = metadata.with_columns(
    pl.col("ID").replace("P29 ", "P29").alias("ID")
)

# Next is C101/CE9
metadata = metadata.with_columns(
    pl.col("ID").replace("C101/CE9", "CE9").alias("ID")
)

In [None]:
# Add samples to window_pca
window_pca_metadata = pd.DataFrame(window_pca.T, columns=["PC1", "PC2", "PC3", "PC4", "PC5", "PC6", "PC7", "PC8", "PC9", "PC10"])
window_pca_metadata["ID"] = samples


In [None]:
# Merge with Population3 from metadata
window_pca_metadata = pd.merge(window_pca_metadata, metadata.to_pandas(), on="ID", how="left")


In [None]:
px.scatter(window_pca_metadata, x="PC1", y="PC2", color="Population3")

In [None]:
# Plot the Weight (Kg) for all PC1 <= 0 and PC1 > 0
window_pca_metadata["Weight"] = window_pca_metadata["Weight (kg)"].astype(float)
window_pca_metadata["PC1"] = window_pca_metadata["PC1"].astype(float)

# Do a box plot
window_pca_metadata["RightOfPC1"] = window_pca_metadata["PC1"] > -0.05

px.box(window_pca_metadata, x="RightOfPC1", y="Weight", color="Population3")