In [1]:
import os
import random
import h5py
import pickle
import torch
import networkx as nx
from torch_geometric.data import Data
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Sequential, ReLU, Dropout, BatchNorm1d
from typing import List, Tuple, Dict, Any

def one_hot_encode(sequences, gene_present, gene_length, alphabet=['A','C','T','G','-']):
    """
    sequences: List of strings (DNA sequences)
    gene_present: np.array(bool) oder Torch Tensor, gleiche Länge wie sequences
    gene_length: int, fixe Länge für das Hot-Encoding
    alphabet: list, Zeichenalphabet
    """
    num_samples = len(sequences)
    num_chars = len(alphabet)
    char_to_idx = {c:i for i,c in enumerate(alphabet)}
    sequences_str = [s.decode('utf-8') for s in sequences]
    gene_present = np.array(gene_present, dtype=bool)
    
    # 1️⃣ Leere Batch-Matrix vorbereiten: (num_samples, gene_length, num_chars)
    batch = np.zeros((num_samples, gene_length, num_chars), dtype=np.float32)
    
    # 2️⃣ Hot-Encode alle Sequenzen
    for i, seq in enumerate(sequences_str):
        if gene_present[i]:
            L = min(len(seq), gene_length)  # abschneiden
            for j, c in enumerate(seq[:L]):
                if c in char_to_idx:
                    batch[i, j, char_to_idx[c]] = 1.0
        elif gene_present[i] == 0:
            batch[i, :, :] = -1.0
            
    # 3️⃣ Zufällige, aber konsistente Spaltenpermutation
    perm = np.random.permutation(gene_length)
    batch = batch[:, perm, :]
    
    # 4️⃣ Optional: Flatten zu Vektor (num_samples, gene_length*num_chars)
    batch_flat = batch.reshape(num_samples, -1)
    
    return torch.tensor(batch_flat)  # shape: (num_samples, gene_length*num_chars)

def load_file(file):

    gene_length = 300
    #nucleotide_mutation_rate = 0.1
    
    with h5py.File(file, "r") as f:
            grp = f["results"]
            # Load graph_properties (pickle stored in dataset)
            graph_properties = pickle.loads(grp["graph_properties"][()])
    
            # Unpack graph properties
            nodes = torch.tensor(graph_properties[0])                # [num_nodes]
            edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
            coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]
    
            # Load datasets instead of attrs
            gene_absence_presence_matrix = grp["gene_absence_presence_matrix"][()]
            nucleotide_sequences = grp["nucleotide_sequences"][()]
            #children_gene_nodes_loss_events = grp["children_gene_nodes_loss_events"][()]
        
            # Load HGT events (simplified)
            hgt_events = {}
            hgt_grp_simpl = grp.get("nodes_hgt_events_simplified", None)
            if hgt_grp_simpl is not None:
                for site_id in hgt_grp_simpl.keys():
                    hgt_events[int(site_id)] = hgt_grp_simpl[site_id][()]
            else:
                hgt_events = {}

            hot_encoded_nucleotide_sequences = one_hot_encode(nucleotide_sequences, gene_absence_presence_matrix, gene_length)

            # Fill the remaining nodes with zeros.
            pad_rows = len(nodes) - len(nucleotide_sequences)
            pad = torch.zeros((pad_rows, hot_encoded_nucleotide_sequences.shape[1]), dtype=hot_encoded_nucleotide_sequences.dtype)
            hot_encoded_nucleotide_sequences = torch.cat([hot_encoded_nucleotide_sequences, pad], dim=0)
        
            G = nx.DiGraph()
            
            # Füge Knoten hinzu (optional mit Koordinaten als Attribut)
            node_id_list = nodes.tolist()
            for i, node_id in enumerate(node_id_list):
                # coords[:, i] exists; using index 5 as in your code for node_time
                G.add_node(node_id, node_time = coords[:, i].tolist()[5])
            
            # Füge Kanten hinzu
            edge_list = edges.tolist()
            for src, dst in zip(edge_list[0], edge_list[1]):
                G.add_edge(src, dst)
            
            # Collect all recipient_parent_nodes from all sites (if present)
            recipient_parent_nodes = set()
            if hgt_grp_simpl is not None:
                for site_id in hgt_grp_simpl.keys():
                    arr = hgt_grp_simpl[site_id][()]  # load dataset as numpy structured array
                    # try robust field name for recipient child
                    if 'recipient_child_node' in arr.dtype.names:
                        recipient_parent_nodes.update(arr["recipient_child_node"].tolist())
                    elif 'recipient_child' in arr.dtype.names:
                        recipient_parent_nodes.update(arr["recipient_child"].tolist())

            # Build theta_gains: 1 if node is in recipient_parent_nodes, else 0
            theta_gains = torch.tensor(
                [1 if node in recipient_parent_nodes else 0 for node in range(len(G.nodes))],
                dtype=torch.long
            )

            level = {n: 0 for n in G.nodes}  # Leaves haben Level 0
            
            # 3. Topologische Sortierung (damit Kinder vor Eltern behandelt werden)
            for node in reversed(list(nx.topological_sort(G))):
                successors = list(G.successors(node))
                if successors:
                    level[node] = 1 + max(level[s] for s in successors)
            
            # 4. Level als Attribut setzen
            nx.set_node_attributes(G, level, "level")

            ### Add candidate egdes, i.e. potential hgt edges:
        
            # Hole alle Knoten und ihre Zeiten
            node_times = {n: G.nodes[n]['node_time'] for n in G.nodes}
            sorted_nodes = sorted(node_times.keys(), key=lambda n: node_times[n])

            # Sort the level and node_times from 0 to max_node_id and not in G.nodes order:
            node_times = torch.tensor([node_times[n] for n in sorted_nodes], dtype=torch.float)
            node_levels = torch.tensor([level[n] for n in sorted_nodes], dtype=torch.float)

            # Füge Kanten hinzu
            existing_edges = set(zip(edges[0].tolist(), edges[1].tolist()))
            candidate_edges = []
            for i, src in enumerate(sorted_nodes):
                t_src = node_times[src]
                for dst in sorted_nodes[i+1:]:  # nur spätere Knoten
                    t_dst = node_times[dst]
                    if t_dst > t_src:
                        if (dst, src) not in existing_edges:  # vermeidet doppelte
                            #G.add_edge(src, dst)
                            candidate_edges.append((dst, src))
            if candidate_edges:  
                candidate_edges = torch.tensor(candidate_edges, dtype=torch.long).t()
            else:
                candidate_edges = torch.empty((2,0), dtype=torch.long)
                    
            data = Data(
                #nucleotide_sequences = nucleotide_sequences,
                hot_encoded_nucleotide_sequences = hot_encoded_nucleotide_sequences,       # Node Features [num_nodes, 2]
                edge_index = edges[[1, 0], :],        # Edge Index [2, num_edges]
                candidate_edges = candidate_edges[[1, 0], :],
                y = theta_gains,            # Labels [num_nodes]
                file = file,
                G = G,
                #recipient_parent_nodes = recipient_parent_nodes,
                gene_absence_presence_matrix = gene_absence_presence_matrix,
                node_times = node_times,
                node_levels = node_levels,
                hgt_events = hgt_events,
            )

            # -----------------------------
            # Compute label maps directly here (no external call)
            # -----------------------------
            recip_label_map = {}
            donor_map = {}
            # initialize maps for internal nodes with exactly 2 children
            for n in G.nodes:
                children = list(G.successors(n))
                if len(children) == 2:
                    recip_label_map[n] = (0, None)
                    donor_map[n] = None

            # parse hgt_events and fill maps
            if hgt_events:
                for site_id, arr in hgt_events.items():
                    # arr is expected to be a numpy structured array
                    try:
                        ln = len(arr)
                    except Exception:
                        continue
                    for i in range(ln):
                        # robust field extraction
                        rec_parent = None
                        rec_child = None
                        donor_parent = None
                        if 'recipient_parent_node' in arr.dtype.names:
                            rec_parent = int(arr[i]['recipient_parent_node'])
                        elif 'recipient_parent' in arr.dtype.names:
                            rec_parent = int(arr[i]['recipient_parent'])
                        if 'recipient_child_node' in arr.dtype.names:
                            rec_child = int(arr[i]['recipient_child_node'])
                        elif 'recipient_child' in arr.dtype.names:
                            rec_child = int(arr[i]['recipient_child'])
                        if 'donor_parent_node' in arr.dtype.names:
                            donor_parent = int(arr[i]['donor_parent_node'])
                        elif 'donor_parent' in arr.dtype.names:
                            donor_parent = int(arr[i]['donor_parent'])

                        if rec_parent is None:
                            continue
                        if rec_parent in recip_label_map:
                            recip_label_map[rec_parent] = (1, rec_child)
                            donor_map[rec_parent] = donor_parent

            # attach precomputed label maps to data
            data.recip_label_map = recip_label_map
            data.donor_map = donor_map

    return data

folder = "/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks"

files = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

if len(files) > 100:
    files = random.sample(files, 100)

list_of_Data = []
for f in files:
    try:
        d = load_file(f)
        list_of_Data.append(d)
    except Exception as e:
        print(f"Fehler beim Laden von {f}: {e}")

# example = load_file(random.choice(files))

