In [4]:

import pickle
import numpy as np
from gensim.models import Word2Vec
import random
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt
from collections import defaultdict, deque
import json
import subprocess
import os
from Bio import SeqIO
from io import StringIO
import re

print("="*80)
print("EXTRACTING SEQUENCES FROM BLAST DATABASES")
print("="*80)

def getSeqs(fastaPath):
    seqsByGenus = defaultdict(list)
    
    with open(fastaPath, 'r') as f:
        for record in SeqIO.parse(f, "fasta"):
            header = record.description
            seq = str(record.seq).upper()
            
            if len(seq) < 100:
                continue
            
            cleaned = ''.join(c for c in seq if c in 'ATCGUN')
            if len(cleaned) < 100:
                continue
            
            # genus is usually in brackets or first word
            genus = None
            bracketMatch = re.search(r'\[([A-Z][a-z]+)\s+', header)
            if bracketMatch:
                genus = bracketMatch.group(1)
            else:
                words = header.split()
                for w in words:
                    if re.match(r'^[A-Z][a-z]{2,}$', w):
                        genus = w
                        break
            
            if not genus:
                genus = record.id.split('_')[0].split('.')[0]
            
            if genus:
                seqsByGenus[genus].append(cleaned)
    
    return dict(seqsByGenus)

lsuGenus2Seq = getSeqs("LSUs/LSU_eukaryote_rRNA.fasta")
print(f"Got {len(lsuGenus2Seq)} LSU genera")

ssuGenus2Seq = getSeqs("SSUs/SSU_eukaryote_rRNA.fasta")  
print(f"Got {len(ssuGenus2Seq)} SSU genera")

commonGenera = list(set(lsuGenus2Seq.keys()) & set(ssuGenus2Seq.keys()))
print(f"Common genera: {len(commonGenera)}")

if not commonGenera:
    print("No common genera found!")
    exit(1)

# save for later
with open("genus_to_lsu_sequences.pkl", "wb") as f:
    pickle.dump(lsuGenus2Seq, f)
with open("genus_to_ssu_sequences.pkl", "wb") as f:
    pickle.dump(ssuGenus2Seq, f)

def makeKmerSentences(sequences, k=6):
    sentences = []
    for seq in sequences:
        kmers = [seq[i:i+k] for i in range(len(seq)-k+1) if len(seq[i:i+k]) == k]
        sentences.append(kmers)
    return sentences

class PhyloW2V:
    def __init__(self, sentences, vecSize=128, window=5):
        self.sentences = sentences
        self.vecSize = vecSize
        self.window = window
        self.model = Word2Vec(sentences, vector_size=vecSize, window=window, 
                              min_count=1, workers=4)
        
    def embedSeq(self, seq, k=6):
        kmers = [seq[i:i+k] for i in range(len(seq)-k+1) 
                if len(seq[i:i+k]) == k and seq[i:i+k] in self.model.wv]
        if not kmers:
            return np.zeros(self.vecSize)
        return np.mean([self.model.wv[kmer] for kmer in kmers], axis=0)
    
    def embedGenus(self, genusSeqs, k=6, maxSamples=20):
        samples = random.sample(genusSeqs, min(maxSamples, len(genusSeqs)))
        embs = [self.embedSeq(seq, k) for seq in samples]
        embs = [e for e in embs if np.any(e)]
        if not embs:
            return np.zeros(self.vecSize)
        return normalize([np.mean(embs, axis=0)], norm='l2')[0]
    
    def retrain(self, kmerWeights):
        """
        So here's the trick - we're doing adversarial training between LSU and SSU models.
        When they disagree on how to split genera, we identify k-mers that are distinctive
        for each split and upweight them in the retraining. This forces the models to
        learn better representations that align with phylogenetic signal. Pretty neat IMO.
        
        The weight determines how many times we repeat each k-mer in the training data,
        effectively making the model pay more attention to phylogenetically informative k-mers.
        """
        weightedSents = []
        for sent in self.sentences:
            ws = []
            for kmer in sent:
                weight = kmerWeights.get(kmer, 1.0)
                repeat = max(1, int(weight * 2))
                ws.extend([kmer] * repeat)
            if ws:
                weightedSents.append(ws)
        
        self.model = Word2Vec(weightedSents, vector_size=self.vecSize, 
                            window=self.window, min_count=1, workers=4)

print("\nTraining Word2Vec models...")
lsuSents = [s for genus in lsuGenus2Seq.values() for s in makeKmerSentences(genus)]
ssuSents = [s for genus in ssuGenus2Seq.values() for s in makeKmerSentences(genus)]

