In [22]:
import os
import random
import h5py
import pickle
import torch
import networkx as nx
from torch_geometric.data import Data
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Sequential, ReLU, Dropout, BatchNorm1d
from typing import List, Tuple, Dict, Any

def one_hot_encode(sequences, gene_present, gene_length, alphabet=['A','C','T','G','-']):
    """
    sequences: List of strings (DNA sequences)
    gene_present: np.array(bool) oder Torch Tensor, gleiche Länge wie sequences
    gene_length: int, fixe Länge für das Hot-Encoding
    alphabet: list, Zeichenalphabet
    """
    num_samples = len(sequences)
    num_chars = len(alphabet)
    char_to_idx = {c:i for i,c in enumerate(alphabet)}
    sequences_str = [s.decode('utf-8') for s in sequences]
    gene_present = np.array(gene_present, dtype=bool)
    
    # 1️⃣ Leere Batch-Matrix vorbereiten: (num_samples, gene_length, num_chars)
    batch = np.zeros((num_samples, gene_length, num_chars), dtype=np.float32)
    
    # 2️⃣ Hot-Encode alle Sequenzen
    for i, seq in enumerate(sequences_str):
        if gene_present[i]:
            L = min(len(seq), gene_length)  # abschneiden
            for j, c in enumerate(seq[:L]):
                if c in char_to_idx:
                    batch[i, j, char_to_idx[c]] = 1.0
        elif gene_present[i] == 0:
            batch[i, :, :] = -1.0
            
    # 3️⃣ Zufällige, aber konsistente Spaltenpermutation
    perm = np.random.permutation(gene_length)
    batch = batch[:, perm, :]
    
    # 4️⃣ Optional: Flatten zu Vektor (num_samples, gene_length*num_chars)
    batch_flat = batch.reshape(num_samples, -1)
    
    return torch.tensor(batch_flat)  # shape: (num_samples, gene_length*num_chars)