print(f"{len(list_of_Data)} Dateien erfolgreich geladen.")

  machar = _get_machar(dtype)
  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]


100 Dateien erfolgreich geladen.


In [10]:
import os
import random
import math
import time
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import networkx as nx

# -------------------------
# Model components
# -------------------------

def make_mlp(input_dim, hidden_dims=[256,128], output_dim=None, dropout=0.1):
    """
    Build an MLP that uses LayerNorm instead of BatchNorm.
    LayerNorm is stable for batch size 1.
    """
    layers = []
    dims = [input_dim] + hidden_dims
    for i in range(len(hidden_dims)):
        layers.append(nn.Linear(dims[i], dims[i+1]))
        # use LayerNorm over feature dimension
        layers.append(nn.LayerNorm(dims[i+1]))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout))
    if output_dim is not None:
        layers.append(nn.Linear(dims[-1], output_dim))
    return nn.Sequential(*layers)


class RecipientFinder(nn.Module):
    """
    Returns:
      - hgt_logit: raw logit (no sigmoid) shape [B]
      - which_logits: raw logits for left/right shape [B,2]
    """
    def __init__(self, seq_vec_dim, aux_dim=3, hidden=[512,256], dropout=0.1):
        super().__init__()
        self.seq_vec_dim = seq_vec_dim
        self.aux_dim = aux_dim
        mlp_input = seq_vec_dim * 4 + aux_dim
        self.mlp = make_mlp(mlp_input, hidden_dims=hidden, output_dim=None, dropout=dropout)
        self.hgt_head = nn.Linear(hidden[-1], 1)    # logit
        self.which_head = nn.Linear(hidden[-1], 2)  # logits

    def forward(self, left_vec, right_vec, aux_feats):
        # left_vec/right_vec: [B, D]
        absdiff = torch.abs(left_vec - right_vec)
        prod = left_vec * right_vec
        x = torch.cat([left_vec, right_vec, absdiff, prod, aux_feats], dim=1)
        h = self.mlp(x)
        hgt_logit = self.hgt_head(h).squeeze(-1)         # [B]
        which_logits = self.which_head(h)                # [B,2]
        return hgt_logit, which_logits

class DonorFinder(nn.Module):
    """
    Scores candidate donors relative to a recipient vector.
    Returns raw scores (logits) of shape [M]
    """
    def __init__(self, seq_vec_dim, aux_dim=3, hidden=[512,256], dropout=0.1):
        super().__init__()
        mlp_input = seq_vec_dim * 4 + aux_dim
        self.score_mlp = make_mlp(mlp_input, hidden_dims=hidden, output_dim=1, dropout=dropout)

    def forward(self, recipient_vec, donor_vecs, aux_feats):
        # recipient_vec: [D] or [1,D]
        if recipient_vec.dim() == 1:
            r = recipient_vec.unsqueeze(0).expand(donor_vecs.shape[0], -1)
        else:
            r = recipient_vec.expand(donor_vecs.shape[0], -1)
        absdiff = torch.abs(r - donor_vecs)
        prod = r * donor_vecs
        x = torch.cat([r, donor_vecs, absdiff, prod, aux_feats], dim=1)
        scores = self.score_mlp(x).squeeze(-1)  # [M]
        return scores

class HGTDetector(nn.Module):
    """
    Orchestrator that contains RecipientFinder and DonorFinder and provides utility compute_aggregates.
    """
    def __init__(self, seq_vec_dim, aux_recipient_dim=3, aux_donor_dim=3,
                 rec_hidden=[16], donor_hidden=[16], hgt_threshold=0.5, topk_donors=1):
        super().__init__()
        self.recipient_finder = RecipientFinder(seq_vec_dim, aux_dim=aux_recipient_dim, hidden=rec_hidden)
        self.donor_finder = DonorFinder(seq_vec_dim, aux_dim=aux_donor_dim, hidden=donor_hidden)
        self.hgt_threshold = hgt_threshold
        self.topk_donors = topk_donors

    @staticmethod
    def compute_aggregates(G: nx.DiGraph, node_seq_matrix: torch.Tensor):
        """
        Sum one-hot vectors of leaves under each node.
        Assumes node_seq_matrix rows correspond to node indices aligned with node ids in G (0..N-1),
        otherwise adapt mapping externally.
        """
        nodes = list(G.nodes)
        N = len(nodes)
        # Build mapping node_id -> index if node ids are not 0..N-1
        if nodes == list(range(N)):
            node_to_idx = {n: n for n in nodes}
        else:
            node_to_idx = {n: i for i, n in enumerate(nodes)}
        children = {n: list(G.successors(n)) for n in nodes}

        # We'll create agg in the same indexing as node_to_idx
        D = node_seq_matrix.shape[1]
        agg = torch.zeros((N, D), dtype=node_seq_matrix.dtype, device=node_seq_matrix.device)

        # If node_seq_matrix ordering matches node_to_idx, copy directly
        # Attempt to copy: if node_seq_matrix has N rows we copy by index order 0..N-1
        if node_seq_matrix.shape[0] == N:
            try:
                agg = node_seq_matrix.clone()
            except Exception:
                agg = node_seq_matrix.clone().to(node_seq_matrix.device)
        else:
            # fallback: place available rows in order of nodes if possible (rare)
            raise RuntimeError("node_seq_matrix rows != number of nodes; please supply node-ordered matrix.")

        # children processed before parent: do topological sort
        topo = list(nx.topological_sort(G))

        for node in reversed(topo):
            i = node_to_idx[node]
            for c in children[node]:
                j = node_to_idx[c]
                agg[i] = torch.max(agg[i], agg[j])
        return agg

# -------------------------
# Utilities for training
# -------------------------
def tree_distance(G: nx.DiGraph, node_a, node_b, max_penalty=None):
    """
    Shortest path length in undirected tree, fallback to max_penalty or len(G.nodes)
    """
    und = G.to_undirected()
    try:
        return nx.shortest_path_length(und, source=node_a, target=node_b)
    except Exception:
        return max_penalty if max_penalty is not None else len(G.nodes)