lsuW2v = PhyloW2V(lsuSents)
ssuW2v = PhyloW2V(ssuSents)

class TreeNode:
    def __init__(self, genera, path=""):
        self.genera = genera
        self.path = path
        self.left = None
        self.right = None
        self.score = 0.0
        
    def isLeaf(self):
        return self.left is None and self.right is None

class PhyloBuilder:
    def __init__(self, lsuModel, ssuModel, maxLeaf=6, minCluster=2):
        self.lsu = lsuModel
        self.ssu = ssuModel
        self.maxLeaf = maxLeaf
        self.minCluster = minCluster
        self.kmerFeedback = defaultdict(float)

        
    def getEmbs(self, genera):
        lsuEmbs = []
        ssuEmbs = []
        
        for g in genera:
            lsuEmbs.append(self.lsu.embedGenus(lsuGenus2Seq[g]))
            ssuEmbs.append(self.ssu.embedGenus(ssuGenus2Seq[g]))
            
        return np.array(lsuEmbs), np.array(ssuEmbs)
    
    def calcAgreement(self, labels1, labels2):
        # jaccard-like metric for partition agreement
        n = len(labels1)
        pairs1 = set()
        pairs2 = set()
        
        for i in range(n):
            for j in range(i+1, n):
                if labels1[i] == labels1[j]:
                    pairs1.add((i, j))
                if labels2[i] == labels2[j]:
                    pairs2.add((i, j))
        
        if not pairs1 and not pairs2:
            return 1.0
        if not pairs1 or not pairs2:
            return 0.0
        
        inter = len(pairs1 & pairs2)
        union = len(pairs1 | pairs2)
        jaccard = inter / union if union > 0 else 0
        
        # check flipped
        labels2Flip = 1 - labels2
        pairs2Flip = set()
        for i in range(n):
            for j in range(i+1, n):
                if labels2Flip[i] == labels2Flip[j]:
                    pairs2Flip.add((i, j))
        
        interFlip = len(pairs1 & pairs2Flip)
        unionFlip = len(pairs1 | pairs2Flip)
        jaccardFlip = interFlip / unionFlip if unionFlip > 0 else 0
        
        return max(jaccard, jaccardFlip)
    
    def split(self, genera, tier=0):
        if len(genera) <= self.maxLeaf:
            return None, None, 0.0
            
        if len(genera) < 2 * self.minCluster:
            return None, None, 0.0
        
        lsuEmbs, ssuEmbs = self.getEmbs(genera)
        
        km1 = KMeans(n_clusters=2, n_init=10)
        km2 = KMeans(n_clusters=2, n_init=10)
        
        lsuLabels = km1.fit_predict(lsuEmbs)
        ssuLabels = km2.fit_predict(ssuEmbs)
        
        agreement = self.calcAgreement(lsuLabels, ssuLabels)
        
        if agreement < 0.5:
            ssuLabels = 1 - ssuLabels
            agreement = self.calcAgreement(lsuLabels, ssuLabels)
        
        # tier weighting
        tierWeight = np.exp(-tier * 0.2)
        weightedAgr = agreement ** (0.5 + 0.5 * tierWeight)
        
        consensusWeight = 0.5 + 0.1 * agreement
        consensus = (lsuLabels * consensusWeight + ssuLabels * (1 - consensusWeight))
        consensus = (consensus >= 0.5).astype(int)
        
        grp0 = [genera[i] for i in range(len(genera)) if consensus[i] == 0]
        grp1 = [genera[i] for i in range(len(genera)) if consensus[i] == 1]
        
        if len(grp0) < self.minCluster or len(grp1) < self.minCluster:
            return None, None, 0.0
        
        if len(grp0) > len(grp1):
            grp0, grp1 = grp1, grp0
            
        self._updateKmers(genera, grp0, grp1, weightedAgr)
        
        return grp0, grp1, weightedAgr
    
    def _updateKmers(self, allGenera, leftGrp, rightGrp, score):
        """
        This is where the magic happens - we identify k-mers that distinguish
        between the two groups and feed that back to retrain the models.
        The more a k-mer differs between groups (cross-modally), the more we weight it.
    
        We use symmetric cross-modal comparisons (LSU-left vs SSU-right, and SSU-left vs LSU-right)
        to force adversarial alignment between models.
        """
    
        leftLsuSeqs = [seq for g in leftGrp[:12] for seq in lsuGenus2Seq.get(g, [])[:6]]
        rightLsuSeqs = [seq for g in rightGrp[:12] for seq in lsuGenus2Seq.get(g, [])[:6]]
        leftSsuSeqs = [seq for g in leftGrp[:12] for seq in ssuGenus2Seq.get(g, [])[:6]]
        rightSsuSeqs = [seq for g in rightGrp[:12] for seq in ssuGenus2Seq.get(g, [])[:6]]
    
        if not (leftLsuSeqs and rightSsuSeqs and leftSsuSeqs and rightLsuSeqs):
            return
    
        leftLsuKmers = defaultdict(int)
        rightLsuKmers = defaultdict(int)
        leftSsuKmers = defaultdict(int)
        rightSsuKmers = defaultdict(int)
    
        for seq in leftLsuSeqs:
            for i in range(len(seq) - 5):
                leftLsuKmers[seq[i:i+6]] += 1
        for seq in rightLsuSeqs:
            for i in range(len(seq) - 5):
                rightLsuKmers[seq[i:i+6]] += 1
        for seq in leftSsuSeqs:
            for i in range(len(seq) - 5):
                leftSsuKmers[seq[i:i+6]] += 1
        for seq in rightSsuSeqs:
            for i in range(len(seq) - 5):
                rightSsuKmers[seq[i:i+6]] += 1
    
        totalLeftLsu = sum(leftLsuKmers.values())
        totalRightLsu = sum(rightLsuKmers.values())
        totalLeftSsu = sum(leftSsuKmers.values())
        totalRightSsu = sum(rightSsuKmers.values())
    
        if min(totalLeftLsu, totalRightSsu, totalLeftSsu, totalRightLsu) == 0:
            return
    
        allKmers1 = set(leftLsuKmers.keys()) | set(rightSsuKmers.keys())
        for kmer in allKmers1:
            freqLeft = leftLsuKmers.get(kmer, 0) / totalLeftLsu
            freqRight = rightSsuKmers.get(kmer, 0) / totalRightSsu
            diff = abs(freqLeft - freqRight)
            if diff > 0.001:  
                self.kmerFeedback[kmer] += diff * score * 0.5  
    
    
        allKmers2 = set(leftSsuKmers.keys()) | set(rightLsuKmers.keys())
        for kmer in allKmers2:
            freqLeft = leftSsuKmers.get(kmer, 0) / totalLeftSsu
            freqRight = rightLsuKmers.get(kmer, 0) / totalRightLsu
            diff = abs(freqLeft - freqRight)
            if diff > 0.001:  
                self.kmerFeedback[kmer] += diff * score * 0.5  
       
    def buildTree(self, genera=None, maxDepth=10):
        if genera is None:
            genera = commonGenera
            
        print(f"\nBuilding tree for {len(genera)} genera...")
        
        root = TreeNode(genera, "")
        q = deque([(root, 0)])
        
        totalScore = 0.0
        nSplits = 0
        
        while q:
            node, tier = q.popleft()
            
            shouldSplit = len(node.genera) > self.maxLeaf
            
            if tier >= maxDepth and len(node.genera) <= self.maxLeaf * 2:
                shouldSplit = False
            
            if not shouldSplit:
                continue
                
            left, right, score = self.split(node.genera, tier)
            
            if left and right:
                node.left = TreeNode(left, node.path + "L")
                node.right = TreeNode(right, node.path + "R")
                node.score = score
                
                totalScore += score
                nSplits += 1
                
                q.append((node.left, tier + 1))
                q.append((node.right, tier + 1))
                
                print(f"T{tier}: {len(node.genera)} → {len(left)}L | {len(right)}R (agr: {score:.3f})")
        
        avgScore = totalScore / max(nSplits, 1)
        print(f"Avg agreement: {avgScore:.3f}")
        return root, avgScore

