In [None]:
import pickle
import numpy as np
from gensim.models import Word2Vec
import random
from sklearn.preprocessing import normalize
from collections import defaultdict, deque
import json
from Bio import SeqIO
import re

print("="*80)
print("EXTRACTING SEQUENCES FROM BLAST DATABASES")
print("="*80)

def getSeqs(fastaPath):
    seqsByGenus = defaultdict(list)
    
    with open(fastaPath, 'r') as f:
        for record in SeqIO.parse(f, "fasta"):
            header = record.description
            seq = str(record.seq).upper()
            
            if len(seq) < 100:
                continue
            
            cleaned = ''.join(c for c in seq if c in 'ATCGUN')
            if len(cleaned) < 100:
                continue
            
            # genus is usually in brackets or first word
            genus = None
            bracketMatch = re.search(r'```math([A-Z][a-z]+)\s+', header)
            if bracketMatch:
                genus = bracketMatch.group(1)
            else:
                words = header.split()
                for w in words:
                    if re.match(r'^[A-Z][a-z]{2,}$', w):
                        genus = w
                        break
            
            if not genus:
                genus = record.id.split('_')[0].split('.')[0]
            
            if genus:
                seqsByGenus[genus].append(cleaned)
    
    return dict(seqsByGenus)

lsuGenus2Seq = getSeqs("LSUs/LSU_eukaryote_rRNA.fasta")
print(f"Got {len(lsuGenus2Seq)} LSU genera")

ssuGenus2Seq = getSeqs("SSUs/SSU_eukaryote_rRNA.fasta")  
print(f"Got {len(ssuGenus2Seq)} SSU genera")

commonGenera = list(set(lsuGenus2Seq.keys()) & set(ssuGenus2Seq.keys()))
print(f"Common genera: {len(commonGenera)}")

if not commonGenera:
    print("No common genera found!")
    exit(1)

# save for later
with open("genus_to_lsu_sequences.pkl", "wb") as f:
    pickle.dump(lsuGenus2Seq, f)
with open("genus_to_ssu_sequences.pkl", "wb") as f:
    pickle.dump(ssuGenus2Seq, f)

def makeKmerSentences(sequences, k=6):
    sentences = []
    for seq in sequences:
        kmers = [seq[i:i+k] for i in range(len(seq)-k+1) if len(seq[i:i+k]) == k]
        sentences.append(kmers)
    return sentences

class PhyloW2V:
    def __init__(self, sentences, vecSize=128, window=5):
        self.sentences = sentences
        self.vecSize = vecSize
        self.window = window
        self.model = Word2Vec(sentences, vector_size=vecSize, window=window, 
                              min_count=1, workers=4)
        
    def embedSeq(self, seq, k=6):
        kmers = [seq[i:i+k] for i in range(len(seq)-k+1) 
                if len(seq[i:i+k]) == k and seq[i:i+k] in self.model.wv]
        if not kmers:
            return np.zeros(self.vecSize)
        return np.mean([self.model.wv[kmer] for kmer in kmers], axis=0)
    
    def embedGenus(self, genusSeqs, k=6, maxSamples=20):
        samples = random.sample(genusSeqs, min(maxSamples, len(genusSeqs)))
        embs = [self.embedSeq(seq, k) for seq in samples]
        embs = [e for e in embs if np.any(e)]
        if not embs:
            return np.zeros(self.vecSize)
        return normalize([np.mean(embs, axis=0)], norm='l2')[0]
    
    def retrain(self, kmerWeights):
        """
        Retrain with specific k-mer weights for THIS model
        """
        weightedSents = []
        for sent in self.sentences:
            ws = []
            for kmer in sent:
                weight = kmerWeights.get(kmer, 1.0)
                repeat = max(1, int(weight * 2))
                ws.extend([kmer] * repeat)
            if ws:
                weightedSents.append(ws)
        
        self.model = Word2Vec(weightedSents, vector_size=self.vecSize, 
                            window=self.window, min_count=1, workers=4, alpha=0.01, sg=1)

