In [281]:
import os
import random
import h5py
import pickle
import torch
import networkx as nx
from torch_geometric.data import Data
import numpy as np

def one_hot_encode(sequences, gene_present, gene_length, alphabet=['A','C','T','G','-']):
    """
    sequences: List of strings (DNA sequences)
    gene_present: np.array(bool) oder Torch Tensor, gleiche Länge wie sequences
    gene_length: int, fixe Länge für das Hot-Encoding
    alphabet: list, Zeichenalphabet
    """
    num_samples = len(sequences)
    num_chars = len(alphabet)
    char_to_idx = {c:i for i,c in enumerate(alphabet)}
    sequences_str = [s.decode('utf-8') for s in sequences]
    gene_present = np.array(gene_present, dtype=bool)
    
    # 1️⃣ Leere Batch-Matrix vorbereiten: (num_samples, gene_length, num_chars)
    batch = np.zeros((num_samples, gene_length, num_chars), dtype=np.float32)
    
    # 2️⃣ Hot-Encode alle Sequenzen
    for i, seq in enumerate(sequences_str):
        if gene_present[i]:
            L = min(len(seq), gene_length)  # abschneiden
            for j, c in enumerate(seq[:L]):
                if c in char_to_idx:
                    batch[i, j, char_to_idx[c]] = 1.0
        elif gene_present[i] == 0:
            batch[i, :, :] = -1.0
            
    # 3️⃣ Zufällige, aber konsistente Spaltenpermutation
    perm = np.random.permutation(gene_length)
    batch = batch[:, perm, :]
    
    # 4️⃣ Optional: Flatten zu Vektor (num_samples, gene_length*num_chars)
    batch_flat = batch.reshape(num_samples, -1)
    
    return torch.tensor(batch_flat)  # shape: (num_samples, gene_length*num_chars)

def load_file(file):

    gene_length = 300
    #nucleotide_mutation_rate = 0.1
    
    with h5py.File(file, "r") as f:
            grp = f["results"]
            # Load graph_properties (pickle stored in dataset)
            graph_properties = pickle.loads(grp["graph_properties"][()])
    
            # Unpack graph properties
            nodes = torch.tensor(graph_properties[0])                # [num_nodes]
            edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
            coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]
    
            # Load datasets instead of attrs
            gene_absence_presence_matrix = grp["gene_absence_presence_matrix"][()]
            nucleotide_sequences = grp["nucleotide_sequences"][()]
            #children_gene_nodes_loss_events = grp["children_gene_nodes_loss_events"][()]
        
            # Load HGT events (simplified)
            hgt_events = {}
            hgt_grp_simpl = grp["nodes_hgt_events_simplified"]
            for site_id in hgt_grp_simpl.keys():
                hgt_events[int(site_id)] = hgt_grp_simpl[site_id][()]

            hot_encoded_nucleotide_sequences = one_hot_encode(nucleotide_sequences, gene_absence_presence_matrix, gene_length)

            # Fill the remaining nodes with zeros.
            pad_rows = len(nodes) - len(nucleotide_sequences)
            pad = torch.zeros((pad_rows, hot_encoded_nucleotide_sequences.shape[1]), dtype=hot_encoded_nucleotide_sequences.dtype)
            hot_encoded_nucleotide_sequences = torch.cat([hot_encoded_nucleotide_sequences, pad], dim=0)
        
            G = nx.DiGraph()
            
            # Füge Knoten hinzu (optional mit Koordinaten als Attribut)
            for i, node_id in enumerate(nodes.tolist()):
                G.add_node(node_id, node_time = coords[:, i].tolist()[5])
            
            # Füge Kanten hinzu
            edge_list = edges.tolist()
            for src, dst in zip(edge_list[0], edge_list[1]):
                G.add_edge(src, dst)
            
            # Collect all recipient_parent_nodes from all sites
            recipient_parent_nodes = set()
            for site_id in hgt_grp_simpl.keys():
                arr = hgt_grp_simpl[site_id][()]  # load dataset as numpy structured array
                recipient_parent_nodes.update(arr["recipient_child_node"].tolist())
            
            # Build theta_gains: 1 if node is in recipient_parent_nodes, else 0
            theta_gains = torch.tensor(
                [1 if node in recipient_parent_nodes else 0 for node in range(len(G.nodes))],
                dtype=torch.long
            )

            level = {n: 0 for n in G.nodes}  # Leaves haben Level 0
            
            # 3. Topologische Sortierung (damit Kinder vor Eltern behandelt werden)
            for node in reversed(list(nx.topological_sort(G))):
                successors = list(G.successors(node))
                if successors:
                    level[node] = 1 + max(level[s] for s in successors)
            
            # 4. Level als Attribut setzen
            nx.set_node_attributes(G, level, "level")

            ### Add candidate egdes, i.e. potential hgt edges:
        
            # Hole alle Knoten und ihre Zeiten
            node_times = {n: G.nodes[n]['node_time'] for n in G.nodes}
            sorted_nodes = sorted(node_times.keys(), key=lambda n: node_times[n])
            
            # Füge Kanten hinzu
            existing_edges = set(zip(edges[0].tolist(), edges[1].tolist()))
            candidate_edges = []
            for i, src in enumerate(sorted_nodes):
                t_src = node_times[src]
                for dst in sorted_nodes[i+1:]:  # nur spätere Knoten
                    t_dst = node_times[dst]
                    if t_dst > t_src:
                        if (dst, src) not in existing_edges:  # vermeidet doppelte
                            #G.add_edge(src, dst)
                            candidate_edges.append((dst, src))
            if candidate_edges:  
                candidate_edges = torch.tensor(candidate_edges, dtype=torch.long).t()
            else:
                candidate_edges = torch.empty((2,0), dtype=torch.long)
                    
            data = Data(
                nucleotide_sequences = nucleotide_sequences,
                hot_encoded_nucleotide_sequences = hot_encoded_nucleotide_sequences,       # Node Features [num_nodes, 2]
                edge_index = edges[[1, 0], :],        # Edge Index [2, num_edges]
                candidate_edges = candidate_edges[[1, 0], :],
                y = theta_gains,            # Labels [num_nodes]
                file = file,
                G = G,
                recipient_parent_nodes = recipient_parent_nodes,
                gene_absence_presence_matrix = gene_absence_presence_matrix
            )

    return data