# -------------------------
# Training loop (fully adjusted)
# -------------------------
def train_hgt_detector(model: HGTDetector, list_of_Data, epochs=10, val_frac=0.1, lr=1e-3,
                       lambda_donor=0.1, hgt_threshold=0.5, max_candidates=300,
                       device=None, clip_grad=5.0):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    N = len(list_of_Data)
    idxs = list(range(N))
    random.shuffle(idxs)
    n_val = max(1, int(val_frac * N))
    val_idx = set(idxs[:n_val])
    train_idx = idxs[n_val:]

    # Estimate pos_weight for BCEWithLogitsLoss on training set (clamped)
    pos_count = 0
    neg_count = 0
    for i in train_idx:
        d = list_of_Data[i]
        recip_map = d.recip_label_map
        _ = d.donor_map
        for p, (lbl, _) in recip_map.items():
            if lbl == 1:
                pos_count += 1
            else:
                neg_count += 1
    pos_count = max(pos_count, 1)
    neg_count = max(neg_count, 1)
    ratio = neg_count / pos_count
    pos_weight_val = max(1.0, min(ratio, 50.0))  # clamp between 1 and 50
    pos_weight = torch.tensor([pos_weight_val], device=device)
    bce_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    print(f"Training on {len(train_idx)} graphs | Validation on {len(val_idx)} graphs | pos_weight={pos_weight_val:.3f}")

    for epoch in range(1, epochs+1):
        t0 = time.time()
        model.train()
        train_loss = 0.0
        train_rec_loss = 0.0
        train_don_loss = 0.0
        train_samples = 0

        # metrics
        tp = 0; fp = 0; fn = 0
        donor_expected_distance_accum = 0.0
        donor_expected_count = 0

        random.shuffle(train_idx)
        for idx in train_idx:
            data = list_of_Data[idx]
            # move seqs to device once
            data.hot_encoded_nucleotide_sequences = data.hot_encoded_nucleotide_sequences.to(device)

            recip_map = data.recip_label_map
            donor_map = data.donor_map
            parents = [n for n in data.G.nodes if len(list(data.G.successors(n))) == 2]
            if len(parents) == 0:
                continue

            # compute aggregates
            agg = model.compute_aggregates(data.G, data.hot_encoded_nucleotide_sequences)  # [N,D]

            # collect batch lists
            L_list = []
            R_list = []
            aux_list = []
            true_label_list = []
            true_rec_child_list = []
            true_donor_parent_list = []
            parent_nodes_list = []

            for parent in parents:
                children = list(data.G.successors(parent))
                left, right = children[0], children[1]
                L_list.append(agg[left].unsqueeze(0))
                R_list.append(agg[right].unsqueeze(0))
                # aux: node_time, level, gene_frac
                node_time = float(data.G.nodes[parent].get('node_time', 0.0))
                node_level = float(data.G.nodes[parent].get('level', 0.0))
                # avoid divide by zero
                parent_sum = agg[parent].sum().clamp(min=1.0)
                gene_frac = ((agg[left].sum() + agg[right].sum()) / parent_sum).unsqueeze(0).unsqueeze(1)
                aux = torch.tensor([[node_time, node_level, float(gene_frac.item())]], device=device, dtype=torch.float32)
                aux_list.append(aux)
                lbl, true_rec = recip_map.get(parent, (0, None))
                true_label_list.append(lbl)
                true_rec_child_list.append(true_rec)
                true_donor_parent_list.append(donor_map.get(parent, None))
                parent_nodes_list.append(parent)

            left_vecs = torch.cat(L_list, dim=0)      # [P, D]
            right_vecs = torch.cat(R_list, dim=0)     # [P, D]
            aux_feats = torch.cat(aux_list, dim=0)    # [P, 3]
            true_labels = torch.tensor(true_label_list, dtype=torch.float32, device=device)  # [P]

            # Recipient forward
            hgt_logits, which_logits = model.recipient_finder(left_vecs, right_vecs, aux_feats)
            rec_loss = bce_loss_fn(hgt_logits, true_labels)

            # compute predicted probabilities & which
            probs = torch.sigmoid(hgt_logits).detach().cpu().numpy()
            which_pred = torch.argmax(which_logits.detach().cpu(), dim=1).tolist()

            # Donor loss (differentiable expected distance) aggregated for events
            donor_loss_total = torch.tensor(0.0, device=device)
            donor_events = 0
            # For each parent with true_label==1, check whether recipient predicted correctly
            for i_parent, parent in enumerate(parent_nodes_list):
                lbl = int(true_labels[i_parent].item())
                if lbl != 1:
                    continue
                true_rec_child = true_rec_child_list[i_parent]
                true_donor_parent = true_donor_parent_list[i_parent]
                # predicted recipient child:
                pred_rec_child = list(data.G.successors(parent))[which_pred[i_parent]]
                pred_prob = float(probs[i_parent])
                # condition: only add donor loss if recipient correctly identified
                if (pred_prob >= hgt_threshold) and (true_rec_child is not None) and (pred_rec_child == true_rec_child):
                    # build candidate list: nodes with node_time < recipient_time
                    rec_node = pred_rec_child
                    rec_time = float(data.G.nodes[rec_node].get('node_time', 0.0))
                    candidates = [n for n in data.G.nodes if float(data.G.nodes[n].get('node_time',0.0)) < rec_time and n != rec_node]
                    # ensure true donor parent is among candidates - if not, append it (keeps learning signal)
                    if true_donor_parent is not None and true_donor_parent not in candidates:
                        candidates.append(true_donor_parent)
                    if len(candidates) == 0:
                        # no candidates -> penalize by max distance (normalized)
                        max_pen = len(data.G.nodes)
                        donor_loss_total = donor_loss_total + (torch.tensor(float(max_pen), device=device) / float(max_pen))
                        donor_events += 1
                        continue

                    # sample candidates if too many
                    if len(candidates) > max_candidates:
                        random.shuffle(candidates)
                        candidates = candidates[:max_candidates]
                    cand_idx_tensor = torch.tensor(candidates, dtype=torch.long, device=device)

                    donor_vecs = agg[cand_idx_tensor]  # [M, D]
                    donor_times = torch.tensor([float(data.G.nodes[n].get('node_time',0.0)) for n in candidates], device=device)
                    donor_levels = torch.tensor([float(data.G.nodes[n].get('level',0.0)) for n in candidates], device=device)
                    time_diff = (rec_time - donor_times).unsqueeze(1)              # [M,1]
                    level_diff = (data.G.nodes[rec_node].get('level',0.0) - donor_levels).unsqueeze(1)
                    donor_frac = (donor_vecs.sum(dim=1).unsqueeze(1) / (agg.sum(dim=1).clamp(min=1.0)[cand_idx_tensor].unsqueeze(1))).clamp(0.0,1.0)
                    donor_aux = torch.cat([time_diff, level_diff, donor_frac], dim=1)

                    # get raw scores (no softmax yet)
                    scores = model.donor_finder(agg[rec_node].to(device), donor_vecs, donor_aux)  # [M]

                    # compute distances vector to true donor
                    # ensure we have a numeric dist for each candidate
                    und = data.G.to_undirected()
                    max_pen = len(data.G.nodes)
                    dists = []
                    for cand in candidates:
                        if true_donor_parent is None:
                            # if no true donor known, give zero distance (no loss)
                            dists.append(0.0)
                        else:
                            try:
                                dd = nx.shortest_path_length(und, source=cand, target=true_donor_parent)
                                dists.append(float(dd))
                            except Exception:
                                dists.append(float(max_pen))
                    dists = torch.tensor(dists, dtype=torch.float32, device=device)  # [M]

                    # Softmax probabilities over scores (temperature can be used)
                    probs_soft = torch.softmax(scores, dim=0)  # [M]
                    # Expected distance (differentiable)
                    expected_dist = torch.dot(probs_soft, dists)  # scalar
                    # Normalize by max_pen to keep scale ~[0,1]
                    expected_dist = expected_dist / float(max_pen)

                    donor_loss_total = donor_loss_total + expected_dist
                    donor_events += 1

                    # accumulate for reporting expected donor distance (as float)
                    donor_expected_distance_accum += float((expected_dist * float(max_pen)).item())
                    donor_expected_count += 1

            # Normalize donor loss by donor_events (if >0)
            if donor_events > 0:
                donor_loss = donor_loss_total / donor_events
            else:
                donor_loss = torch.tensor(0.0, device=device)

            loss = rec_loss + lambda_donor * donor_loss

            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()

            train_loss += loss.item()
            train_rec_loss += rec_loss.item()
            train_don_loss += float(donor_loss.item())
            train_samples += 1

            # update metrics for recipient detection
            for i_parent, parent in enumerate(parent_nodes_list):
                true_lbl = int(true_labels[i_parent].item())
                pred_prob = float(probs[i_parent])
                pred_class = 1 if pred_prob >= hgt_threshold else 0
                if pred_class == 1 and true_lbl == 1:
                    tp += 1
                if pred_class == 1 and true_lbl == 0:
                    fp += 1
                if pred_class == 0 and true_lbl == 1:
                    fn += 1

        # End epoch training stats
        avg_train_loss = train_loss / max(1, train_samples)
        avg_train_rec = train_rec_loss / max(1, train_samples)
        avg_train_don = train_don_loss / max(1, train_samples)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        avg_expected_donor_distance = (donor_expected_distance_accum / donor_expected_count) if donor_expected_count > 0 else 0.0

        # Validation pass
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            val_rec_loss = 0.0
            val_don_loss = 0.0
            val_samples = 0
            v_tp = v_fp = v_fn = 0
            v_donor_expected_dist_acc = 0.0
            v_donor_expected_cnt = 0

            for idx in val_idx:
                data = list_of_Data[idx]
                data.hot_encoded_nucleotide_sequences = data.hot_encoded_nucleotide_sequences.to(device)
                recip_map = data.recip_label_map
                donor_map = data.donor_map
                parents = [n for n in data.G.nodes if len(list(data.G.successors(n))) == 2]
                if len(parents) == 0:
                    continue
                agg = model.compute_aggregates(data.G, data.hot_encoded_nucleotide_sequences)
                # collect lists
                L_list, R_list, aux_list = [], [], []
                true_labels = []
                true_rec_child_list = []
                true_donor_parent_list = []
                parent_nodes_list = []
                for parent in parents:
                    left, right = list(data.G.successors(parent))[0], list(data.G.successors(parent))[1]
                    L_list.append(agg[left].unsqueeze(0))
                    R_list.append(agg[right].unsqueeze(0))
                    node_time = float(data.G.nodes[parent].get('node_time', 0.0))
                    node_level = float(data.G.nodes[parent].get('level', 0.0))
                    parent_sum = agg[parent].sum().clamp(min=1.0)
                    gene_frac = ((agg[left].sum() + agg[right].sum()) / parent_sum).unsqueeze(0).unsqueeze(1)
                    aux = torch.tensor([[node_time, node_level, float(gene_frac.item())]], device=device, dtype=torch.float32)
                    aux_list.append(aux)
                    lbl, true_rec = recip_map.get(parent, (0, None))
                    true_labels.append(lbl)
                    true_rec_child_list.append(true_rec)
                    true_donor_parent_list.append(donor_map.get(parent, None))
                    parent_nodes_list.append(parent)
                left_vecs = torch.cat(L_list, dim=0)
                right_vecs = torch.cat(R_list, dim=0)
                aux_feats = torch.cat(aux_list, dim=0)
                true_labels = torch.tensor(true_labels, dtype=torch.float32, device=device)

                hgt_logits, which_logits = model.recipient_finder(left_vecs, right_vecs, aux_feats)
                rec_loss = bce_loss_fn(hgt_logits, true_labels)

                # donor loss computation: same expected-distance approach but without gradients
                probs = torch.sigmoid(hgt_logits).cpu().numpy()
                which_pred = torch.argmax(which_logits, dim=1).tolist()

                donor_loss_total = 0.0
                donor_events = 0
                for i_parent, parent in enumerate(parent_nodes_list):
                    lbl = int(true_labels[i_parent].item())
                    if lbl != 1:
                        continue
                    true_rec_child = true_rec_child_list[i_parent]
                    true_donor_parent = true_donor_parent_list[i_parent]
                    pred_rec_child = list(data.G.successors(parent))[which_pred[i_parent]]
                    pred_prob = float(probs[i_parent])
                    if (pred_prob >= hgt_threshold) and (true_rec_child is not None) and (pred_rec_child == true_rec_child):
                        rec_node = pred_rec_child
                        rec_time = float(data.G.nodes[rec_node].get('node_time', 0.0))
                        candidates = [n for n in data.G.nodes if float(data.G.nodes[n].get('node_time',0.0)) < rec_time and n != rec_node]
                        if true_donor_parent is not None and true_donor_parent not in candidates:
                            candidates.append(true_donor_parent)
                        if len(candidates) == 0:
                            donor_loss_total += 1.0
                            donor_events += 1
                            continue
                        if len(candidates) > max_candidates:
                            random.shuffle(candidates)
                            candidates = candidates[:max_candidates]
                        cand_idx_tensor = torch.tensor(candidates, dtype=torch.long, device=device)
                        donor_vecs = agg[cand_idx_tensor]
                        donor_times = torch.tensor([float(data.G.nodes[n].get('node_time',0.0)) for n in candidates], device=device)
                        donor_levels = torch.tensor([float(data.G.nodes[n].get('level',0.0)) for n in candidates], device=device)
                        time_diff = (rec_time - donor_times).unsqueeze(1)
                        level_diff = (data.G.nodes[rec_node].get('level',0.0) - donor_levels).unsqueeze(1)
                        donor_frac = (donor_vecs.sum(dim=1).unsqueeze(1) / (agg.sum(dim=1).clamp(min=1.0)[cand_idx_tensor].unsqueeze(1))).clamp(0.0,1.0)
                        donor_aux = torch.cat([time_diff, level_diff, donor_frac], dim=1)
                        scores = model.donor_finder(agg[rec_node].to(device), donor_vecs, donor_aux)
                        # distances
                        und = data.G.to_undirected()
                        max_pen = len(data.G.nodes)
                        dists = []
                        for cand in candidates:
                            if true_donor_parent is None:
                                dists.append(0.0)
                            else:
                                try:
                                    dd = nx.shortest_path_length(und, source=cand, target=true_donor_parent)
                                    dists.append(float(dd))
                                except Exception:
                                    dists.append(float(max_pen))
                        dists = torch.tensor(dists, dtype=torch.float32, device=device)
                        probs_soft = torch.softmax(scores, dim=0)
                        expected_dist = float(torch.dot(probs_soft, dists).item()) / float(max_pen)
                        donor_loss_total += expected_dist
                        donor_events += 1
                        v_donor_expected_dist_acc += (expected_dist * float(max_pen))
                        v_donor_expected_cnt += 1

                donor_loss = (donor_loss_total / donor_events) if donor_events > 0 else 0.0
                loss = rec_loss.item() + lambda_donor * donor_loss

                val_loss += loss
                val_rec_loss += rec_loss.item()
                val_don_loss += donor_loss
                val_samples += 1

                # update val metrics for recipient detection
                for i_parent, parent in enumerate(parent_nodes_list):
                    true_lbl = int(true_labels[i_parent].item())
                    pred_prob = float(probs[i_parent])
                    pred_class = 1 if pred_prob >= hgt_threshold else 0
                    if pred_class == 1 and true_lbl == 1:
                        v_tp += 1
                    if pred_class == 1 and true_lbl == 0:
                        v_fp += 1
                    if pred_class == 0 and true_lbl == 1:
                        v_fn += 1

            avg_val_loss = val_loss / max(1, val_samples)
            avg_val_rec = val_rec_loss / max(1, val_samples)
            avg_val_don = val_don_loss / max(1, val_samples)
            v_precision = v_tp / (v_tp + v_fp) if (v_tp + v_fp) > 0 else 0.0
            v_recall = v_tp / (v_tp + v_fn) if (v_tp + v_fn) > 0 else 0.0
            v_f1 = 2 * v_precision * v_recall / (v_precision + v_recall) if (v_precision + v_recall) > 0 else 0.0
            avg_v_donor_expected = (v_donor_expected_dist_acc / v_donor_expected_cnt) if v_donor_expected_cnt > 0 else 0.0

        # scheduler step (use validation loss)
        scheduler.step(avg_val_loss)

        t1 = time.time()
        print(f"Epoch {epoch}/{epochs} | time {t1-t0:.1f}s")
        print(f" TRAIN loss {avg_train_loss:.4f} (rec {avg_train_rec:.4f}, don {avg_train_don:.4f}) "
              f"Prec/Rec/F1 {precision:.3f}/{recall:.3f}/{f1:.3f} | avg_expected_donor_dist {avg_expected_donor_distance:.3f}")
        print(f" VAL   loss {avg_val_loss:.4f} (rec {avg_val_rec:.4f}, don {avg_val_don:.4f}) "
              f"Prec/Rec/F1 {v_precision:.3f}/{v_recall:.3f}/{v_f1:.3f} | avg_expected_donor_dist {avg_v_donor_expected:.3f}")
        print("-" * 120)

    return model