print("\nTraining Word2Vec models...")
lsuSents = [s for genus in lsuGenus2Seq.values() for s in makeKmerSentences(genus)]
ssuSents = [s for genus in ssuGenus2Seq.values() for s in makeKmerSentences(genus)]

lsuW2v = PhyloW2V(lsuSents)
ssuW2v = PhyloW2V(ssuSents)

class TreeNode:
    def __init__(self, genera, path=""):
        self.genera = genera
        self.path = path
        self.left = None
        self.right = None
        self.score = 0.0
        
    def isLeaf(self):
        return self.left is None and self.right is None

class PhyloBuilder:
    def __init__(self, lsuModel, ssuModel, maxLeaf=6, minCluster=2):
        self.lsu = lsuModel
        self.ssu = ssuModel
        self.maxLeaf = maxLeaf
        self.minCluster = minCluster
        self.lsuKmerFeedback = defaultdict(float)
        self.ssuKmerFeedback = defaultdict(float)
        
    def getEmbs(self, genera):
        lsuEmbs = []
        ssuEmbs = []
        
        for g in genera:
            lsuEmbs.append(self.lsu.embedGenus(lsuGenus2Seq[g]))
            ssuEmbs.append(self.ssu.embedGenus(ssuGenus2Seq[g]))
            
        return np.array(lsuEmbs), np.array(ssuEmbs)
    
    def calcAgreement(self, labels1, labels2):
        from sklearn.metrics import adjusted_rand_score
        return adjusted_rand_score(labels1, labels2)
    
    def split(self, genera, tier=0):
        if len(genera) <= self.maxLeaf:
            return None, None, 0.0
        
        if len(genera) < 2:  
            return None, None, 0.0
    
        lsuEmbs, ssuEmbs = self.getEmbs(genera)
    
    # Try spectral first (better for non-spherical clusters)
        from sklearn.cluster import SpectralClustering
        from sklearn.cluster import DBSCAN
        from sklearn.decomposition import PCA
    
        pca1 = PCA(n_components=min(64, len(genera)))
        pca2 = PCA(n_components=min(64, len(genera)))
    
        lsuReduced = pca1.fit_transform(lsuEmbs)
        ssuReduced = pca2.fit_transform(ssuEmbs)
    
        import warnings


        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")

            spec1 = SpectralClustering(
            n_clusters=2, affinity='rbf', gamma=1.0,
            assign_labels='discretize', random_state=42
            )
            
            spec2 = SpectralClustering(
            n_clusters=2, affinity='rbf', gamma=1.0,
            assign_labels='discretize', random_state=42
            )
    
        lsuLabels = spec1.fit_predict(lsuReduced)
        ssuLabels = spec2.fit_predict(ssuReduced)
    
        if len(np.unique(lsuLabels)) < 2:
            firstPc = lsuReduced[:, 0]
            threshold = np.percentile(firstPc, 30)  
            lsuLabels = (firstPc > threshold).astype(int)
    
        if len(np.unique(ssuLabels)) < 2:
            firstPc = ssuReduced[:, 0]
            threshold = np.percentile(firstPc, 30)
            ssuLabels = (firstPc > threshold).astype(int)
    
        agreement = self.calcAgreement(lsuLabels, ssuLabels)
    
        if agreement < 0.5:
            ssuLabels = 1 - ssuLabels
            agreement = self.calcAgreement(lsuLabels, ssuLabels)
    
    
        if agreement > 0.8:  
            consensusWeight = 0.5
        else:  
            consensusWeight = 0.65
    
        softConsensus = lsuLabels * consensusWeight + ssuLabels * (1 - consensusWeight)
        consensus = (softConsensus >= 0.5).astype(int)
    
        grp0 = [genera[i] for i in range(len(genera)) if consensus[i] == 0]
        grp1 = [genera[i] for i in range(len(genera)) if consensus[i] == 1]
    
        if len(grp0) == 0 or len(grp1) == 0:
            return None, None, 0.0
    
        sepScore = self._calcSeparation(lsuReduced, consensus) 
        finalScore = agreement * (0.7 + 0.3 * sepScore)
    
        self._updateKmers(genera, grp0, grp1, finalScore, tier)
    
        return grp0, grp1, finalScore

    def _calcSeparation(self, embeddings, labels):
        """How well separated are the clusters"""
        from sklearn.metrics import silhouette_score
        try:
            return (silhouette_score(embeddings, labels) + 1) / 2  # Normalize to 0-1
        except:
            return 0.5
    
    def _updateKmers(self, allGenera, leftGrp, rightGrp, score, tier):
        """
    Smarter feedback using:
    - Position-aware conservation (k-mers at similar positions = conserved)  
    - Information gain (which k-mers best separate groups)
    - Tier decay (early splits more important)
        """
        import math
    
        def getKmerStats(seqs, k=6):
            kmerCounts = defaultdict(int)
            kmerPositions = defaultdict(list)
            totalKmers = 0
        
            for seq in seqs:
                seqLen = len(seq)
                for i in range(len(seq) - k + 1):
                    kmer = seq[i:i+k]
                kmerCounts[kmer] += 1
                # Relative position (0-1) handles different seq lengths
                relPos = i / max(1, seqLen - k)
                kmerPositions[kmer].append(relPos)
                totalKmers += 1
        
        # Calculate conservation score for each k-mer
            kmerScores = {}
            for kmer, positions in kmerPositions.items():
                if len(positions) > 1:
                # Low position variance = highly conserved
                    posVariance = np.var(positions)
                    conservation = math.exp(-posVariance * 5)  # Sharp decay
                    frequency = kmerCounts[kmer] / max(1, totalKmers)
                    kmerScores[kmer] = conservation * math.sqrt(frequency)
                else:
                    kmerScores[kmer] = 0.1
                
            return kmerCounts, kmerScores
    
    # Sample more sequences for better stats
        leftLsuSeqs = []
        rightLsuSeqs = []
        for g in leftGrp[:min(50, len(leftGrp))]:
            leftLsuSeqs.extend(lsuGenus2Seq[g][:20])
        for g in rightGrp[:min(50, len(rightGrp))]:
            rightLsuSeqs.extend(lsuGenus2Seq[g][:20])
    
        leftCounts, leftScores = getKmerStats(leftLsuSeqs)
        rightCounts, rightScores = getKmerStats(rightLsuSeqs)
    
    # Calculate information gain for each k-mer
        allKmers = set(leftCounts.keys()) | set(rightCounts.keys())
    
        for kmer in allKmers:
            leftFreq = leftCounts.get(kmer, 0)
            rightFreq = rightCounts.get(kmer, 0)
            total = leftFreq + rightFreq
        
            if total < 2:  # Skip rare k-mers
                continue
            
        # How well does this k-mer discriminate?
            leftProb = leftFreq / max(1, leftFreq + rightFreq)
            rightProb = rightFreq / max(1, leftFreq + rightFreq)
        
        # Information gain (simplified mutual information)
            if leftProb > 0 and rightProb > 0:
                entropy = -(leftProb * math.log(leftProb) + rightProb * math.log(rightProb))
            else:
                entropy = 0
            
            discrimination = abs(leftProb - rightProb)  # How different are the groups
        
        # Conservation bonus
            conservationBonus = (leftScores.get(kmer, 0) + rightScores.get(kmer, 0)) / 2
        
        # Tier decay - early splits matter more
            tierMultiplier = math.exp(-tier * 0.1)
        
        # Final weight combines everything
            weight = (entropy * discrimination + conservationBonus) * score * tierMultiplier
        
            self.lsuKmerFeedback[kmer] += weight
    
    # Same for SSU
        leftSsuSeqs = []
        rightSsuSeqs = []
        for g in leftGrp[:min(50, len(leftGrp))]:
            leftSsuSeqs.extend(ssuGenus2Seq[g][:20])
        for g in rightGrp[:min(50, len(rightGrp))]:
            rightSsuSeqs.extend(ssuGenus2Seq[g][:20])
    
        leftCounts, leftScores = getKmerStats(leftSsuSeqs)
        rightCounts, rightScores = getKmerStats(rightSsuSeqs)

        allKmers = set(leftCounts.keys()) | set(rightCounts.keys())
    
        for kmer in allKmers:
            leftFreq = leftCounts.get(kmer, 0)
            rightFreq = rightCounts.get(kmer, 0)
            total = leftFreq + rightFreq
        
            if total < 2:
                continue
            
            leftProb = leftFreq / max(1, leftFreq + rightFreq)
            rightProb = rightFreq / max(1, leftFreq + rightFreq)
        
            if leftProb > 0 and rightProb > 0:
                entropy = -(leftProb * math.log(leftProb) + rightProb * math.log(rightProb))
            else:
                entropy = 0
            
            discrimination = abs(leftProb - rightProb)
            conservationBonus = (leftScores.get(kmer, 0) + rightScores.get(kmer, 0)) / 2
            tierMultiplier = math.exp(-tier * 0.1)
        
            weight = (entropy * discrimination + conservationBonus) * score * tierMultiplier
        
            self.ssuKmerFeedback[kmer] += weight

    def buildTree(self, genera=None, maxDepth=8):
        if genera is None:
            genera = commonGenera
            
        print(f"\nBuilding tree for {len(genera)} genera...")
        
        root = TreeNode(genera, "")
        q = deque([(root, 0)])
        
        totalScore = 0.0
        nSplits = 0
        
        while q:
            node, tier = q.popleft()
            
            shouldSplit = True
            
            if tier >= maxDepth and len(node.genera) <= self.maxLeaf:
                shouldSplit = False
            
            if not shouldSplit:
                continue
                
            left, right, score = self.split(node.genera, tier)
            
            if left and right:
                node.left = TreeNode(left, node.path + "L")
                node.right = TreeNode(right, node.path + "R")
                node.score = score
                
                totalScore += score
                nSplits += 1
                
                q.append((node.left, tier + 1))
                q.append((node.right, tier + 1))
                
                #print(f"T{tier}: {len(node.genera)} → {len(left)}L | {len(right)}R (agr: {score:.3f})")
        
        avgScore = totalScore / max(nSplits, 1)
        print(f"Avg agreement: {avgScore:.3f}")
        return root, avgScore