def load_file(file):

    gene_length = 300
    #nucleotide_mutation_rate = 0.1
    
    with h5py.File(file, "r") as f:
            grp = f["results"]
            # Load graph_properties (pickle stored in dataset)
            graph_properties = pickle.loads(grp["graph_properties"][()])
    
            # Unpack graph properties
            nodes = torch.tensor(graph_properties[0])                # [num_nodes]
            edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
            coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]
    
            # Load datasets instead of attrs
            gene_absence_presence_matrix = grp["gene_absence_presence_matrix"][()]
            nucleotide_sequences = grp["nucleotide_sequences"][()]
            #children_gene_nodes_loss_events = grp["children_gene_nodes_loss_events"][()]
        
            # Load HGT events (simplified)
            hgt_events = {}
            hgt_grp_simpl = grp.get("nodes_hgt_events_simplified", None)
            if hgt_grp_simpl is not None:
                for site_id in hgt_grp_simpl.keys():
                    hgt_events[int(site_id)] = hgt_grp_simpl[site_id][()]
            else:
                hgt_events = {}

            hot_encoded_nucleotide_sequences = one_hot_encode(nucleotide_sequences, gene_absence_presence_matrix, gene_length)

            # Fill the remaining nodes with zeros.
            pad_rows = len(nodes) - len(nucleotide_sequences)
            pad = torch.zeros((pad_rows, hot_encoded_nucleotide_sequences.shape[1]), dtype=hot_encoded_nucleotide_sequences.dtype)
            hot_encoded_nucleotide_sequences = torch.cat([hot_encoded_nucleotide_sequences, pad], dim=0)
        
            G = nx.DiGraph()
            
            # Füge Knoten hinzu (optional mit Koordinaten als Attribut)
            node_id_list = nodes.tolist()
            for i, node_id in enumerate(node_id_list):
                # coords[:, i] exists; using index 5 as in your code for node_time
                G.add_node(node_id, node_time = coords[:, i].tolist()[5])
            
            # Füge Kanten hinzu
            edge_list = edges.tolist()
            for src, dst in zip(edge_list[0], edge_list[1]):
                G.add_edge(src, dst)
            
            # Collect all recipient_parent_nodes from all sites (if present)
            recipient_parent_nodes = set()
            if hgt_grp_simpl is not None:
                for site_id in hgt_grp_simpl.keys():
                    arr = hgt_grp_simpl[site_id][()]  # load dataset as numpy structured array
                    # try robust field name for recipient child
                    if 'recipient_child_node' in arr.dtype.names:
                        recipient_parent_nodes.update(arr["recipient_child_node"].tolist())
                    elif 'recipient_child' in arr.dtype.names:
                        recipient_parent_nodes.update(arr["recipient_child"].tolist())

            # Build theta_gains: 1 if node is in recipient_parent_nodes, else 0
            theta_gains = torch.tensor(
                [1 if node in recipient_parent_nodes else 0 for node in range(len(G.nodes))],
                dtype=torch.long
            )

            level = {n: 0 for n in G.nodes}  # Leaves haben Level 0
            
            # 3. Topologische Sortierung (damit Kinder vor Eltern behandelt werden)
            for node in reversed(list(nx.topological_sort(G))):
                successors = list(G.successors(node))
                if successors:
                    level[node] = 1 + max(level[s] for s in successors)
            
            # 4. Level als Attribut setzen
            nx.set_node_attributes(G, level, "level")

            ### Add candidate egdes, i.e. potential hgt edges:
        
            # Hole alle Knoten und ihre Zeiten
            node_times = {n: G.nodes[n]['node_time'] for n in G.nodes}
            sorted_nodes = sorted(node_times.keys(), key=lambda n: node_times[n])

            # Sort the level and node_times from 0 to max_node_id and not in G.nodes order:
            node_times = torch.tensor([node_times[n] for n in sorted_nodes], dtype=torch.float)
            node_levels = torch.tensor([level[n] for n in sorted_nodes], dtype=torch.float)

            # Füge Kanten hinzu
            existing_edges = set(zip(edges[0].tolist(), edges[1].tolist()))
            candidate_edges = []
            for i, src in enumerate(sorted_nodes):
                t_src = node_times[src]
                for dst in sorted_nodes[i+1:]:  # nur spätere Knoten
                    t_dst = node_times[dst]
                    if t_dst > t_src:
                        if (dst, src) not in existing_edges:  # vermeidet doppelte
                            #G.add_edge(src, dst)
                            candidate_edges.append((dst, src))
            if candidate_edges:  
                candidate_edges = torch.tensor(candidate_edges, dtype=torch.long).t()
            else:
                candidate_edges = torch.empty((2,0), dtype=torch.long)
                    
            data = Data(
                #nucleotide_sequences = nucleotide_sequences,
                hot_encoded_nucleotide_sequences = hot_encoded_nucleotide_sequences,       # Node Features [num_nodes, 2]
                edge_index = edges[[1, 0], :],        # Edge Index [2, num_edges]
                candidate_edges = candidate_edges[[1, 0], :],
                y = theta_gains,            # Labels [num_nodes]
                file = file,
                G = G,
                #recipient_parent_nodes = recipient_parent_nodes,
                gene_absence_presence_matrix = gene_absence_presence_matrix,
                node_times = node_times,
                node_levels = node_levels,
                hgt_events = hgt_events,
            )

            # -----------------------------
            # Compute label maps directly here (no external call)
            # -----------------------------
            recip_label_map = {}
            donor_map = {}
            # initialize maps for internal nodes with exactly 2 children
            for n in G.nodes:
                children = list(G.successors(n))
                if len(children) == 2:
                    recip_label_map[n] = (0, None)
                    donor_map[n] = None

            # parse hgt_events and fill maps
            if hgt_events:
                for site_id, arr in hgt_events.items():
                    # arr is expected to be a numpy structured array
                    try:
                        ln = len(arr)
                    except Exception:
                        continue
                    for i in range(ln):
                        # robust field extraction
                        rec_parent = None
                        rec_child = None
                        donor_parent = None
                        if 'recipient_parent_node' in arr.dtype.names:
                            rec_parent = int(arr[i]['recipient_parent_node'])
                        elif 'recipient_parent' in arr.dtype.names:
                            rec_parent = int(arr[i]['recipient_parent'])
                        if 'recipient_child_node' in arr.dtype.names:
                            rec_child = int(arr[i]['recipient_child_node'])
                        elif 'recipient_child' in arr.dtype.names:
                            rec_child = int(arr[i]['recipient_child'])
                        if 'donor_parent_node' in arr.dtype.names:
                            donor_parent = int(arr[i]['donor_parent_node'])
                        elif 'donor_parent' in arr.dtype.names:
                            donor_parent = int(arr[i]['donor_parent'])

                        if rec_parent is None:
                            continue
                        if rec_parent in recip_label_map:
                            recip_label_map[rec_parent] = (1, rec_child)
                            donor_map[rec_parent] = donor_parent

            # attach precomputed label maps to data
            data.recip_label_map = recip_label_map
            data.donor_map = donor_map

    return data

folder = "/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks"

files = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

if len(files) > 100:
    files = random.sample(files, 100)

list_of_Data = []
for f in files:
    try:
        d = load_file(f)
        list_of_Data.append(d)
    except Exception as e:
        print(f"Fehler beim Laden von {f}: {e}")

# example = load_file(random.choice(files))

print(f"{len(list_of_Data)} Dateien erfolgreich geladen.")

  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]


Fehler beim Laden von /mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks/simulation_186274.h5: cannot reshape array of size 0 into shape (0,newaxis)
99 Dateien erfolgreich geladen.


In [24]:
import os
import random
import math
import time
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import networkx as nx
import copy

# -------------------------
# Model components
# -------------------------

def make_mlp(input_dim, hidden_dims=[256,128], output_dim=None, dropout=0.1):
    """
    Build an MLP that uses LayerNorm instead of BatchNorm.
    LayerNorm is stable for batch size 1.
    """
    layers = []
    dims = [input_dim] + hidden_dims
    for i in range(len(hidden_dims)):
        layers.append(nn.Linear(dims[i], dims[i+1]))
        # use LayerNorm over feature dimension
        layers.append(nn.LayerNorm(dims[i+1]))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout))
    if output_dim is not None:
        layers.append(nn.Linear(dims[-1], output_dim))
    return nn.Sequential(*layers)