random.seed(42)
torch.manual_seed(42)

# choose a sample data to get seq_vec_dim
sample_data = random.choice(list_of_Data)
seq_vec_dim = sample_data.hot_encoded_nucleotide_sequences.shape[1]
model = HGTDetector(seq_vec_dim=seq_vec_dim, hgt_threshold=0.5, topk_donors=1)

trained_model = train_hgt_detector(model, list_of_Data,
                                  epochs=4,
                                  val_frac=0.1,
                                  lr=1e-3,
                                  lambda_donor=0.1,
                                  hgt_threshold=0.5,
                                  max_candidates=300,
                                  device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                                  clip_grad=5.0)


Training on 90 graphs | Validation on 10 graphs | pos_weight=13.400
Epoch 1/4 | time 0.2s
 TRAIN loss 1.2180 (rec 1.2176, don 0.0036) Prec/Rec/F1 0.135/0.200/0.161 | avg_expected_donor_dist 0.981
 VAL   loss 1.0983 (rec 1.0983, don 0.0000) Prec/Rec/F1 0.000/0.000/0.000 | avg_expected_donor_dist 0.000
------------------------------------------------------------------------------------------------------------------------
Epoch 2/4 | time 0.2s
 TRAIN loss 0.8702 (rec 0.8699, don 0.0036) Prec/Rec/F1 0.387/0.480/0.429 | avg_expected_donor_dist 0.555
 VAL   loss 1.1179 (rec 1.1179, don 0.0000) Prec/Rec/F1 0.000/0.000/0.000 | avg_expected_donor_dist 0.000
------------------------------------------------------------------------------------------------------------------------
Epoch 3/4 | time 0.2s
 TRAIN loss 0.5224 (rec 0.5222, don 0.0019) Prec/Rec/F1 0.611/0.880/0.721 | avg_expected_donor_dist 0.341
 VAL   loss 1.4188 (rec 1.4188, don 0.0000) Prec/Rec/F1 0.000/0.000/0.000 | avg_expected_donor

In [9]:
import os
import random
import math
import time
import h5py
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import networkx as nx

# -------------------------
# Model components
# -------------------------

def make_mlp(input_dim, hidden_dims=[32], output_dim=None, dropout=0.1):
    """
    Build an MLP that uses LayerNorm instead of BatchNorm.
    LayerNorm is stable for batch size 1.
    """
    layers = []
    dims = [input_dim] + hidden_dims
    for i in range(len(hidden_dims)):
        layers.append(nn.Linear(dims[i], dims[i+1]))
        # use LayerNorm over feature dimension
        layers.append(nn.LayerNorm(dims[i+1]))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout))
    if output_dim is not None:
        layers.append(nn.Linear(dims[-1], output_dim))
    return nn.Sequential(*layers)


class RecipientFinder(nn.Module):
    """
    Binary detector: for a parent (given aggregated vectors of its two children + aux),
    predict whether an HGT event (recipient) occurred at this parent.

    Output:
      - hgt_logit: raw logit (no sigmoid) shape [B]
    """
    def __init__(self, seq_vec_dim, aux_dim=3, hidden=[32], dropout=0.1):
        super().__init__()
        self.seq_vec_dim = seq_vec_dim
        self.aux_dim = aux_dim
        # input: left, right, |left-right|, left*right, aux => 4*D + aux_dim
        mlp_input = seq_vec_dim * 4 + aux_dim
        self.mlp = make_mlp(mlp_input, hidden_dims=hidden, output_dim=None, dropout=dropout)
        self.hgt_head = nn.Linear(hidden[-1], 1)    # logit

    def forward(self, left_vec, right_vec, aux_feats):
        # left_vec/right_vec: [B, D]
        absdiff = torch.abs(left_vec - right_vec)
        prod = left_vec * right_vec
        x = torch.cat([left_vec, right_vec, absdiff, prod, aux_feats], dim=1)
        h = self.mlp(x)
        hgt_logit = self.hgt_head(h).squeeze(-1)         # [B]
        return hgt_logit