print("\n" + "="*80)
print("LSU/SSU ADVERSARIAL REFINEMENT")
print("="*80)

builder = PhyloBuilder(lsuW2v, ssuW2v, maxLeaf=6)
bestScore = 0.0
bestTree = None

MAX_ITER = 30
PATIENCE = 5
MIN_IMPROVE = 0.001

scores = []
patience = 0

for it in range(MAX_ITER):
    print(f"\n--- Iteration {it+1} ---")
    
    tree, score = builder.buildTree(maxDepth=8)
    scores.append(score)
    
    improve = score - bestScore
    
    if improve > MIN_IMPROVE:
        bestScore = score
        lsuW2v.model.save("best_lsu.model")
        ssuW2v.model.save("best_ssu.model")
        bestTree = tree
        patience = 0
        print(f"New best: {bestScore:.3f}")
        
        with open("best_tree.pkl", "wb") as f:
            pickle.dump(bestTree, f)
    else:
        patience += 1
        print(f"No improvement, patience: {patience}/{PATIENCE}")
    
    if patience >= PATIENCE:
        print(f"\nConverged after {it+1} iterations")
        break
    
    # check plateau
    if len(scores) >= 3:
        if np.std(scores[-3:]) < MIN_IMPROVE / 2:
            print(f"\nPlateaued after {it+1} iterations")
            break
    
    # SEPARATE normalization and retraining for LSU and SSU
    if builder.lsuKmerFeedback:
        # Use percentile-based normalization (robust to outliers)
        weights = np.array(list(builder.lsuKmerFeedback.values()))
        p20 = np.percentile(weights, 20)
        p80 = np.percentile(weights, 80)
        
        normLsuFb = {}
        for k, v in builder.lsuKmerFeedback.items():
            if v < p20:
                normLsuFb[k] = 0.5  # Downweight uninformative k-mers
            elif v > p80:
                normLsuFb[k] = 2.0  # Boost discriminative k-mers
            else:
                # Linear interpolation in between
                normLsuFb[k] = 0.5 + 1.5 * (v - p20) / max(1, p80 - p20)
        
        print(f"Retraining LSU with {len(normLsuFb)} weighted k-mers...")
        lsuW2v.retrain(normLsuFb)
    
    if builder.ssuKmerFeedback:
        weights = np.array(list(builder.ssuKmerFeedback.values()))
        p20 = np.percentile(weights, 20)
        p80 = np.percentile(weights, 80)
        
        normSsuFb = {}
        for k, v in builder.ssuKmerFeedback.items():
            if v < p20:
                normSsuFb[k] = 0.5
            elif v > p80:
                normSsuFb[k] = 2.0
            else:
                normSsuFb[k] = 0.5 + 1.5 * (v - p20) / max(1, p80 - p20)
        
        print(f"Retraining SSU with {len(normSsuFb)} weighted k-mers...")
        ssuW2v.retrain(normSsuFb)
    
    builder = PhyloBuilder(lsuW2v, ssuW2v, maxLeaf=6)