print("\n" + "="*80)
print("ADVERSARIAL REFINEMENT")
print("="*80)

builder = PhyloBuilder(lsuW2v, ssuW2v, maxLeaf=6)
bestScore = 0.0
bestTree = None

MAX_ITER = 30
PATIENCE = 5
MIN_IMPROVE = 0.001

scores = []
patience = 0

for it in range(MAX_ITER):
    print(f"\n--- Iteration {it+1} ---")
    

    tree, score = builder.buildTree(maxDepth=12)
    scores.append(score)
    
    improve = score - bestScore
    
    if improve > MIN_IMPROVE:
        bestScore = score
        lsuW2v.model.save("best_lsu.model")
        ssuW2v.model.save("best_ssu.model")
        bestTree = tree
        patience = 0
        print(f"New best: {bestScore:.3f}")
        
        with open("best_tree.pkl", "wb") as f:
            pickle.dump(bestTree, f)
    else:
        patience += 1
        print(f"No improvement, patience: {patience}/{PATIENCE}")
    
    if patience >= PATIENCE:
        print(f"\nConverged after {it+1} iterations")
        break
    
    # check plateau
    if len(scores) >= 3:
        if np.std(scores[-3:]) < MIN_IMPROVE / 2:
            print(f"\nPlateaued after {it+1} iterations")
            break
    
    if builder.kmerFeedback:
        maxFb = max(builder.kmerFeedback.values())
        normFb = {k: 0.4 + 1.8 * v/maxFb for k, v in builder.kmerFeedback.items()}  
        
        print("Retraining with feedback...")
        lsuW2v.retrain(normFb)
        ssuW2v.retrain(normFb)
        
        builder = PhyloBuilder(lsuW2v, ssuW2v, maxLeaf=6)