class DonorFinder(nn.Module):
    """
    The DonorFinder now:
      - takes the aggregated vectors for the two daughters (left_vec, right_vec)
      - compares each daughter's aggregated-sum with aggregated sums of candidate nodes
        with node_time > daughter's node_time (as per your spec)
      - returns:
         which_pred: 0 if left is recipient, 1 if right is recipient (the one that best matches)
         donor_parent_pred: node id of the best matching donor-parent
         donor_score: optional score (float) for the chosen donor
    Note: This implementation uses simple scalar-sum comparisons (L1 on sums) for speed/clarity.
    You can replace the distance metric with vector L1/L2 if desired.
    """
    def __init__(self):
        super().__init__()
        # DonorFinder here is mostly procedural (no learnable params in this implementation).
        # If you want learnable scoring, you can replace the matching with an MLP.
        pass

    @staticmethod
    def find_donor_and_which(agg_tensor, G, parent, left, right, node_seq_sums=None):
        """
        Find which daughter is predicted recipient (0=left,1=right) and the donor_parent node.
        - agg_tensor: [N, D] tensor of aggregated leaf-vectors for nodes (consistent with node order)
        - G: networkx DiGraph (parent->child edges)
        - parent, left, right: node ids (integers, consistent with agg ordering)
        - node_seq_sums: optional precomputed 1D tensor [N] with sum over features per node
          (if None, computed from agg_tensor as agg_tensor.sum(dim=1))
        Returns: which_pred (int 0/1), donor_parent_pred (node id or None), donor_score (float)
        """
        # Compute scalar sums if not provided
        if node_seq_sums is None:
            node_seq_sums = agg_tensor.sum(dim=1)  # [N]

        # times of daughters
        left_time = float(G.nodes[left].get('node_time', 0.0))
        right_time = float(G.nodes[right].get('node_time', 0.0))

        # candidate sets: nodes with node_time > daughter_time
        node_times = {n: float(G.nodes[n].get('node_time', 0.0)) for n in G.nodes}
        candidates_for_left = [n for n in G.nodes if node_times[n] > left_time and n != left]
        candidates_for_right = [n for n in G.nodes if node_times[n] > right_time and n != right]

        # handle empties
        if len(candidates_for_left) == 0:
            best_left = (None, float('inf'))
        else:
            left_sum = float(node_seq_sums[left].item())
            # compute abs diff between candidate sums and left_sum
            cand_idx_left = torch.tensor(candidates_for_left, dtype=torch.long)
            cand_sums_left = node_seq_sums[cand_idx_left]
            diffs_left = torch.abs(cand_sums_left - left_sum)
            min_idx = torch.argmin(diffs_left).item()
            best_left = (candidates_for_left[min_idx], float(diffs_left[min_idx].item()))

        if len(candidates_for_right) == 0:
            best_right = (None, float('inf'))
        else:
            right_sum = float(node_seq_sums[right].item())
            cand_idx_right = torch.tensor(candidates_for_right, dtype=torch.long)
            cand_sums_right = node_seq_sums[cand_idx_right]
            diffs_right = torch.abs(cand_sums_right - right_sum)
            min_idx = torch.argmin(diffs_right).item()
            best_right = (candidates_for_right[min_idx], float(diffs_right[min_idx].item()))

        # Decide which daughter yields smaller minimal diff -> that daughter considered recipient
        if best_left[1] <= best_right[1]:
            which = 0
            donor_parent = best_left[0]
            donor_score = best_left[1]
        else:
            which = 1
            donor_parent = best_right[0]
            donor_score = best_right[1]

        return which, donor_parent, donor_score


class HGTDetector(nn.Module):
    """
    Orchestrator that contains RecipientFinder and DonorFinder and provides utility compute_aggregates.
    Note: DonorFinder in this design is procedural (no learnable params). If you want learnable
    scoring, convert DonorFinder to use an MLP and return differentiable scores.
    """
    def __init__(self, seq_vec_dim, aux_recipient_dim=3, hgt_threshold=0.5, topk_donors=1):
        super().__init__()
        self.recipient_finder = RecipientFinder(seq_vec_dim, aux_dim=aux_recipient_dim, hidden=[32])
        self.hgt_threshold = hgt_threshold
        self.topk_donors = topk_donors

    def compute_aggregates(self, G: nx.DiGraph, node_seq_matrix: torch.Tensor, gene_presence_mask: torch.Tensor = None):
        """
        Compute subtree aggregates for every node in G.

        - G: networkx.DiGraph with nodes that have integer ids. We assume that the original node ids
             correspond to the row indices of node_seq_matrix for nodes 0..N-1. If your node ids
             are different, you must remap externally.
        - node_seq_matrix: torch.Tensor [N_orig, D] containing leaf one-hot / feature vectors aligned
             with original node ids (0..N_orig-1). If new internal nodes are appended later, extend this
             matrix with zero rows before calling compute_aggregates for the expanded graph.
        - gene_presence_mask: optional 1D torch tensor length N_orig with 1 for leaves that should be counted,
             0 otherwise. If provided, it will zero out rows where mask==0 before aggregation.

        Returns:
          agg: torch.Tensor [N_mod, D] where N_mod == number of nodes in G (list(G.nodes) order).
        Important assumptions / notes:
          - If node_seq_matrix.shape[0] == N_nodes and node ids are 0..N-1 in order, we copy directly.
          - If a new node is added (id == original N), the caller should have appended a zero-row
            to node_seq_matrix so that the total rows match list(G.nodes) order.
        """
        
        N_nodes = len(list(G.nodes))
        N, D = node_seq_matrix.shape
        device = node_seq_matrix.device
    
        node_list = list(G.nodes)
        node_to_idx = {n: i for i, n in enumerate(node_list)}
    
        # Initialize agg in node_list order
        if node_seq_matrix.shape[0] == N_nodes:
            agg = node_seq_matrix.clone().to(device)
        elif node_seq_matrix.shape[0] >= N_nodes:
            # if node_seq_matrix has at least as many rows, try to copy by node id where possible
            agg = torch.zeros((N_nodes, D), dtype=node_seq_matrix.dtype, device=device)
            for n, i in node_to_idx.items():
                if isinstance(n, int) and n < node_seq_matrix.shape[0]:
                    agg[i] = node_seq_matrix[n].to(device)
                else:
                    agg[i] = torch.zeros(D, dtype=node_seq_matrix.dtype, device=device)
        else:
            raise RuntimeError("node_seq_matrix rows != number of nodes; please supply node-ordered matrix or extend it.")
    
        # optional gene presence mask: expects a mask aligned to leaves; map to node_list indices if needed
        if gene_presence_mask is not None:
            # if mask length equals number of nodes, use directly
            if gene_presence_mask.shape[0] == N_nodes:
                full_mask = gene_presence_mask.to(device).to(dtype=torch.float32)
            else:
                # best-effort: assume gene_presence_mask is aligned to some leaf ordering => leave as ones
                full_mask = torch.ones(N_nodes, dtype=torch.float32, device=device)
        else:
            full_mask = torch.ones(N_nodes, dtype=torch.float32, device=device)
    
        agg = agg * full_mask.unsqueeze(1)
    
        # bottom-up: children before parent
        topo = list(nx.topological_sort(G))
        for node in reversed(topo):
            children = list(G.successors(node))
            if len(children) == 0:
                continue
            child_vecs = [agg[node_to_idx[c]].unsqueeze(0) for c in children]
            child_stack = torch.cat(child_vecs, dim=0)  # [C, D]
            agg_val, _ = torch.max(child_stack, dim=0)
            agg[node_to_idx[node]] = agg_val
    
        return agg, node_to_idx


# -------------------------
# Utilities for training
# -------------------------
def tree_distance(G: nx.DiGraph, node_a, node_b, max_penalty=None):
    """
    Compute time-based distance between node_a and node_b in a tree:
      dist(a,b) = 2 * t_MRCA(a,b) - t(a) - t(b)

    We find MRCA by intersecting ancestor sets (including node itself) and selecting
    the one with maximum node_time. If no MRCA or times missing, fallback to max_penalty or len(G.nodes).
    """
    #try:
        # ancestors returns nodes that have a path to the given node (excluding the node itself),
        # so include the node itself.
    anc_a = set(nx.ancestors(G, node_a))
    anc_a.add(node_a)
    anc_b = set(nx.ancestors(G, node_b))
    anc_b.add(node_b)
    common = anc_a.intersection(anc_b)
    if len(common) == 0:
        raise RuntimeError("no common ancestor")
    # pick MRCA as the common ancestor with maximum node_time
    def node_time(n):
        return float(G.nodes[n].get('node_time', float('-inf')))
    mrca = max(common, key=node_time)
    t_mrca = node_time(mrca)
    t_a = node_time(node_a)
    t_b = node_time(node_b)
    # if times are -inf or missing, fallback
    if any([not math.isfinite(t) for t in (t_mrca, t_a, t_b)]):
        raise RuntimeError("invalid times")
    dist = 2.0 * t_mrca - t_a - t_b
    # ensure non-negative
    return max(0.0, float(dist))
    #except Exception:
    #    return float(max_penalty) if max_penalty is not None else float(len(G.nodes))