print(f"\nFinal score: {bestScore:.3f}")

# save final models
from gensim.models import Word2Vec as W2V
finalLsu = W2V.load("best_lsu.model")
finalSsu = W2V.load("best_ssu.model")

finalLsu.save("lsu_phylo.model")
finalSsu.save("ssu_phylo.model")

with open("best_tree.pkl", "rb") as f:
    finalTree = pickle.load(f)

def tree2dict(node):
    if node.isLeaf():
        return {
            "path": node.path,
            "genera": node.genera,
            "size": len(node.genera)
        }
    return {
        "path": node.path,
        "size": len(node.genera),
        "score": node.score,
        "left": tree2dict(node.left) if node.left else None,
        "right": tree2dict(node.right) if node.right else None
    }

treeDict = tree2dict(finalTree)
with open("tree.json", "w") as f:
    json.dump(treeDict, f, indent=2)

with open("tree.pkl", "wb") as f:
    pickle.dump(finalTree, f)

class Navigator:
    def __init__(self, root):
        self.root = root
        self.current = root
        self.history = []
        
    def go(self, dir):
        dir = dir.upper()
        if dir == 'L' and self.current.left:
            self.current = self.current.left
            self.history.append('L')
            return True
        elif dir == 'R' and self.current.right:
            self.current = self.current.right
            self.history.append('R')
            return True
        return False
    
    def show(self, showGenera=False):
        path = ''.join(self.history) if self.history else "ROOT"
        print(f"\n{'='*60}")
        print(f"Position: [{path}]")
        print(f"Genera: {len(self.current.genera)}")
        
        if self.current.score > 0:
            print(f"Score: {self.current.score:.3f}")
        
        if not self.current.isLeaf():
            if self.current.left:
                print(f"  Left (L): {len(self.current.left.genera)} genera")
            if self.current.right:
                print(f"  Right (R): {len(self.current.right.genera)} genera")
        else:
            print("  [LEAF]")
        
        if showGenera:
            print(f"\nGenera:")
            glist = sorted(self.current.genera)[:20]
            for i, g in enumerate(glist, 1):
                print(f"  {i:3d}. {g}")
            if len(self.current.genera) > 20:
                print(f"  ... +{len(self.current.genera) - 20} more")
        print(f"{'='*60}")
    
    def reset(self):
        self.current = self.root
        self.history = []
        
    def run(self):
        print("\n" + "="*80)
        print("TREE NAVIGATOR")
        print("="*80)
        print("\nCommands: goto L/R, show, new, quit")
        
        self.show()
        
        while True:
            try:
                cmd = input("\n> ").strip().lower()
                
                if cmd.startswith("goto"):
                    parts = cmd.split()
                    if len(parts) == 2 and parts[1].upper() in ['L', 'R']:
                        if self.go(parts[1]):
                            self.show()
                        else:
                            print("Can't go there")
                    else:
                        print("Use: goto L or goto R")
                        
                elif cmd == "show":
                    self.show(showGenera=True)
                    
                elif cmd == "new":
                    self.reset()
                    print("\nBack to root")
                    self.show()
                    
                elif cmd in ["quit", "exit", "q"]:
                    break
                    
                else:
                    print("Unknown command")
                    
            except KeyboardInterrupt:
                break
            except Exception as e:
                print(f"Error: {e}")