print(f"\nFinal score: {bestScore:.3f}")

# save final models
from gensim.models import Word2Vec as W2V
finalLsu = W2V.load("best_lsu.model")
finalSsu = W2V.load("best_ssu.model")

finalLsu.save("lsu_phylo.model")
finalSsu.save("ssu_phylo.model")

with open("best_tree.pkl", "rb") as f:
    finalTree = pickle.load(f)

def tree2dict(node):
    if node.isLeaf():
        return {
            "path": node.path,
            "genera": node.genera,
            "size": len(node.genera)
        }
    return {
        "path": node.path,
        "size": len(node.genera),
        "score": node.score,
        "left": tree2dict(node.left) if node.left else None,
        "right": tree2dict(node.right) if node.right else None
    }

treeDict = tree2dict(finalTree)
with open("tree.json", "w") as f:
    json.dump(treeDict, f, indent=2)

with open("tree.pkl", "wb") as f:
    pickle.dump(finalTree, f)

class Navigator:
    def __init__(self, root):
        self.root = root
        self.current = root
        self.history = []
        
    def go(self, dir):
        dir = dir.upper()
        if dir == 'L' and self.current.left:
            self.current = self.current.left
            self.history.append('L')
            return True
        elif dir == 'R' and self.current.right:
            self.current = self.current.right
            self.history.append('R')
            return True
        return False
    
    def show(self, showGenera=False):
        path = ''.join(self.history) if self.history else "ROOT"
        print(f"\n{'='*60}")
        print(f"Position: [{path}]")
        print(f"Genera: {len(self.current.genera)}")
        
        if self.current.score > 0:
            print(f"Score: {self.current.score:.3f}")
        
        if not self.current.isLeaf():
            if self.current.left:
                print(f"  Left (L): {len(self.current.left.genera)} genera")
            if self.current.right:
                print(f"  Right (R): {len(self.current.right.genera)} genera")
        else:
            print("  [LEAF]")
        
        if showGenera:
            print(f"\nGenera:")
            glist = sorted(self.current.genera)[:20]
            for i, g in enumerate(glist, 1):
                print(f"  {i:3d}. {g}")
            if len(self.current.genera) > 20:
                print(f"  ... +{len(self.current.genera) - 20} more")
        print(f"{'='*60}")
    
    def reset(self):
        self.current = self.root
        self.history = []
        
    def run(self):
        print("\n" + "="*80)
        print("TREE NAVIGATOR")
        print("="*80)
        print("\nCommands: goto L/R, show, new, quit")
        
        self.show()
        
        while True:
            try:
                cmd = input("\n> ").strip().lower()
                
                if cmd.startswith("goto"):
                    parts = cmd.split()
                    if len(parts) == 2 and parts[1].upper() in ['L', 'R']:
                        if self.go(parts[1]):
                            self.show()
                        else:
                            print("Can't go there")
                    else:
                        print("Use: goto L or goto R")
                        
                elif cmd == "show":
                    self.show(showGenera=True)
                    
                elif cmd == "new":
                    self.reset()
                    print("\nBack to root")
                    self.show()
                    
                elif cmd in ["quit", "exit", "q"]:
                    break
                    
                else:
                    print("Unknown command")
                    
            except KeyboardInterrupt:
                break
            except Exception as e:
                print(f"Error: {e}")

print("\n" + "="*80)
print(f"DONE! Converged after {len(scores)} iterations")
print(f"Best score: {bestScore:.3f}")
print("="*80)