# -------------------------
# Training loop (modified to apply structure corrections on a copy)
# -------------------------
def train_hgt_detector(model: HGTDetector, list_of_Data, epochs=10, val_frac=0.1, lr=1e-3,
                       lambda_donor=0.1, hgt_threshold=0.5, max_candidates=300,
                       device=None, clip_grad=5.0):
    """
    Training loop that:
      - processes each internal parent node in time order (bottom-up)
      - runs RecipientFinder (binary)
      - if recipient predicted (p >= threshold) -> run DonorFinder to get which and donor_parent
      - if recipient prediction matches ground truth and donor_parent matches ground truth:
          * perform local structure correction on a deepcopy of G
          * recompute aggregates for the modified graph
          * continue processing from the parent node (the loop continues sequentially)
      - computes losses:
          - rec_loss: BCEWithLogitsLoss over recipient predictions (batched per graph)
          - if recipient correctly predicted:
              + add penalty if daughter-edge (which) is wrong (0 if correct, 1 if wrong)
              + add (normalized) distance to true donor (time-based)
        The donor components are scaled by lambda_donor (same as before).
    """

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    N = len(list_of_Data)
    idxs = list(range(N))
    random.shuffle(idxs)
    n_val = max(1, int(val_frac * N))
    val_idx = set(idxs[:n_val])
    train_idx = idxs[n_val:]

    # Estimate pos_weight for BCEWithLogitsLoss on training set (clamped)
    pos_count = 0
    neg_count = 0
    for i in train_idx:
        d = list_of_Data[i]
        recip_map = d.recip_label_map
        for p, (lbl, _) in recip_map.items():
            if lbl == 1:
                pos_count += 1
            else:
                neg_count += 1
    pos_count = max(pos_count, 1)
    neg_count = max(neg_count, 1)
    ratio = neg_count / pos_count
    pos_weight_val = max(1.0, min(ratio, 50.0))  # clamp between 1 and 50
    pos_weight = torch.tensor([pos_weight_val], device=device)
    bce_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    print(f"Training on {len(train_idx)} graphs | Validation on {len(val_idx)} graphs | pos_weight={pos_weight_val:.3f}")

    for epoch in range(1, epochs+1):
        t0 = time.time()
        model.train()
        train_loss = 0.0
        train_rec_loss = 0.0
        train_don_loss = 0.0
        train_samples = 0

        # metrics
        tp = 0; fp = 0; fn = 0
        donor_expected_distance_accum = 0.0
        donor_expected_count = 0

        random.shuffle(train_idx)
        for idx in train_idx:
            data = list_of_Data[idx]
            # move seqs to device once
            # Assumption: data.hot_encoded_nucleotide_sequences rows correspond to node ids 0..M-1
            data.hot_encoded_nucleotide_sequences = data.hot_encoded_nucleotide_sequences.to(device)

            # optional gene presence mask (1D length M) that indicates which leaves count
            gene_presence_mask = getattr(data, 'gene_absence_presence_matrix', None)
            if gene_presence_mask is not None:
                # ensure it's a tensor on device
                if not isinstance(gene_presence_mask, torch.Tensor):
                    gene_presence_mask = torch.tensor(gene_presence_mask, dtype=torch.float32, device=device)
                else:
                    gene_presence_mask = gene_presence_mask.to(device)

            recip_map = data.recip_label_map
            donor_map = data.donor_map

            # We will iterate parents in increasing node_time order (bottom-up)
            parents = [n for n in data.G.nodes if len(list(data.G.successors(n))) == 2]
            if len(parents) == 0:
                continue

            # sort parents by node_time ascending (so we process earliest events first)
            parents_sorted = sorted(parents, key=lambda n: float(data.G.nodes[n].get('node_time', 0.0)))

            # compute aggregates for the original graph (may be recomputed later if we modify structure)
            agg = model.compute_aggregates(data.G, data.hot_encoded_nucleotide_sequences, gene_presence_mask)  # [N,D]
            # precompute scalar sums per node for DonorFinder
            node_seq_sums = agg.sum(dim=1)  # [N]

            # We'll process parents sequentially; collect losses per graph
            graph_rec_logits = []
            graph_rec_labels = []
            graph_donor_penalties = []  # penalties for daughter-edge miss (0/1)
            graph_donor_distances = []  # normalized distances

            # For reproducibility of structure edits, we won't mutate original G. When needed we create a modified copy.
            G_current = data.G  # original graph object (not mutated)
            agg_current = agg
            node_sums_current = node_seq_sums
            node_seq_matrix_current = data.hot_encoded_nucleotide_sequences  # refer to original (will extend when editing)

            # We'll keep an index map telling how many nodes exist currently (to append new nodes if needed)
            current_max_node_id = max([n for n in data.G.nodes]) if len(data.G.nodes) > 0 else -1

            # iterate parents in time order
            for parent in parents_sorted:
                # children (two)
                children = list(G_current.successors(parent))
                if len(children) != 2:
                    # structure may have been altered previously; skip if not binary anymore
                    continue
                left, right = children[0], children[1]

                # build aux features for this parent
                node_time = float(G_current.nodes[parent].get('node_time', 0.0))
                node_level = float(G_current.nodes[parent].get('level', 0.0))
                parent_sum = agg_current[parent].sum().clamp(min=1.0)
                gene_frac = ((agg_current[left].sum() + agg_current[right].sum()) / parent_sum).unsqueeze(0).unsqueeze(1)
                aux = torch.tensor([[node_time, node_level, float(gene_frac.item())]], device=device, dtype=torch.float32)

                # left/right vectors
                left_vec = agg_current[left].unsqueeze(0).to(device)   # [1, D]
                right_vec = agg_current[right].unsqueeze(0).to(device) # [1, D]

                # Recipient forward (single sample)
                hgt_logit = model.recipient_finder(left_vec, right_vec, aux)  # [1]
                graph_rec_logits.append(hgt_logit.squeeze(0))
                lbl, true_rec_child = recip_map.get(parent, (0, None))
                graph_rec_labels.append(float(lbl))

                prob = torch.sigmoid(hgt_logit).item()
                rec_pred = 1 if prob >= hgt_threshold else 0

                # only attempt donor-finding & potential structure correction if model predicts HGT
                # and ground truth label indicates HGT (we will still compute donors to compute loss terms)
                if rec_pred == 1:
                    # DonorFinder decides which daughter is recipient and predicts donor_parent
                    which_pred, donor_parent_pred, donor_score = DonorFinder.find_donor_and_which(
                        agg_current, G_current, parent, left, right, node_seq_sums=node_sums_current
                    )
                    # map which_pred to predicted recipient child id
                    pred_rec_child = children[which_pred]

                    # Now compute donor-related losses/metrics based on ground truth
                    if lbl == 1:
                        # ground truth recipient child and donor parent
                        true_rec_child = true_rec_child  # from recip_map
                        true_donor_parent = donor_map.get(parent, None)  # may be None

                        # 1) daughter-edge correctness penalty: 0 if predicted recipient child == true_rec_child else 1
                        daughter_edge_penalty = 0.0
                        if true_rec_child is not None:
                            if pred_rec_child != true_rec_child:
                                daughter_edge_penalty = 1.0
                        else:
                            # if no true_rec_child provided, we don't penalize (set penalty 0)
                            daughter_edge_penalty = 0.0
                        graph_donor_penalties.append(daughter_edge_penalty)

                        # 2) donor distance: compute time-based distance between predicted donor_parent and true_donor_parent
                        if (true_donor_parent is None) or (donor_parent_pred is None):
                            normalized_dist = 1.0  # maximal penalty if missing
                        else:
                            # time-based distance using tree_distance but using G_current
                            max_pen = len(G_current.nodes)
                            raw_dist = tree_distance(G_current, donor_parent_pred, true_donor_parent, max_penalty=max_pen)
                            # normalize by max_pen to keep in [0,1]
                            normalized_dist = float(raw_dist) / float(max_pen) if max_pen > 0 else float(raw_dist)
                        graph_donor_distances.append(normalized_dist)

                        # Accumulate reporting
                        donor_expected_distance_accum += normalized_dist * float(len(G_current.nodes))
                        donor_expected_count += 1

                        # If both recipient daughter and donor parent are correctly identified,
                        # perform structure correction on a deepcopy of G (so original G remains unchanged for next epoch)
                        if (pred_rec_child == true_rec_child) and (donor_parent_pred == true_donor_parent) and (donor_parent_pred is not None):
                            # perform structure correction on a COPY of the graph
                            G_mod = copy.deepcopy(G_current)
                            # extend node_seq_matrix_current by a zero-row to represent the new internal node
                            M_old = node_seq_matrix_current.shape[0]
                            D = node_seq_matrix_current.shape[1]
                            new_row = torch.zeros((1, D), dtype=node_seq_matrix_current.dtype, device=node_seq_matrix_current.device)
                            node_seq_matrix_current = torch.cat([node_seq_matrix_current, new_row], dim=0)
                            # new node id
                            current_max_node_id += 1
                            new_node = current_max_node_id
                            # compute new_node time as per your formula:
                            rec_child_time = float(G_mod.nodes[pred_rec_child].get('node_time', 0.0))
                            donor_parent_time = float(G_mod.nodes[donor_parent_pred].get('node_time', 0.0))
                            # donor child of the donor edge: pick the child of donor_parent that is on the donor edge to be split.
                            donor_children = list(G_mod.successors(donor_parent_pred))
                            # if donor_parent_pred had multiple children, we need to choose the child that is on the donor edge.
                            # For safety, pick the child with smallest difference in aggregate sums to the recipient sum.
                            if len(donor_children) == 0:
                                # cannot split; skip structural edit
                                pass
                            else:
                                # pick donor_child as the one whose subtree sum is closest to pred_rec_child subtree sum
                                rec_sum = float(agg_current[pred_rec_child].sum().item())
                                cand_sums = [(c, float(agg_current[c].sum().item())) for c in donor_children if c < node_seq_sums.shape[0]]
                                if len(cand_sums) == 0:
                                    donor_child = donor_children[0]
                                else:
                                    donor_child = min(cand_sums, key=lambda x: abs(x[1] - rec_sum))[0]

                                donor_child_time = float(G_mod.nodes[donor_child].get('node_time', 0.0))

                                t1 = rec_child_time + donor_parent_time
                                t2 = donor_child_time + donor_parent_time
                                new_time = 0.5 * max(t1, t2) * 0.5  # as per your "half of maximum of sums" (two halves)
                                # Note: you've described "half of the maximum of 1) sum(rec_child, donor_parent) OR 2) sum(donor_child, donor_parent)".
                                # Implementation sets new_time = 0.5 * max(t1, t2) * 0.5 => effectively 0.25 * max(...).
                                # To be consistent with the verbal spec (half of the max), set:
                                # new_time = 0.5 * max(t1, t2)
                                # Use that instead. (I'll set to the literal half)
                                new_time = 0.5 * max(t1, t2)

                                # add new node and set attributes
                                G_mod.add_node(new_node)
                                G_mod.nodes[new_node]['node_time'] = new_time
                                G_mod.nodes[new_node]['level'] = (G_mod.nodes[pred_rec_child].get('level', 0.0) + G_mod.nodes[donor_parent_pred].get('level', 0.0))/2.0

                                # Rewire recipient edge: remove (recipient_parent -> pred_rec_child), insert (recipient_parent -> new_node) and (new_node -> pred_rec_child)
                                try:
                                    G_mod.remove_edge(parent, pred_rec_child)
                                except Exception:
                                    # maybe edge already changed; continue gracefully
                                    pass
                                G_mod.add_edge(parent, new_node)
                                G_mod.add_edge(new_node, pred_rec_child)

                                # Split donor edge: remove (donor_parent_pred -> donor_child) and add (donor_parent_pred -> new_node) and (new_node -> donor_child)
                                try:
                                    G_mod.remove_edge(donor_parent_pred, donor_child)
                                except Exception:
                                    pass
                                G_mod.add_edge(donor_parent_pred, new_node)
                                G_mod.add_edge(new_node, donor_child)

                                # After structural change, recompute aggregates for the modified graph.
                                agg_current = model.compute_aggregates(G_mod, node_seq_matrix_current, gene_presence_mask)
                                node_sums_current = agg_current.sum(dim=1)
                                # update G_current to modified version for further upward processing
                                G_current = G_mod
                                G_test = G_mod
                                print(G_mod.nodes)
                                print(G_mod.edges)
                                # continue loop (the parent list was computed from original G; we continue on)
                                # Note: parents_sorted may not reflect the new internal node; but per spec we continue from current place.
                                # No extra insertion into parents_sorted needed.

                    else:
                        # ground truth label is 0 but model predicted 1: we can still compute donor penalties as neutral (no add)
                        # append neutral penalties to keep lists aligned
                        graph_donor_penalties.append(0.0)
                        graph_donor_distances.append(0.0)
                else:
                    # model did not predict HGT at this parent; no donor penalties
                    graph_donor_penalties.append(0.0)
                    graph_donor_distances.append(0.0)

            # Now we have per-parent lists (graph_rec_logits, graph_rec_labels, graph_donor_penalties, graph_donor_distances)
            if len(graph_rec_logits) == 0:
                continue

            rec_logits_tensor = torch.stack(graph_rec_logits, dim=0)  # [P]
            rec_labels_tensor = torch.tensor(graph_rec_labels, dtype=torch.float32, device=device)  # [P]

            rec_loss = bce_loss_fn(rec_logits_tensor, rec_labels_tensor)

            # donor losses: average over events where rec label ==1
            donor_penalty_tensor = torch.tensor(graph_donor_penalties, dtype=torch.float32, device=device)
            donor_dist_tensor = torch.tensor(graph_donor_distances, dtype=torch.float32, device=device)

            # Only include donor penalties where ground-truth recipient label == 1
            gt_mask = (rec_labels_tensor == 1.0).to(dtype=torch.float32)
            if gt_mask.sum() > 0:
                avg_daughter_penalty = (donor_penalty_tensor * gt_mask).sum() / gt_mask.sum()
                avg_donor_distance = (donor_dist_tensor * gt_mask).sum() / gt_mask.sum()
                donor_loss = avg_daughter_penalty + avg_donor_distance
            else:
                avg_daughter_penalty = torch.tensor(0.0, device=device)
                avg_donor_distance = torch.tensor(0.0, device=device)
                donor_loss = torch.tensor(0.0, device=device)

            loss = rec_loss + lambda_donor * donor_loss

            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()

            train_loss += float(loss.item())
            train_rec_loss += float(rec_loss.item())
            train_don_loss += float(donor_loss.item())
            train_samples += 1

            # update metrics for recipient detection (on original per-parent predictions)
            probs = torch.sigmoid(rec_logits_tensor).detach().cpu().numpy().tolist()
            for i_parent, lbl in enumerate(graph_rec_labels):
                true_lbl = int(lbl)
                pred_prob = float(probs[i_parent])
                pred_class = 1 if pred_prob >= hgt_threshold else 0
                if pred_class == 1 and true_lbl == 1:
                    tp += 1
                if pred_class == 1 and true_lbl == 0:
                    fp += 1
                if pred_class == 0 and true_lbl == 1:
                    fn += 1

        # End epoch training stats
        avg_train_loss = train_loss / max(1, train_samples)
        avg_train_rec = train_rec_loss / max(1, train_samples)
        avg_train_don = train_don_loss / max(1, train_samples)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        avg_expected_donor_distance = (donor_expected_distance_accum / donor_expected_count) if donor_expected_count > 0 else 0.0

        # Validation pass (keeps previous logic but adapted to new RecipientFinder)
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            val_rec_loss = 0.0
            val_don_loss = 0.0
            val_samples = 0
            v_tp = v_fp = v_fn = 0
            v_donor_expected_dist_acc = 0.0
            v_donor_expected_cnt = 0

            for idx in val_idx:
                data = list_of_Data[idx]
                data.hot_encoded_nucleotide_sequences = data.hot_encoded_nucleotide_sequences.to(device)
                gene_presence_mask = getattr(data, 'gene_absence_presence_matrix', None)
                if gene_presence_mask is not None:
                    if not isinstance(gene_presence_mask, torch.Tensor):
                        gene_presence_mask = torch.tensor(gene_presence_mask, dtype=torch.float32, device=device)
                    else:
                        gene_presence_mask = gene_presence_mask.to(device)

                recip_map = data.recip_label_map
                donor_map = data.donor_map
                parents = [n for n in data.G.nodes if len(list(data.G.successors(n))) == 2]
                if len(parents) == 0:
                    continue
                parents_sorted = sorted(parents, key=lambda n: float(data.G.nodes[n].get('node_time', 0.0)))
                agg, node_to_idx = model.compute_aggregates(data.G, data.hot_encoded_nucleotide_sequences, gene_presence_mask)
                node_seq_sums = agg.sum(dim=1)

                graph_rec_logits = []
                graph_rec_labels = []
                graph_donor_penalties = []
                graph_donor_distances = []

                G_current = data.G
                agg_current = agg
                node_sums_current = node_sums
                node_seq_matrix_current = data.hot_encoded_nucleotide_sequences
                current_max_node_id = max([n for n in data.G.nodes]) if len(data.G.nodes) > 0 else -1

                for parent in parents_sorted:
                    children = list(G_current.successors(parent))
                    if len(children) != 2:
                        continue
                    left, right = children[0], children[1]
                    node_time = float(G_current.nodes[parent].get('node_time', 0.0))
                    node_level = float(G_current.nodes[parent].get('level', 0.0))
                    parent_sum = agg_current[parent].sum().clamp(min=1.0)
                    gene_frac = ((agg_current[left].sum() + agg_current[right].sum()) / parent_sum).unsqueeze(0).unsqueeze(1)
                    aux = torch.tensor([[node_time, node_level, float(gene_frac.item())]], device=device, dtype=torch.float32)
                    left_vec = agg_current[left].unsqueeze(0).to(device)
                    right_vec = agg_current[right].unsqueeze(0).to(device)

                    hgt_logit = model.recipient_finder(left_vec, right_vec, aux)
                    graph_rec_logits.append(hgt_logit.squeeze(0))
                    lbl, true_rec_child = recip_map.get(parent, (0, None))
                    graph_rec_labels.append(float(lbl))
                    prob = torch.sigmoid(hgt_logit).item()
                    rec_pred = 1 if prob >= hgt_threshold else 0

                    if rec_pred == 1:
                        which_pred, donor_parent_pred, donor_score = DonorFinder.find_donor_and_which(
                            agg_current, G_current, parent, left, right, node_seq_sums=node_sums_current
                        )
                        pred_rec_child = children[which_pred]
                        if lbl == 1:
                            true_rec_child = true_rec_child
                            true_donor_parent = donor_map.get(parent, None)
                            daughter_edge_penalty = 0.0
                            if true_rec_child is not None:
                                if pred_rec_child != true_rec_child:
                                    daughter_edge_penalty = 1.0
                            graph_donor_penalties.append(daughter_edge_penalty)
                            if (true_donor_parent is None) or (donor_parent_pred is None):
                                normalized_dist = 1.0
                            else:
                                max_pen = len(G_current.nodes)
                                raw_dist = tree_distance(G_current, donor_parent_pred, true_donor_parent, max_penalty=max_pen)
                                normalized_dist = float(raw_dist) / float(max_pen) if max_pen > 0 else float(raw_dist)
                            graph_donor_distances.append(normalized_dist)
                            v_donor_expected_dist_acc += normalized_dist * float(len(G_current.nodes))
                            v_donor_expected_cnt += 1
                            if (pred_rec_child == true_rec_child) and (donor_parent_pred == true_donor_parent) and (donor_parent_pred is not None):
                                # apply same structural edit on a deepcopy (kept local to validation)
                                G_mod = copy.deepcopy(G_current)
                                M_old = node_seq_matrix_current.shape[0]
                                D = node_seq_matrix_current.shape[1]
                                new_row = torch.zeros((1, D), dtype=node_seq_matrix_current.dtype, device=node_seq_matrix_current.device)
                                node_seq_matrix_current = torch.cat([node_seq_matrix_current, new_row], dim=0)
                                current_max_node_id += 1
                                new_node = current_max_node_id
                                rec_child_time = float(G_mod.nodes[pred_rec_child].get('node_time', 0.0))
                                donor_parent_time = float(G_mod.nodes[donor_parent_pred].get('node_time', 0.0))
                                donor_children = list(G_mod.successors(donor_parent_pred))
                                if len(donor_children) == 0:
                                    pass
                                else:
                                    rec_sum = float(agg_current[pred_rec_child].sum().item())
                                    cand_sums = [(c, float(agg_current[c].sum().item())) for c in donor_children if c < node_sums.shape[0]]
                                    if len(cand_sums) == 0:
                                        donor_child = donor_children[0]
                                    else:
                                        donor_child = min(cand_sums, key=lambda x: abs(x[1] - rec_sum))[0]
                                    donor_child_time = float(G_mod.nodes[donor_child].get('node_time', 0.0))
                                    t1 = rec_child_time + donor_parent_time
                                    t2 = donor_child_time + donor_parent_time
                                    new_time = 0.5 * max(t1, t2)
                                    G_mod.add_node(new_node)
                                    G_mod.nodes[new_node]['node_time'] = new_time
                                    G_mod.nodes[new_node]['level'] = (G_mod.nodes[pred_rec_child].get('level', 0.0) + G_mod.nodes[donor_parent_pred].get('level', 0.0))/2.0
                                    try:
                                        G_mod.remove_edge(parent, pred_rec_child)
                                    except Exception:
                                        pass
                                    G_mod.add_edge(parent, new_node)
                                    G_mod.add_edge(new_node, pred_rec_child)
                                    try:
                                        G_mod.remove_edge(donor_parent_pred, donor_child)
                                    except Exception:
                                        pass
                                    G_mod.add_edge(donor_parent_pred, new_node)
                                    G_mod.add_edge(new_node, donor_child)
                                    agg_current = model.compute_aggregates(G_mod, node_seq_matrix_current, gene_presence_mask)
                                    node_sums_current = agg_current.sum(dim=1)
                                    G_current = G_mod
                        else:
                            graph_donor_penalties.append(0.0)
                            graph_donor_distances.append(0.0)
                    else:
                        graph_donor_penalties.append(0.0)
                        graph_donor_distances.append(0.0)

                if len(graph_rec_logits) == 0:
                    continue

                rec_logits_tensor = torch.stack(graph_rec_logits, dim=0)
                rec_labels_tensor = torch.tensor(graph_rec_labels, dtype=torch.float32, device=device)
                rec_loss = bce_loss_fn(rec_logits_tensor, rec_labels_tensor)

                donor_penalty_tensor = torch.tensor(graph_donor_penalties, dtype=torch.float32, device=device)
                donor_dist_tensor = torch.tensor(graph_donor_distances, dtype=torch.float32, device=device)

                gt_mask = (rec_labels_tensor == 1.0).to(dtype=torch.float32)
                if gt_mask.sum() > 0:
                    avg_daughter_penalty = (donor_penalty_tensor * gt_mask).sum() / gt_mask.sum()
                    avg_donor_distance = (donor_dist_tensor * gt_mask).sum() / gt_mask.sum()
                    donor_loss = avg_daughter_penalty + avg_donor_distance
                else:
                    avg_daughter_penalty = 0.0
                    avg_donor_distance = 0.0
                    donor_loss = 0.0

                loss = float(rec_loss.item()) + lambda_donor * float(donor_loss)

                val_loss += loss
                val_rec_loss += float(rec_loss.item())
                val_don_loss += float(donor_loss)
                val_samples += 1

                probs = torch.sigmoid(rec_logits_tensor).cpu().numpy().tolist()
                for i_parent, lbl in enumerate(graph_rec_labels):
                    true_lbl = int(lbl)
                    pred_prob = float(probs[i_parent])
                    pred_class = 1 if pred_prob >= hgt_threshold else 0
                    if pred_class == 1 and true_lbl == 1:
                        v_tp += 1
                    if pred_class == 1 and true_lbl == 0:
                        v_fp += 1
                    if pred_class == 0 and true_lbl == 1:
                        v_fn += 1

            avg_val_loss = val_loss / max(1, val_samples)
            avg_val_rec = val_rec_loss / max(1, val_samples)
            avg_val_don = val_don_loss / max(1, val_samples)
            v_precision = v_tp / (v_tp + v_fp) if (v_tp + v_fp) > 0 else 0.0
            v_recall = v_tp / (v_tp + v_fn) if (v_tp + v_fn) > 0 else 0.0
            v_f1 = 2 * v_precision * v_recall / (v_precision + v_recall) if (v_precision + v_recall) > 0 else 0.0
            avg_v_donor_expected = (v_donor_expected_dist_acc / v_donor_expected_cnt) if v_donor_expected_cnt > 0 else 0.0

        # scheduler step (use validation loss)
        scheduler.step(avg_val_loss)

        t1 = time.time()
        print(f"Epoch {epoch}/{epochs} | time {t1-t0:.1f}s")
        print(f" TRAIN loss {avg_train_loss:.6f} (rec {avg_train_rec:.6f}, don {avg_train_don:.6f}) "
              f"Prec/Rec/F1 {precision:.3f}/{recall:.3f}/{f1:.3f} | avg_expected_donor_dist {avg_expected_donor_distance:.3f}")
        print(f" VAL   loss {avg_val_loss:.6f} (rec {avg_val_rec:.6f}, don {avg_val_don:.6f}) "
              f"Prec/Rec/F1 {v_precision:.3f}/{v_recall:.3f}/{v_f1:.3f} | avg_expected_donor_dist {avg_v_donor_expected:.3f}")
        print("-" * 120)

    return model


