In [82]:
import pandas as pd 
import numpy as np 
import evoVAE.utils.metrics as mt 
import evoVAE.utils.seq_tools as st
from numba import njit, prange, jit
import random



In [15]:
aln: pd.DataFrame = pd.read_pickle("/Users/sebs_mac/uni_OneDrive/honours/data/gb1/encoded_weighted/gb1_ancestors_extants_encoded_weighted_no_dupes.pkl")
aln = aln.drop_duplicates(subset=['sequence'])
#aln = aln.sample(frac=0.2)
aln.drop(columns=["encoding", "weights"], inplace=True)
print(aln.shape)
aln.head()

msa, seq_key, key_label = st.convert_msa_numpy_array(aln)
msa.shape

(14276, 2)
Sequence weight numpy array created with shape (num_seqs, columns):  (14276, 448)


(14276, 448)

In [41]:
st.write_fasta_file("gb1_ancestors_extants_no_dupes.fasta", aln)

In [33]:
results = pd.read_csv("gb1_an_ex_cluster.tsv", sep="\t", header=None)
results.columns = ["cluster", "sequence"]
mark_ancestors = lambda x: 1 if "tree" in x else 0
is_ancestor = results["sequence"].apply(mark_ancestors)
results["is_ancestor"] = is_ancestor

results.head()

Unnamed: 0,cluster,sequence,is_ancestor
0,N21_gb1_tree_1,N21_gb1_tree_1,1
1,N21_gb1_tree_1,N22_gb1_tree_1,1
2,N21_gb1_tree_1,N23_gb1_tree_1,1
3,N21_gb1_tree_1,N24_gb1_tree_1,1
4,N21_gb1_tree_1,N25_gb1_tree_1,1


In [21]:
representative_ids = results["cluster"].unique()
representative_ids.shape

(55,)

In [25]:
clusters = [results.loc[results["cluster"] == cluster] for cluster in representative_ids]

In [77]:
from typing import List, Tuple

def sample_clusters(clusters: List[pd.DataFrame], sample_ids: set, clusters_seen: int, num_clusters: int, 
                    cluster_obs: dict, is_ancestor: int, current_size: int) -> Tuple[int, int]:


    cluster_idx = clusters_seen % num_clusters
    current_cluster = clusters[cluster_idx]
    current_cluster = current_cluster[current_cluster["is_ancestor"] == is_ancestor]
    
    # can't sample if there's no extants in this cluster
    if current_cluster.shape[0] == 0 or cluster_obs[cluster_idx] == current_cluster.shape[0]:
        clusters_seen += 1
        return current_size, clusters_seen

    sample_idx = random.randint(0, current_cluster.shape[0] - 1)
    sample = current_cluster.iloc[sample_idx, 1:]

    while True:

        if sample["sequence"] not in sample_ids:
            sample_ids.add(sample["sequence"])
            current_size += 1
            clusters_seen += 1
            cluster_obs[cluster_idx] += 1
            break
                
        else: 
            sample_idx = random.randint(0, current_cluster.shape[0] - 1)
            sample = current_cluster.iloc[sample_idx, 1:]

    return current_size, clusters_seen
    

In [83]:

sample_size = 1000
extant_proportion = 0.2

def sample_extant_ancestors(clusters: List[pd.DataFrame], sample_size: int, extant_proportion: float):
    
    num_clusters = len(clusters)
    cluster_an_obs = {i: 0 for i in range(num_clusters)}
    cluster_ex_obs = {i: 0 for i in range(num_clusters)}
    
    current_size = 0
    clusters_seen = 0

    EXTANT = 0
    ANCESTOR = 1

    sample_ids = set()
    
    while current_size < sample_size:

        while (current_size / sample_size) < extant_proportion:
            current_size, clusters_seen = sample_clusters(clusters, 
                                                        sample_ids, clusters_seen, 
                                                        num_clusters, cluster_ex_obs, 
                                                        EXTANT, current_size)
            

        current_size, clusters_seen = sample_clusters(clusters, sample_ids, 
                                                    clusters_seen, num_clusters,
                                                    cluster_an_obs, ANCESTOR, 
                                                    current_size)
                
    return sample_ids

sample_ids = sample_extant_ancestors(clusters, 1000, 0.1)

final = results.loc[results["sequence"].isin(sample_ids)]



In [84]:
final

Unnamed: 0,cluster,sequence,is_ancestor
0,N21_gb1_tree_1,N21_gb1_tree_1,1
1,N21_gb1_tree_1,N22_gb1_tree_1,1
2,N21_gb1_tree_1,N23_gb1_tree_1,1
3,N21_gb1_tree_1,N24_gb1_tree_1,1
4,N21_gb1_tree_1,N25_gb1_tree_1,1
...,...,...,...
14271,N1356_gb1_tree_1,UniRef100_L7ZAY0/5-397,0
14272,N852_gb1_tree_8,N852_gb1_tree_8,1
14273,N852_gb1_tree_8,N851_gb1_tree_8,1
14274,N852_gb1_tree_8,N955_gb1_tree_11,1


In [16]:

test = st.read_aln_file("../data/pair_test.aln")
test_msa, _, _ = st.convert_msa_numpy_array(test)

@njit(parallel=True)
def adj_matrix(msa) -> np.ndarray:

    sim_matrix = np.ones((msa.shape[0], msa.shape[0]))
    seq_len = len(msa[0])

    for i in prange(msa.shape[0]):
        for j in prange(i + 1, msa.shape[0]):
            dist = 1 - (mt.hamming_distance(msa[i], msa[j]) / seq_len)
            sim_matrix[i, j] = sim_matrix[j, i] = dist 


    return sim_matrix

adj_matrix(test_msa)


Reading the alignment: ../data/pair_test.aln
Checking for bad characters: ['B', 'J', 'X', 'Z', 'U']
Performing one hot encoding
Number of seqs: 3
Sequence weight numpy array created with shape (num_seqs, columns):  (3, 4)


array([[1.  , 1.  , 0.25],
       [1.  , 1.  , 0.25],
       [0.25, 0.25, 1.  ]])

In [17]:
mat = adj_matrix(msa)

In [33]:
from sklearn.cluster import AgglomerativeClustering

model = AgglomerativeClustering(metric="precomputed", linkage="complete")
clustering = model.fit(mat)

In [38]:
clustering.n_leaves_

14276