nav = Navigator(finalTree)
nav.run()


EXTRACTING SEQUENCES FROM BLAST DATABASES
Got 4640 LSU genera
Got 5637 SSU genera
Common genera: 1598

Training Word2Vec models...

ADVERSARIAL REFINEMENT

--- Iteration 1 ---

Building tree for 1598 genera...
T0: 1598 → 696L | 902R (agr: 0.657)
T1: 696 → 279L | 417R (agr: 0.505)
T1: 902 → 324L | 578R (agr: 0.515)
T2: 279 → 103L | 176R (agr: 0.656)
T2: 417 → 170L | 247R (agr: 0.445)
T2: 324 → 66L | 258R (agr: 0.659)
T2: 578 → 213L | 365R (agr: 0.683)
T3: 103 → 33L | 70R (agr: 0.576)
T3: 176 → 78L | 98R (agr: 0.550)
T3: 170 → 69L | 101R (agr: 0.603)
T3: 247 → 55L | 192R (agr: 0.548)
T3: 66 → 25L | 41R (agr: 0.616)
T3: 258 → 99L | 159R (agr: 0.534)
T3: 213 → 67L | 146R (agr: 0.465)
T3: 365 → 164L | 201R (agr: 0.700)
T4: 33 → 12L | 21R (agr: 0.581)
T4: 70 → 10L | 60R (agr: 0.949)
T4: 78 → 29L | 49R (agr: 0.603)
T4: 98 → 17L | 81R (agr: 0.552)
T4: 69 → 14L | 55R (agr: 0.593)
T4: 101 → 35L | 66R (agr: 0.810)
T4: 55 → 17L | 38R (agr: 0.501)
T4: 192 → 41L | 151R (agr: 0.903)
T4: 25 → 9L | 16R

In [None]:
nav = Navigator(finalTree)
nav.run()


TREE NAVIGATOR

Commands: goto L/R, show, new, quit

Position: [ROOT]
Genera: 1598
Score: 0.583
  Left (L): 735 genera
  Right (R): 863 genera

Position: [R]
Genera: 863
Score: 0.724
  Left (L): 128 genera
  Right (R): 735 genera

Position: [RL]
Genera: 128
Score: 0.740
  Left (L): 53 genera
  Right (R): 75 genera

Position: [RLR]
Genera: 75
Score: 0.493
  Left (L): 32 genera
  Right (R): 43 genera

Position: [RLRR]
Genera: 43
Score: 0.522
  Left (L): 14 genera
  Right (R): 29 genera

Position: [RLRRR]
Genera: 29
Score: 0.683
  Left (L): 10 genera
  Right (R): 19 genera

Position: [RLRRRL]
Genera: 10
Score: 0.451
  Left (L): 5 genera
  Right (R): 5 genera

Position: [RLRRRL]
Genera: 10
Score: 0.451
  Left (L): 5 genera
  Right (R): 5 genera

Genera:
    1. Aenigmaticum
    2. Bittacus
    3. Brachypanorpa
    4. Erynia
    5. Oxyuris
    6. Panorpa
    7. Panorpodes
    8. Podura
    9. Sericoderus
   10. Smittium


In [None]:

import pickle
import numpy as np
from gensim.models import Word2Vec as W2V
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize
from collections import defaultdict, deque
import json
import random
import re
from Bio import SeqIO

# load existing stuff
with open("genus_to_lsu_sequences.pkl", "rb") as f:
    lsuGenus2Seq = pickle.load(f)
with open("genus_to_ssu_sequences.pkl", "rb") as f:
    ssuGenus2Seq = pickle.load(f)

lsuModel = W2V.load("lsu_phylo.model")
ssuModel = W2V.load("ssu_phylo.model")

# reuse kmer func
def makeKmerSentences(sequences, k=6):
    sentences = []
    for seq in sequences:
        kmers = [seq[i:i+k] for i in range(len(seq)-k+1) if len(seq[i:i+k]) == k]
        sentences.append(kmers)
    return sentences

class SinglePhyloW2V:
    def __init__(self, model):
        self.model = model
        self.vecSize = model.vector_size
        
    def embedSeq(self, seq, k=6):
        kmers = [seq[i:i+k] for i in range(len(seq)-k+1) 
                if len(seq[i:i+k]) == k and seq[i:i+k] in self.model.wv]
        if not kmers:
            return np.zeros(self.vecSize)
        return np.mean([self.model.wv[kmer] for kmer in kmers], axis=0)
    
    def embedGenus(self, genusSeqs, k=6, maxSamples=20):
        samples = random.sample(genusSeqs, min(maxSamples, len(genusSeqs)))
        embs = [self.embedSeq(seq, k) for seq in samples]
        embs = [e for e in embs if np.any(e)]
        if not embs:
            return np.zeros(self.vecSize)
        return normalize([np.mean(embs, axis=0)], norm='l2')[0]

