In [None]:
import phylustrator as ph
import phylustrator.alerax_parser as alx

# 1. Parse AleRax Output
# -----------------------------
path_to_alerax = "./alerax_output_dir"
data = alx.parse_alerax(path_to_alerax)

# Check what we found
print(f"Loaded {len(data.events_df)} raw records.")
print("Aggregated Stats (First 5 rows):")
display(data.aggregated_stats.head())

# 2. Setup Tree
# -----------------------------
# If AleRax output provided a tree, use it. Otherwise load your reference.
if data.species_tree:
    t = data.species_tree
else:
    t = ete3.Tree("ReferenceTree.nwk")

# Convert to Ultrametric for cleaner plotting?
# t.convert_to_ultrametric() 

# 3. Setup Drawer (Renamed to RadialTreeDrawer!)
# -----------------------------
style = ph.TreeStyle(mode='r', radius=500, node_size=0, branch_size=3)
drawer = ph.RadialTreeDrawer(t, style=style) # <--- Renamed class

# 4. Map Data to Colors
# -----------------------------
# We will color branches by the TOTAL number of Transfers
branch2color = {}

# Convert agg stats to dict for lookup: { 'NodeName': {'transfers': 10, ...} }
stats_map = data.aggregated_stats.set_index("species_label").to_dict(orient="index")

for n in t.traverse():
    # AleRax usually preserves node names exactly
    node_stats = stats_map.get(n.name)
    
    if node_stats:
        # Example: Log-scale coloring or threshold
        total_transfers = node_stats.get('transfers', 0)
        
        if total_transfers > 50:
            branch2color[n] = "#D62728" # Red (Hot)
        elif total_transfers > 10:
            branch2color[n] = "orange"
        else:
            branch2color[n] = "black"
    else:
        branch2color[n] = "grey"

# 5. Draw
# -----------------------------
drawer.draw(branch2color=branch2color)

# 6. Add Heatmap of Duplications/Losses
# -----------------------------
# Extract leaf data for the matrix
heatmap_data = {}
for l in t.get_leaves():
    s = stats_map.get(l.name)
    if s:
        heatmap_data[l.name] = {
            "D": s['duplications'],
            "L": s['losses'],
            "T": s['transfers']
        }

def heat_cmap(val):
    if val > 100: return "darkblue"
    if val > 20: return "skyblue"
    return "#eee"

drawer.add_heatmap_matrix(heatmap_data, columns=["D", "T", "L"], cmap=heat_cmap)
display(drawer.d)