# -------------------------
# Example usage snippet (unchanged seeds + instantiation)
# -------------------------
random.seed(42)
torch.manual_seed(42)

# NOTE: the following assumes you have variable `list_of_Data` already defined,
# and each data object has attributes:
#   - G: networkx.DiGraph
#   - hot_encoded_nucleotide_sequences: torch.Tensor [M_orig, D]
#   - recip_label_map: dict parent -> (label:int, true_rec_child:int or None)
#   - donor_map: dict parent -> true_donor_parent (or None)
#   - optional: gene_absence_presence_matrix: 1D mask aligned with hot_encoded_nucleotide_sequences rows
#
# Please ensure node ids in G for the original (unmodified) tree are integers 0..M_orig-1
# and hot_encoded_nucleotide_sequences rows are ordered accordingly. If not, adapt mapping.

# choose a sample data to get seq_vec_dim
sample_data = random.choice(list_of_Data)
seq_vec_dim = sample_data.hot_encoded_nucleotide_sequences.shape[1]
model = HGTDetector(seq_vec_dim=seq_vec_dim, hgt_threshold=0.5, topk_donors=1)

trained_model = train_hgt_detector(model, list_of_Data,
                                  epochs=4,
                                  val_frac=0.1,
                                  lr=1e-3,
                                  lambda_donor=0.1,
                                  hgt_threshold=0.5,
                                  max_candidates=300,
                                  device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                                  clip_grad=5.0)


Training on 90 graphs | Validation on 10 graphs | pos_weight=13.400
Epoch 1/4 | time 0.3s
 TRAIN loss 1.845382 (rec 1.845382, don 0.000000) Prec/Rec/F1 0.000/0.000/0.000 | avg_expected_donor_dist 0.000
 VAL   loss 1.166914 (rec 1.166914, don 0.000000) Prec/Rec/F1 0.000/0.000/0.000 | avg_expected_donor_dist 0.000
------------------------------------------------------------------------------------------------------------------------
[0, 1, 2, 3, 4, 8, 7, 6, 5, 9]
[(8, 7), (8, 9), (7, 2), (7, 5), (6, 0), (6, 1), (5, 4), (5, 3), (9, 6)]
[(5, 3), (5, 4), (8, 0), (8, 9), (7, 6), (7, 9), (6, 2), (6, 5), (9, 1), (9, 7)]
[0, 1, 2, 3, 4, 5, 8, 7, 6, 9]


NetworkXUnfeasible: Graph contains a cycle or graph changed during iteration

In [18]:
G_mod

NameError: name 'G_mod' is not defined