class RecipientFinder(nn.Module):
    """
    Returns:
      - hgt_logit: raw logit (no sigmoid) shape [B]
      - which_logits: raw logits for left/right shape [B,2]
    """
    def __init__(self, seq_vec_dim, aux_dim=3, hidden=[512,256], dropout=0.1):
        super().__init__()
        self.seq_vec_dim = seq_vec_dim
        self.aux_dim = aux_dim
        mlp_input = seq_vec_dim * 4 + aux_dim
        self.mlp = make_mlp(mlp_input, hidden_dims=hidden, output_dim=None, dropout=dropout)
        self.hgt_head = nn.Linear(hidden[-1], 1)    # logit
        self.which_head = nn.Linear(hidden[-1], 2)  # logits

    def forward(self, left_vec, right_vec, aux_feats):
        # left_vec/right_vec: [B, D]
        absdiff = torch.abs(left_vec - right_vec)
        prod = left_vec * right_vec
        x = torch.cat([left_vec, right_vec, absdiff, prod, aux_feats], dim=1)
        h = self.mlp(x)
        hgt_logit = self.hgt_head(h).squeeze(-1)         # [B]
        which_logits = self.which_head(h)                # [B,2]
        return hgt_logit, which_logits


class DonorFinder(nn.Module):
    """
    Scores candidate donors relative to a recipient vector.
    Returns raw scores (logits) of shape [M]
    """
    def __init__(self, seq_vec_dim, aux_dim=3, hidden=[512,256], dropout=0.1):
        super().__init__()
        mlp_input = seq_vec_dim * 4 + aux_dim
        self.score_mlp = make_mlp(mlp_input, hidden_dims=hidden, output_dim=1, dropout=dropout)

    def forward(self, recipient_vec, donor_vecs, aux_feats):
        # recipient_vec: [D] or [1,D]
        if recipient_vec.dim() == 1:
            r = recipient_vec.unsqueeze(0).expand(donor_vecs.shape[0], -1)
        else:
            r = recipient_vec.expand(donor_vecs.shape[0], -1)
        absdiff = torch.abs(r - donor_vecs)
        prod = r * donor_vecs
        x = torch.cat([r, donor_vecs, absdiff, prod, aux_feats], dim=1)
        scores = self.score_mlp(x).squeeze(-1)  # [M]
        return scores


class HGTDetector(nn.Module):
    """
    Orchestrator that contains RecipientFinder and DonorFinder and provides utility compute_aggregates.
    """
    def __init__(self, seq_vec_dim, aux_recipient_dim=3, aux_donor_dim=3,
                 rec_hidden=[16], donor_hidden=[16], hgt_threshold=0.5, topk_donors=1):
        super().__init__()
        self.recipient_finder = RecipientFinder(seq_vec_dim, aux_dim=aux_recipient_dim, hidden=rec_hidden)
        self.donor_finder = DonorFinder(seq_vec_dim, aux_dim=aux_donor_dim, hidden=donor_hidden)
        self.hgt_threshold = hgt_threshold
        self.topk_donors = topk_donors

    @staticmethod
    def compute_aggregates(G: nx.DiGraph, node_seq_matrix: torch.Tensor):
        """
        Compute aggregates such that agg[node_id] is valid for arbitrary integer node ids
        (i.e., agg's first dimension has size max_node_id+1). Aggregation is elementwise MAX
        over children (not sum), as requested.

        Returns:
            agg: Tensor [max_node_id+1, D] where rows for node ids not present are zeros.
        """
        # node ids may not be contiguous or start at 0. We'll allocate by max node id.
        nodes = list(G.nodes)
        if len(nodes) == 0:
            return node_seq_matrix.clone()

        max_node_id = int(max(nodes))
        D = node_seq_matrix.shape[1]
        device = node_seq_matrix.device
        dtype = node_seq_matrix.dtype

        # allocate agg with rows 0..max_node_id
        agg = torch.zeros((max_node_id + 1, D), dtype=dtype, device=device)

        # copy available rows from node_seq_matrix into agg by node id when possible.
        n_rows = node_seq_matrix.shape[0]
        for n in nodes:
            if isinstance(n, int) and n < n_rows:
                agg[n] = node_seq_matrix[n].to(device)
            else:
                agg[n] = agg[n]  # noop

        # bottom-up aggregation: children before parent
        topo = list(nx.topological_sort(G))
        for node in reversed(topo):
            children = list(G.successors(node))
            if len(children) == 0:
                continue
            child_vecs = []
            for c in children:
                if not isinstance(c, int) or c > max_node_id:
                    continue
                child_vecs.append(agg[c].unsqueeze(0))
            if len(child_vecs) == 0:
                continue
            child_stack = torch.cat(child_vecs, dim=0)  # [C, D]
            agg_val, _ = torch.max(child_stack, dim=0)
            agg[node] = agg_val

        return agg

# -------------------------
# Utilities for training
# -------------------------
def tree_distance(G: nx.DiGraph, node_a, node_b, max_penalty=None):
    """
    Shortest path length in undirected tree, fallback to max_penalty or len(G.nodes)
    """
    und = G.to_undirected()
    try:
        return nx.shortest_path_length(und, source=node_a, target=node_b)
    except Exception:
        return max_penalty if max_penalty is not None else len(G.nodes)