class TreeNode:
    def __init__(self, genera, path=""):
        self.genera = genera
        self.path = path
        self.left = None
        self.right = None
        self.score = 0.0  # added for compatibility with combined tree
        
    def isLeaf(self):
        return self.left is None and self.right is None

class SingleBuilder:
    def __init__(self, w2v, genus2Seq, maxLeaf=6, minCluster=2):
        self.w2v = w2v
        self.genus2Seq = genus2Seq
        self.maxLeaf = maxLeaf
        self.minCluster = minCluster
        
    def getEmbs(self, genera):
        embs = []
        for g in genera:
            embs.append(self.w2v.embedGenus(self.genus2Seq[g]))
        return np.array(embs)
    
    def split(self, genera, tier=0):
        if len(genera) <= self.maxLeaf:
            return None, None
            
        if len(genera) < 2 * self.minCluster:
            return None, None
        
        embs = self.getEmbs(genera)
        
        km = KMeans(n_clusters=2, n_init=10)
        labels = km.fit_predict(embs)
        
        grp0 = [genera[i] for i in range(len(genera)) if labels[i] == 0]
        grp1 = [genera[i] for i in range(len(genera)) if labels[i] == 1]
        
        if len(grp0) < self.minCluster or len(grp1) < self.minCluster:
            return None, None
        
        if len(grp0) > len(grp1):
            grp0, grp1 = grp1, grp0
            
        return grp0, grp1
    
    def buildTree(self, genera, maxDepth=10):
        print(f"\nBuilding tree for {len(genera)} genera...")
        
        root = TreeNode(genera, "")
        q = deque([(root, 0)])
        
        while q:
            node, tier = q.popleft()
            
            shouldSplit = len(node.genera) > self.maxLeaf
            
            if tier >= maxDepth and len(node.genera) <= self.maxLeaf * 2:
                shouldSplit = False
            
            if not shouldSplit:
                continue
                
            left, right = self.split(node.genera, tier)
            
            if left and right:
                node.left = TreeNode(left, node.path + "L")
                node.right = TreeNode(right, node.path + "R")
                
                q.append((node.left, tier + 1))
                q.append((node.right, tier + 1))
                
                print(f"T{tier}: {len(node.genera)} → {len(left)}L | {len(right)}R")
        
        return root

# build LSU tree
lsuGenera = list(lsuGenus2Seq.keys())
lsuW2v = SinglePhyloW2V(lsuModel)
lsuBuilder = SingleBuilder(lsuW2v, lsuGenus2Seq)
lsuTree = lsuBuilder.buildTree(lsuGenera, maxDepth=12)

# build SSU tree
ssuGenera = list(ssuGenus2Seq.keys())
ssuW2v = SinglePhyloW2V(ssuModel)
ssuBuilder = SingleBuilder(ssuW2v, ssuGenus2Seq)
ssuTree = ssuBuilder.buildTree(ssuGenera, maxDepth=12)

# load the combined tree from previous
with open("tree.pkl", "rb") as f:
    combinedTree = pickle.load(f)

# save single trees
def tree2dict(node):
    if node.isLeaf():
        return {
            "path": node.path,
            "genera": node.genera,
            "size": len(node.genera)
        }
    return {
        "path": node.path,
        "size": len(node.genera),
        "left": tree2dict(node.left) if node.left else None,
        "right": tree2dict(node.right) if node.right else None
    }

with open("lsu_tree.json", "w") as f:
    json.dump(tree2dict(lsuTree), f, indent=2)
with open("ssu_tree.json", "w") as f:
    json.dump(tree2dict(ssuTree), f, indent=2)

with open("lsu_tree.pkl", "wb") as f:
    pickle.dump(lsuTree, f)
with open("ssu_tree.pkl", "wb") as f:
    pickle.dump(ssuTree, f)

print("\nDone! Trees saved.")

print("\nStarting LSU Tree Navigator...")
lsuNav = Navigator(lsuTree, "LSU")
lsuNav.run()

print("\nStarting SSU Tree Navigator...")
ssuNav = Navigator(ssuTree, "SSU")
ssuNav.run()