print("\n" + "="*80)
print(f"DONE! Converged after {len(scores)} iterations uWu")
print(f"Best score: {bestScore:.3f}")
print("="*80)

nav = Navigator(finalTree)
nav.run()

EXTRACTING SEQUENCES FROM BLAST DATABASES
Got 4640 LSU genera
Got 5637 SSU genera
Common genera: 1598

Training Word2Vec models...

LSU/SSU ADVERSARIAL REFINEMENT

--- Iteration 1 ---

Building tree for 1598 genera...




Avg agreement: 0.238
New best: 0.238
Retraining LSU with 269 weighted k-mers...
Retraining SSU with 118 weighted k-mers...

--- Iteration 2 ---

Building tree for 1598 genera...




Avg agreement: 0.207
No improvement, patience: 1/5
Retraining LSU with 258 weighted k-mers...
Retraining SSU with 116 weighted k-mers...

--- Iteration 3 ---

Building tree for 1598 genera...




Avg agreement: 0.236
No improvement, patience: 2/5
Retraining LSU with 258 weighted k-mers...
Retraining SSU with 110 weighted k-mers...

--- Iteration 4 ---

Building tree for 1598 genera...




Avg agreement: 0.227
No improvement, patience: 3/5
Retraining LSU with 260 weighted k-mers...
Retraining SSU with 111 weighted k-mers...