# -------------------------
# Training loop (fully adjusted)
# -------------------------
def train_hgt_detector(model: HGTDetector, list_of_Data, epochs=10, val_frac=0.1, lr=1e-3,
                       lambda_donor=0.1, hgt_threshold=0.5, max_candidates=300,
                       device=None, clip_grad=5.0):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    N = len(list_of_Data)
    idxs = list(range(N))
    random.shuffle(idxs)
    n_val = max(1, int(val_frac * N))
    val_idx = set(idxs[:n_val])
    train_idx = idxs[n_val:]

    # Estimate pos_weight for BCEWithLogitsLoss on training set (clamped)
    pos_count = 0
    neg_count = 0
    for i in train_idx:
        d = list_of_Data[i]
        recip_map = d.recip_label_map
        _ = d.donor_map
        for p, (lbl, _) in recip_map.items():
            if lbl == 1:
                pos_count += 1
            else:
                neg_count += 1
    pos_count = max(pos_count, 1)
    neg_count = max(neg_count, 1)
    ratio = neg_count / pos_count
    pos_weight_val = max(1.0, min(ratio, 50.0))  # clamp between 1 and 50
    pos_weight = torch.tensor([pos_weight_val], device=device)
    bce_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    print(f"Training on {len(train_idx)} graphs | Validation on {len(val_idx)} graphs | pos_weight={pos_weight_val:.3f}")

    for epoch in range(1, epochs+1):
        t0 = time.time()
        model.train()
        train_loss = 0.0
        train_rec_loss = 0.0
        train_don_loss = 0.0
        train_samples = 0

        # metrics
        tp = 0; fp = 0; fn = 0
        donor_expected_distance_accum = 0.0
        donor_expected_count = 0

        random.shuffle(train_idx)
        for idx in train_idx:
            data = list_of_Data[idx]
            # move seqs to device once
            data.hot_encoded_nucleotide_sequences = data.hot_encoded_nucleotide_sequences.to(device)

            recip_map = data.recip_label_map
            donor_map = data.donor_map
            parents = [n for n in data.G.nodes if len(list(data.G.successors(n))) == 2]
            if len(parents) == 0:
                continue

            # compute aggregates
            agg = model.compute_aggregates(data.G, data.hot_encoded_nucleotide_sequences)  # agg shape [max_id+1, D]

            # collect batch lists
            L_list = []
            R_list = []
            aux_list = []
            true_label_list = []
            true_rec_child_list = []
            true_donor_parent_list = []
            parent_nodes_list = []

            for parent in parents:
                children = list(data.G.successors(parent))
                left, right = children[0], children[1]
                # left/right exist as node ids; agg is indexed by node id
                L_list.append(agg[left].unsqueeze(0))
                R_list.append(agg[right].unsqueeze(0))
                # aux: node_time, level, gene_frac
                node_time = float(data.G.nodes[parent].get('node_time', 0.0))
                node_level = float(data.G.nodes[parent].get('level', 0.0))
                # avoid divide by zero
                parent_sum = agg[parent].sum().clamp(min=1.0)
                gene_frac = ((agg[left].sum() + agg[right].sum()) / parent_sum).unsqueeze(0).unsqueeze(1)
                aux = torch.tensor([[node_time, node_level, float(gene_frac.item())]], device=device, dtype=torch.float32)
                aux_list.append(aux)
                lbl, true_rec = recip_map.get(parent, (0, None))
                true_label_list.append(lbl)
                true_rec_child_list.append(true_rec)
                true_donor_parent_list.append(donor_map.get(parent, None))
                parent_nodes_list.append(parent)

            if len(L_list) == 0:
                continue

            left_vecs = torch.cat(L_list, dim=0)      # [P, D]
            right_vecs = torch.cat(R_list, dim=0)     # [P, D]
            aux_feats = torch.cat(aux_list, dim=0)    # [P, 3]
            true_labels = torch.tensor(true_label_list, dtype=torch.float32, device=device)  # [P]

            # Recipient forward
            hgt_logits, which_logits = model.recipient_finder(left_vecs, right_vecs, aux_feats)
            rec_loss = bce_loss_fn(hgt_logits, true_labels)

            # compute predicted probabilities & which
            probs = torch.sigmoid(hgt_logits).detach().cpu().numpy()
            which_pred = torch.argmax(which_logits.detach().cpu(), dim=1).tolist()

            # Donor loss (differentiable expected distance) aggregated for events
            donor_loss_total = torch.tensor(0.0, device=device)
            donor_events = 0
            # For each parent with true_label==1, check whether recipient predicted correctly
            for i_parent, parent in enumerate(parent_nodes_list):
                lbl = int(true_labels[i_parent].item())
                if lbl != 1:
                    continue
                true_rec_child = true_rec_child_list[i_parent]
                true_donor_parent = true_donor_parent_list[i_parent]
                # predicted recipient child:
                pred_rec_child = list(data.G.successors(parent))[which_pred[i_parent]]
                pred_prob = float(probs[i_parent])
                # condition: only add donor loss if recipient correctly identified
                if (pred_prob >= hgt_threshold) and (true_rec_child is not None) and (pred_rec_child == true_rec_child):
                    # build candidate list: nodes with node_time < recipient_time
                    rec_node = pred_rec_child
                    rec_time = float(data.G.nodes[rec_node].get('node_time', 0.0))
                    candidates = [n for n in data.G.nodes if float(data.G.nodes[n].get('node_time',0.0)) < rec_time and n != rec_node]
                    # ensure true donor parent is among candidates - if not, append it (keeps learning signal)
                    if true_donor_parent is not None and true_donor_parent not in candidates:
                        candidates.append(true_donor_parent)
                    if len(candidates) == 0:
                        # no candidates -> penalize by max distance (normalized)
                        max_pen = len(data.G.nodes)
                        donor_loss_total = donor_loss_total + (torch.tensor(float(max_pen), device=device) / float(max_pen))
                        donor_events += 1
                        continue

                    # sample candidates if too many
                    if len(candidates) > max_candidates:
                        random.shuffle(candidates)
                        candidates = candidates[:max_candidates]
                    cand_idx_tensor = torch.tensor(candidates, dtype=torch.long, device=device)

                    donor_vecs = agg[cand_idx_tensor]  # [M, D]
                    donor_times = torch.tensor([float(data.G.nodes[n].get('node_time',0.0)) for n in candidates], device=device)
                    donor_levels = torch.tensor([float(data.G.nodes[n].get('level',0.0)) for n in candidates], device=device)
                    time_diff = (rec_time - donor_times).unsqueeze(1)              # [M,1]
                    level_diff = (data.G.nodes[rec_node].get('level',0.0) - donor_levels).unsqueeze(1)
                    donor_frac = (donor_vecs.sum(dim=1).unsqueeze(1) / (agg.sum(dim=1).clamp(min=1.0)[cand_idx_tensor].unsqueeze(1))).clamp(0.0,1.0)
                    donor_aux = torch.cat([time_diff, level_diff, donor_frac], dim=1)

                    # get raw scores (no softmax yet)
                    scores = model.donor_finder(agg[rec_node].to(device), donor_vecs, donor_aux)  # [M]

                    # compute distances vector to true donor
                    und = data.G.to_undirected()
                    max_pen = len(data.G.nodes)
                    dists = []
                    for cand in candidates:
                        if true_donor_parent is None:
                            # if no true donor known, give zero distance (no loss)
                            dists.append(0.0)
                        else:
                            try:
                                dd = nx.shortest_path_length(und, source=cand, target=true_donor_parent)
                                dists.append(float(dd))
                            except Exception:
                                dists.append(float(max_pen))
                    dists = torch.tensor(dists, dtype=torch.float32, device=device)  # [M]

                    # Softmax probabilities over scores (temperature can be used)
                    probs_soft = torch.softmax(scores, dim=0)  # [M]
                    # Expected distance (differentiable)
                    expected_dist = torch.dot(probs_soft, dists)  # scalar
                    # Normalize by max_pen to keep scale ~[0,1]
                    expected_dist = expected_dist / float(max_pen)

                    donor_loss_total = donor_loss_total + expected_dist
                    donor_events += 1

                    # accumulate for reporting expected donor distance (as float)
                    donor_expected_distance_accum += float((expected_dist * float(max_pen)).item())
                    donor_expected_count += 1

                    # -----------------------
                    # HERE: perform the structure correction exactly when:
                    # recipient is predicted correctly (pred_rec_child == true_rec_child) AND
                    # donor parent predicted equals true donor parent (donor selected by something),
                    # We check whether the highest-scoring candidate equals true_donor_parent.
                    # -----------------------
                    _, max_idx = torch.max(scores, dim=0)
                    pred_donor_parent_candidate = candidates[int(max_idx.item())] if len(candidates) > 0 else None

                    # Only perform *safe* structural edit when predicted donor parent matches truth.
                    if (pred_donor_parent_candidate == true_donor_parent) and (pred_donor_parent_candidate is not None):
                        # do a safe structural edit on a COPY of data.G (work on copy to be safe)
                        G_mod = copy.deepcopy(data.G)
                        node_seq_matrix_current = data.hot_encoded_nucleotide_sequences.clone().to(device)
                        current_max_node_id = int(max(list(G_mod.nodes))) if len(G_mod.nodes) > 0 else -1
                        new_node = current_max_node_id + 1

                        # extend node_seq_matrix_current by zeros if needed
                        rows_needed = new_node + 1 - node_seq_matrix_current.shape[0]
                        if rows_needed > 0:
                            pad = torch.zeros((rows_needed, node_seq_matrix_current.shape[1]), dtype=node_seq_matrix_current.dtype, device=node_seq_matrix_current.device)
                            node_seq_matrix_current = torch.cat([node_seq_matrix_current, pad], dim=0)

                        # compute rec_child_time and donor_parent_time ON THE COPY
                        rec_child_time = float(G_mod.nodes[pred_rec_child].get('node_time', 0.0))
                        donor_parent_time = float(G_mod.nodes[pred_donor_parent_candidate].get('node_time', 0.0))

                        # Use the predicted donor parent candidate variable consistently here:
                        donor_children = list(G_mod.successors(pred_donor_parent_candidate))
                        if len(donor_children) == 0:
                            # cannot split; skip structural edit
                            pass
                        else:
                            # pick donor_child heuristically: closest subtree-sum to rec subtree-sum
                            try:
                                agg_mod = model.compute_aggregates(G_mod, node_seq_matrix_current)
                            except Exception:
                                agg_mod = agg  # fallback to previously computed agg if compute fails

                            rec_sum = float(agg_mod[pred_rec_child].sum().item()) if pred_rec_child < agg_mod.shape[0] else 0.0
                            cand_sums = []
                            for c in donor_children:
                                if c < agg_mod.shape[0]:
                                    cand_sums.append((c, float(agg_mod[c].sum().item())))
                                else:
                                    cand_sums.append((c, 0.0))
                            if len(cand_sums) == 0:
                                donor_child = donor_children[0]
                            else:
                                donor_child = min(cand_sums, key=lambda x: abs(x[1] - rec_sum))[0]

                            donor_child_time = float(G_mod.nodes[donor_child].get('node_time', 0.0))

                            t1 = rec_child_time + donor_parent_time
                            t2 = donor_child_time + donor_parent_time
                            new_time = 0.5 * max(t1, t2)

                            # add new node and attributes to G_mod
                            G_mod.add_node(new_node)
                            G_mod.nodes[new_node]['node_time'] = new_time
                            G_mod.nodes[new_node]['level'] = (G_mod.nodes[pred_rec_child].get('level', 0.0) + G_mod.nodes[pred_donor_parent_candidate].get('level', 0.0))/2.0

                            # Rewire recipient edge: remove (parent -> pred_rec_child), insert (parent -> new_node) and (new_node -> pred_rec_child)

                            try:
                                G_mod.remove_edge(pred_donor_parent_candidate, donor_child)
                                G_mod.add_edge(pred_donor_parent_candidate, new_node)
                                G_mod.add_edge(new_node, donor_child)
                                
                                G_mod.remove_edge(parent, pred_rec_child)
                                G_mod.add_edge(new_node, pred_rec_child)
    
                                parent_of_recipient_parental_node = list(G_mod.predecessors(parent))[0]
                                other_child_of_recipient_parental_node = list(G_mod.successors(parent))[0]
    
                                G_mod.remove_edge(parent_of_recipient_parental_node, parent)
                                G_mod.remove_edge(parent, other_child_of_recipient_parental_node)
                                G_mod.add_edge(parent_of_recipient_parental_node, other_child_of_recipient_parental_node)
                            except:
                                print(data.G.edges)
                                print(data.hgt_events)
                                print("Predicted recipient:", parent, pred_rec_child)

                            #G_mod.remove_edge(pred_donor_parent_candidate, donor_child)
                            #G_mod.add_edge(pred_donor_parent_candidate, new_node)
                            

                            #G_mod.remove_edge(parent_of_recipient_parental_node, other_child_of_parent)
                            #G_mod.add_edge(parent_of_recipient_parental_node, other_child_of_parent)
                            

                            # Split donor edge: remove (pred_donor_parent_candidate -> donor_child) and add (pred_donor_parent_candidate -> new_node) and (new_node -> donor_child)
                            #if G_mod.has_edge(pred_donor_parent_candidate, donor_child):
                            #G_mod.add_edge(new_node, donor_child)

                            print(data.G.edges)
                            print(data.hgt_events)
                            print(G_mod.edges)

                            # Safety check: avoid creating cycles
                            if nx.has_path(G_mod, donor_child, parent):
                                # rollback: remove new node and its edges
                                if new_node in G_mod.nodes:
                                    G_mod.remove_node(new_node)
                                # don't change original data.G
                            else:
                                if nx.is_directed_acyclic_graph(G_mod):
                                    # Accept modification: update data.G so subsequent parents see change
                                    data.G = G_mod
                                    # Recompute agg so next parents in loop see updated aggregates
                                    try:
                                        agg = model.compute_aggregates(data.G, node_seq_matrix_current)
                                    except Exception:
                                        pass
                                else:
                                    # rollback due to cycle
                                    if new_node in G_mod.nodes:
                                        G_mod.remove_node(new_node)
                                    # leave original graph alone

            # end for each parent donor loss / possible structure edits

            # Normalize donor loss by donor_events (if >0)
            if donor_events > 0:
                donor_loss = donor_loss_total / donor_events
            else:
                donor_loss = torch.tensor(0.0, device=device)

            loss = rec_loss + lambda_donor * donor_loss

            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()

            train_loss += loss.item()
            train_rec_loss += rec_loss.item()
            train_don_loss += float(donor_loss.item())
            train_samples += 1

            # update metrics for recipient detection
            for i_parent, parent in enumerate(parent_nodes_list):
                true_lbl = int(true_labels[i_parent].item())
                pred_prob = float(probs[i_parent])
                pred_class = 1 if pred_prob >= hgt_threshold else 0
                if pred_class == 1 and true_lbl == 1:
                    tp += 1
                if pred_class == 1 and true_lbl == 0:
                    fp += 1
                if pred_class == 0 and true_lbl == 1:
                    fn += 1

        # End epoch training stats
        avg_train_loss = train_loss / max(1, train_samples)
        avg_train_rec = train_rec_loss / max(1, train_samples)
        avg_train_don = train_don_loss / max(1, train_samples)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        avg_expected_donor_distance = (donor_expected_distance_accum / donor_expected_count) if donor_expected_count > 0 else 0.0

        # Validation pass (keeps previous validation logic; structure edits are not applied during validation)
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            val_rec_loss = 0.0
            val_don_loss = 0.0
            val_samples = 0
            v_tp = v_fp = v_fn = 0
            v_donor_expected_dist_acc = 0.0
            v_donor_expected_cnt = 0

            for idx in val_idx:
                data = list_of_Data[idx]
                data.hot_encoded_nucleotide_sequences = data.hot_encoded_nucleotide_sequences.to(device)
                recip_map = data.recip_label_map
                donor_map = data.donor_map
                parents = [n for n in data.G.nodes if len(list(data.G.successors(n))) == 2]
                if len(parents) == 0:
                    continue
                agg = model.compute_aggregates(data.G, data.hot_encoded_nucleotide_sequences)
                # collect lists
                L_list, R_list, aux_list = [], [], []
                true_labels = []
                true_rec_child_list = []
                true_donor_parent_list = []
                parent_nodes_list = []
                for parent in parents:
                    left, right = list(data.G.successors(parent))[0], list(data.G.successors(parent))[1]
                    L_list.append(agg[left].unsqueeze(0))
                    R_list.append(agg[right].unsqueeze(0))
                    node_time = float(data.G.nodes[parent].get('node_time', 0.0))
                    node_level = float(data.G.nodes[parent].get('level', 0.0))
                    parent_sum = agg[parent].sum().clamp(min=1.0)
                    gene_frac = ((agg[left].sum() + agg[right].sum()) / parent_sum).unsqueeze(0).unsqueeze(1)
                    aux = torch.tensor([[node_time, node_level, float(gene_frac.item())]], device=device, dtype=torch.float32)
                    aux_list.append(aux)
                    lbl, true_rec = recip_map.get(parent, (0, None))
                    true_labels.append(lbl)
                    true_rec_child_list.append(true_rec)
                    true_donor_parent_list.append(donor_map.get(parent, None))
                    parent_nodes_list.append(parent)
                if len(L_list) == 0:
                    continue
                left_vecs = torch.cat(L_list, dim=0)
                right_vecs = torch.cat(R_list, dim=0)
                aux_feats = torch.cat(aux_list, dim=0)
                true_labels = torch.tensor(true_labels, dtype=torch.float32, device=device)

                hgt_logits, which_logits = model.recipient_finder(left_vecs, right_vecs, aux_feats)
                rec_loss = bce_loss_fn(hgt_logits, true_labels)

                # donor loss computation: same expected-distance approach but without gradients
                probs = torch.sigmoid(hgt_logits).cpu().numpy()
                which_pred = torch.argmax(which_logits, dim=1).tolist()

                donor_loss_total = 0.0
                donor_events = 0
                for i_parent, parent in enumerate(parent_nodes_list):
                    lbl = int(true_labels[i_parent].item())
                    if lbl != 1:
                        continue
                    true_rec_child = true_rec_child_list[i_parent]
                    true_donor_parent = true_donor_parent_list[i_parent]
                    pred_rec_child = list(data.G.successors(parent))[which_pred[i_parent]]
                    pred_prob = float(probs[i_parent])
                    if (pred_prob >= hgt_threshold) and (true_rec_child is not None) and (pred_rec_child == true_rec_child):
                        rec_node = pred_rec_child
                        rec_time = float(data.G.nodes[rec_node].get('node_time', 0.0))
                        candidates = [n for n in data.G.nodes if float(data.G.nodes[n].get('node_time',0.0)) < rec_time and n != rec_node]
                        if true_donor_parent is not None and true_donor_parent not in candidates:
                            candidates.append(true_donor_parent)
                        if len(candidates) == 0:
                            donor_loss_total += 1.0
                            donor_events += 1
                            continue
                        if len(candidates) > max_candidates:
                            random.shuffle(candidates)
                            candidates = candidates[:max_candidates]
                        cand_idx_tensor = torch.tensor(candidates, dtype=torch.long, device=device)
                        donor_vecs = agg[cand_idx_tensor]
                        donor_times = torch.tensor([float(data.G.nodes[n].get('node_time',0.0)) for n in candidates], device=device)
                        donor_levels = torch.tensor([float(data.G.nodes[n].get('level',0.0)) for n in candidates], device=device)
                        time_diff = (rec_time - donor_times).unsqueeze(1)
                        level_diff = (data.G.nodes[rec_node].get('level',0.0) - donor_levels).unsqueeze(1)
                        donor_frac = (donor_vecs.sum(dim=1).unsqueeze(1) / (agg.sum(dim=1).clamp(min=1.0)[cand_idx_tensor].unsqueeze(1))).clamp(0.0,1.0)
                        donor_aux = torch.cat([time_diff, level_diff, donor_frac], dim=1)
                        scores = model.donor_finder(agg[rec_node].to(device), donor_vecs, donor_aux)
                        # distances
                        und = data.G.to_undirected()
                        max_pen = len(data.G.nodes)
                        dists = []
                        for cand in candidates:
                            if true_donor_parent is None:
                                dists.append(0.0)
                            else:
                                try:
                                    dd = nx.shortest_path_length(und, source=cand, target=true_donor_parent)
                                    dists.append(float(dd))
                                except Exception:
                                    dists.append(float(max_pen))
                        dists = torch.tensor(dists, dtype=torch.float32, device=device)
                        probs_soft = torch.softmax(scores, dim=0)
                        expected_dist = float(torch.dot(probs_soft, dists).item()) / float(max_pen)
                        donor_loss_total += expected_dist
                        donor_events += 1
                        v_donor_expected_dist_acc += (expected_dist * float(max_pen))
                        v_donor_expected_cnt += 1

                donor_loss = (donor_loss_total / donor_events) if donor_events > 0 else 0.0
                loss = rec_loss.item() + lambda_donor * donor_loss

                val_loss += loss
                val_rec_loss += rec_loss.item()
                val_don_loss += donor_loss
                val_samples += 1

                # update val metrics for recipient detection
                for i_parent, parent in enumerate(parent_nodes_list):
                    true_lbl = int(true_labels[i_parent].item())
                    pred_prob = float(probs[i_parent])
                    pred_class = 1 if pred_prob >= hgt_threshold else 0
                    if pred_class == 1 and true_lbl == 1:
                        v_tp += 1
                    if pred_class == 1 and true_lbl == 0:
                        v_fp += 1
                    if pred_class == 0 and true_lbl == 1:
                        v_fn += 1

            avg_val_loss = val_loss / max(1, val_samples)
            avg_val_rec = val_rec_loss / max(1, val_samples)
            avg_val_don = val_don_loss / max(1, val_samples)
            v_precision = v_tp / (v_tp + v_fp) if (v_tp + v_fp) > 0 else 0.0
            v_recall = v_tp / (v_tp + v_fn) if (v_tp + v_fn) > 0 else 0.0
            v_f1 = 2 * v_precision * v_recall / (v_precision + v_recall) if (v_precision + v_recall) > 0 else 0.0
            avg_v_donor_expected = (v_donor_expected_dist_acc / v_donor_expected_cnt) if v_donor_expected_cnt > 0 else 0.0

        # scheduler step (use validation loss)
        scheduler.step(avg_val_loss)

        t1 = time.time()
        print(f"Epoch {epoch}/{epochs} | time {t1-t0:.1f}s")
        print(f" TRAIN loss {avg_train_loss:.4f} (rec {avg_train_rec:.4f}, don {avg_train_don:.4f}) "
              f"Prec/Rec/F1 {precision:.3f}/{recall:.3f}/{f1:.3f} | avg_expected_donor_dist {avg_expected_donor_distance:.3f}")
        print(f" VAL   loss {avg_val_loss:.4f} (rec {avg_val_rec:.4f}, don {avg_val_don:.4f}) "
              f"Prec/Rec/F1 {v_precision:.3f}/{v_recall:.3f}/{v_f1:.3f} | avg_expected_donor_dist {avg_v_donor_expected:.3f}")
        print("-" * 120)

    return model