# Pfad zum Windows-Ordner (aus Linux/WSL-Sicht)
folder = "/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks"

# Alle Dateien im Ordner (nur reguläre Dateien, keine Unterordner)
files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

file = random.choice(files)
data = load_file(os.path.join(folder, file))


print(data.nucleotide_sequences)
print(data.hot_encoded_nucleotide_sequences)



[b'-------------------------------------------------------------------------------------------------------------------------------------'
 b'CCGC---C-CAGC-G-CGG-CAATTC-ACGGGTG-AG-CCT-C-GGATT-GTTG-CGGTTTGCG-CATTAATAAGC--TC--CGATGGTTGTG--AAT--GACCATG-TACTCC-CTCTGCC-GTCTGACTAT'
 b'CCGACT---TCGACGC--GAGAACTCTACG-ACGTGGACTA-C-G-AGTACTTA-GGATTTGCC-CCTC-A-AAAC--TCC-CGTT-CGTGACC-AATG--TGGC-C-TAAT-CT--CTGCCTGTCTGATCTA'
 b'CACC----GCA-CGGCC-CAGAACGATTCCGACC-AA-GTTACCGGGGGAG-TG-GGAATTGCGATCCC-ATAAACATCC-GACACGCGTGTGCAATTG-GTCCTTC-TTCTCC-C-CGGAC-AACTACTCTT'
 b'CACC-----CA-CGGCC-CAGAACGATACGGACC-AA-GTT---AGGGGAG-TGGGTAATTGCGATCCC-ATAAACATCC-GCCACGCGCATGCATTCG-GTCCTTCA-TCCCA-C-CCAAT-AAGTGCTCTT']
tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  ...,  0.,  0.,  0.],
        ...,
        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  ...,  0.,  0.,  0.]])


  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]


In [231]:
torch.tensor([1 if node in {4, 6} else 0 for node in range(10)],
                dtype=torch.long
            )

tensor([0, 0, 0, 0, 1, 0, 1, 0, 0, 0])

In [238]:
file = random.choice(files)
data = load_file(os.path.join(folder, file))

[(7, 3, 3, 7, 5)]
[3]
[0, 1, 2, 3, 4, 8, 7, 6, 5]


  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]


In [287]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
import math

# ======================

# 1. DAG Positional Encoding

# ======================

class DAGPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding based on node level (depth)."""
    def __init__(self, d_model, max_depth=200):
        super().__init__()
        self.d_model = d_model
        self.max_depth = max_depth
        pe = torch.zeros(max_depth, d_model)
        position = torch.arange(0, max_depth, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, node_levels):
        """Return positional encodings for node levels (shape: [num_nodes, d_model])."""
        node_levels = node_levels.clamp(0, self.max_depth - 1)
        return self.pe[node_levels]


# ======================

# 2. Graph Attention Layer with learnable edge weights

# ======================

class WeightedGraphAttention(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1):
        super().__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads

        self.W = nn.Linear(in_channels, heads * out_channels, bias=False)
        self.attn_src = nn.Parameter(torch.Tensor(heads * out_channels))
        self.attn_dst = nn.Parameter(torch.Tensor(heads * out_channels))
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.attn_src.unsqueeze(0))
        nn.init.xavier_uniform_(self.attn_dst.unsqueeze(0))
    
    def forward(self, x, edge_index, edge_probs):
        # x: [N, in_channels]
        x = self.W(x)  # [N, H*C]
        return self.propagate(edge_index, x=x, edge_probs=edge_probs)

    def message(self, x_j, x_i, edge_probs, index, ptr, size_i):
        # Compute attention scores
        alpha = (x_i * self.attn_dst + x_j * self.attn_src).sum(dim=-1)
        alpha = alpha / math.sqrt(self.out_channels)
        
        # Add edge bias (log-prob)
        alpha = alpha + torch.log(edge_probs + 1e-8)
        alpha = F.leaky_relu(alpha, 0.2)
        alpha = softmax(alpha, index, ptr, num_nodes=size_i)
        
        # Scale message directly by edge probability (stronger influence)
        msg = x_j * alpha.unsqueeze(-1) * edge_probs.unsqueeze(-1)
        return msg



# ======================

# 3. DAG Transformer Network

# ======================

class DAGTransformer(nn.Module):
    """Full DAG-based network with learnable edge weights for candidate HGTs."""
    def __init__(self, in_dim, hidden_dim=128, out_dim=1, heads=4, num_layers=2, max_depth=200):
        super().__init__()
        self.pos_enc = DAGPositionalEncoding(hidden_dim, max_depth)
        self.input_proj = nn.Linear(in_dim, hidden_dim)
        self.heads = heads
        self.out_channels = hidden_dim // heads
        self.layers = nn.ModuleList([
            WeightedGraphAttention(hidden_dim, self.out_channels, heads=heads)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, data, edge_logits):
        x = data.hot_encoded_nucleotide_sequences
        base_edges = data.edge_index
        candidate_edges = data.candidate_edges
        G = data.G

        # Combine base edges and candidate edges
        if candidate_edges.numel() > 0:
            combined_edges = torch.cat([base_edges, candidate_edges], dim=1)
        else:
            combined_edges = base_edges

        # Convert logits to probabilities
        if candidate_edges.numel() > 0:
            edge_probs_candidates = torch.sigmoid(edge_logits)
            edge_probs = torch.cat([
                torch.ones(base_edges.shape[1], device=x.device),
                edge_probs_candidates
            ])
        else:
            edge_probs = torch.ones(base_edges.shape[1], device=x.device)

        # Project input features
        x = self.input_proj(x)

        # Add positional encoding
        node_levels = torch.tensor([G.nodes[n]['level'] for n in G.nodes], dtype=torch.long, device=x.device)
        x = x + self.pos_enc(node_levels)

        # Multi-layer attention propagation
        for layer in self.layers:
            x = layer(x, combined_edges, edge_probs)
            x = F.relu(x)

        # Output prediction per node
        out = self.fc_out(x).squeeze(-1)
        return out, edge_probs
    

# ======================

# 4. Training Example

# ======================
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1️⃣ Modell nur einmal definieren
in_dim = None
for file in files:
    data = load_file(os.path.join(folder, file))
    in_dim = data.hot_encoded_nucleotide_sequences.shape[1]
    break

model = DAGTransformer(
    in_dim=in_dim,
    hidden_dim=128,
    out_dim=1,
    heads=4,
    num_layers=2,
    max_depth=200
).to(device)

# 2️⃣ Candidate-Edge-Logits global für alle Bäume persistent speichern
global_edge_logits = {}
for file in files:
    path = os.path.join(folder, file)
    data = load_file(path)
    if data.candidate_edges.numel() > 0:
        global_edge_logits[file] = nn.Parameter(
            torch.full((data.candidate_edges.shape[1],), -3.0, device=device)
        )
    else:
        global_edge_logits[file] = None

# 3️⃣ Optimizer für Modell + alle Edge-Logits
params = list(model.parameters()) + [p for p in global_edge_logits.values() if p is not None]
optimizer = torch.optim.Adam(params, lr=2e-3, weight_decay=1e-5)

# 4️⃣ Training auf allen Bäumen
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    for file in files:
        path = os.path.join(folder, file)
        data = load_file(path)
        data = data.to(device)

        edge_logits = global_edge_logits[file]

        y_true = data.y.float().to(device)

        num_pos = (y_true == 1).sum()
        num_neg = (y_true == 0).sum()

        optimizer.zero_grad()
        out, edge_probs = model(data, edge_logits)

        if edge_logits is not None and data.candidate_edges.numel() > 0:
            candidate_scores = edge_probs[data.edge_index.shape[1]:]

            # ✅ Ziel: Source-Knotenlabel (nicht Zielknoten)
            target_scores = y_true[data.candidate_edges[0]]

            # pos_weight für HGT-Knoten (damit sie stärker gewichtet werden)
            weight = torch.ones_like(target_scores)
            weight[target_scores == 1] = num_neg / (num_pos + 1e-6)
            
            loss_edge = F.binary_cross_entropy(candidate_scores, target_scores, weight=weight)

            # Optional: Node-based Stabilisierung
            loss_node = F.binary_cross_entropy_with_logits(out, y_true)

            # Sparsity-Term: Bestraft zu viele aktive Edges
            loss_sparsity = 0.01 * candidate_scores.sum()

            # Gesamtverlust
            #loss = loss_edge + 0.1 * loss_node + loss_sparsity
            loss = loss_edge
        else:
            loss = F.binary_cross_entropy_with_logits(out, y_true)

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f"Epoch {epoch:03d} | Avg Loss: {epoch_loss / len(files):.4f}")
    
# 5️⃣ Nach Training: Candidate-Edge-Wahrscheinlichkeiten pro Baum ausgeben
for file in files:
    edge_logits = global_edge_logits[file]
    if edge_logits is not None:
        learned_p = torch.sigmoid(edge_logits).detach().cpu().numpy()
        print(f"File {file} | Learned candidate edge probabilities (top 20): {learned_p[:20]}")


  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]


Epoch 000 | Avg Loss: 1.1823
Epoch 001 | Avg Loss: 1.1815
Epoch 002 | Avg Loss: 1.1807
Epoch 003 | Avg Loss: 1.1800
Epoch 004 | Avg Loss: 1.1792
Epoch 005 | Avg Loss: 1.1784
Epoch 006 | Avg Loss: 1.1776
Epoch 007 | Avg Loss: 1.1768
Epoch 008 | Avg Loss: 1.1760
Epoch 009 | Avg Loss: 1.1752
Epoch 010 | Avg Loss: 1.1744
Epoch 011 | Avg Loss: 1.1736
Epoch 012 | Avg Loss: 1.1728
Epoch 013 | Avg Loss: 1.1720
Epoch 014 | Avg Loss: 1.1712
Epoch 015 | Avg Loss: 1.1704
Epoch 016 | Avg Loss: 1.1696
Epoch 017 | Avg Loss: 1.1688
Epoch 018 | Avg Loss: 1.1680
Epoch 019 | Avg Loss: 1.1672
Epoch 020 | Avg Loss: 1.1664
Epoch 021 | Avg Loss: 1.1656
Epoch 022 | Avg Loss: 1.1648
Epoch 023 | Avg Loss: 1.1641
Epoch 024 | Avg Loss: 1.1633
Epoch 025 | Avg Loss: 1.1625
Epoch 026 | Avg Loss: 1.1617
Epoch 027 | Avg Loss: 1.1609
Epoch 028 | Avg Loss: 1.1601
Epoch 029 | Avg Loss: 1.1593
Epoch 030 | Avg Loss: 1.1585
Epoch 031 | Avg Loss: 1.1577
Epoch 032 | Avg Loss: 1.1569
Epoch 033 | Avg Loss: 1.1562
Epoch 034 | Av

In [147]:
data

Data(edge_index=[2, 8], y=[9], nucleotide_sequences=[5], hot_encoded_nucleotide_sequences=[9, 5000], candidate_edges=[2, 18], file='/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks/simulation_10.h5', G=DiGraph with 9 nodes and 8 edges, recipient_parent_nodes=[0], gene_absence_presence_matrix=[5])

In [79]:
len(G.nucleotide_sequences[4])*5

635

In [89]:
print(G.gene_absence_presence_matrix)
sum(G.hot_encoded_nucleotide_sequences[4])

[1 1 1 0 0]


tensor(0.)

In [103]:
file

'/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks/simulation_10.h5'