--- Iteration 5 ---

Building tree for 1598 genera...




Avg agreement: 0.222
No improvement, patience: 4/5
Retraining LSU with 260 weighted k-mers...
Retraining SSU with 111 weighted k-mers...

--- Iteration 6 ---

Building tree for 1598 genera...




Avg agreement: 0.233
No improvement, patience: 5/5

Converged after 6 iterations

Final score: 0.238

DONE! Converged after 6 iterations
Best score: 0.238

TREE NAVIGATOR

Commands: goto L/R, show, new, quit

Position: [ROOT]
Genera: 1598
Score: 0.599
  Left (L): 733 genera
  Right (R): 865 genera

Position: [R]
Genera: 865
Score: 0.748
  Left (L): 188 genera
  Right (R): 677 genera

Position: [RR]
Genera: 677
Score: 0.537
  Left (L): 238 genera
  Right (R): 439 genera

Position: [RRR]
Genera: 439
Score: 0.461
  Left (L): 163 genera
  Right (R): 276 genera

Position: [RRRR]
Genera: 276
Score: 0.544
  Left (L): 101 genera
  Right (R): 175 genera

Position: [RRRRL]
Genera: 101
Score: 0.718
  Left (L): 44 genera
  Right (R): 57 genera

Position: [RRRRLL]
Genera: 44
Score: 0.743
  Left (L): 13 genera
  Right (R): 31 genera

Position: [RRRRLLR]
Genera: 31
Score: 0.780
  Left (L): 3 genera
  Right (R): 28 genera

Position: [RRRRLLR]
Genera: 31
Score: 0.780
  Left (L): 3 genera
  Right (R): 2