# -------------------------
# Example usage snippet (unchanged seeds + instantiation)
# -------------------------
random.seed(42)
torch.manual_seed(42)

# choose a sample data to get seq_vec_dim
sample_data = random.choice(list_of_Data)
seq_vec_dim = sample_data.hot_encoded_nucleotide_sequences.shape[1]
model = HGTDetector(seq_vec_dim=seq_vec_dim, hgt_threshold=0.5, topk_donors=1)

trained_model = train_hgt_detector(model, list_of_Data,
                                  epochs=4,
                                  val_frac=0.1,
                                  lr=1e-3,
                                  lambda_donor=0.1,
                                  hgt_threshold=0.5,
                                  max_candidates=300,
                                  device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                                  clip_grad=5.0)


Training on 90 graphs | Validation on 9 graphs | pos_weight=21.500
Epoch 1/4 | time 0.2s
 TRAIN loss 1.2130 (rec 1.2130, don 0.0000) Prec/Rec/F1 0.077/0.062/0.069 | avg_expected_donor_dist 0.000
 VAL   loss 1.8585 (rec 1.8585, don 0.0000) Prec/Rec/F1 0.000/0.000/0.000 | avg_expected_donor_dist 0.000
------------------------------------------------------------------------------------------------------------------------
[(5, 3), (5, 4), (8, 0), (8, 7), (7, 1), (7, 6), (6, 2), (6, 5)]
{0: array([(7, 1, 1, 7, 6)],
      dtype=[('recipient_parent_node', '<i4'), ('recipient_child_node', '<i4'), ('leaf', '<i4'), ('donor_parent_node', '<i4'), ('donor_child_node', '<i4')])}
Predicted recipient: 7 1
[(5, 3), (5, 4), (8, 0), (8, 7), (7, 1), (7, 6), (6, 2), (6, 5)]
{0: array([(7, 1, 1, 7, 6)],
      dtype=[('recipient_parent_node', '<i4'), ('recipient_child_node', '<i4'), ('leaf', '<i4'), ('donor_parent_node', '<i4'), ('donor_child_node', '<i4')])}
[(5, 3), (5, 4), (8, 0), (8, 7), (7, 6), (7, 9), 