In [1]:
import os
import random
import h5py
import pickle
import torch
import networkx as nx
from torch_geometric.data import Data
import numpy as np

def one_hot_encode(sequences, gene_present, gene_length, alphabet=['A','C','T','G','-']):
    """
    sequences: List of strings (DNA sequences)
    gene_present: np.array(bool) oder Torch Tensor, gleiche Länge wie sequences
    gene_length: int, fixe Länge für das Hot-Encoding
    alphabet: list, Zeichenalphabet
    """
    num_samples = len(sequences)
    num_chars = len(alphabet)
    char_to_idx = {c:i for i,c in enumerate(alphabet)}
    sequences_str = [s.decode('utf-8') for s in sequences]
    gene_present = np.array(gene_present, dtype=bool)
    
    # 1️⃣ Leere Batch-Matrix vorbereiten: (num_samples, gene_length, num_chars)
    batch = np.zeros((num_samples, gene_length, num_chars), dtype=np.float32)
    
    # 2️⃣ Hot-Encode alle Sequenzen
    for i, seq in enumerate(sequences_str):
        if gene_present[i]:
            L = min(len(seq), gene_length)  # abschneiden
            for j, c in enumerate(seq[:L]):
                if c in char_to_idx:
                    batch[i, j, char_to_idx[c]] = 1.0
        elif gene_present[i] == 0:
            batch[i, :, :] = -1.0
            
    # 3️⃣ Zufällige, aber konsistente Spaltenpermutation
    perm = np.random.permutation(gene_length)
    batch = batch[:, perm, :]
    
    # 4️⃣ Optional: Flatten zu Vektor (num_samples, gene_length*num_chars)
    batch_flat = batch.reshape(num_samples, -1)
    
    return torch.tensor(batch_flat)  # shape: (num_samples, gene_length*num_chars)

def load_file(file):

    gene_length = 300
    #nucleotide_mutation_rate = 0.1
    
    with h5py.File(file, "r") as f:
            grp = f["results"]
            # Load graph_properties (pickle stored in dataset)
            graph_properties = pickle.loads(grp["graph_properties"][()])
    
            # Unpack graph properties
            nodes = torch.tensor(graph_properties[0])                # [num_nodes]
            edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
            coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]
    
            # Load datasets instead of attrs
            gene_absence_presence_matrix = grp["gene_absence_presence_matrix"][()]
            nucleotide_sequences = grp["nucleotide_sequences"][()]
            #children_gene_nodes_loss_events = grp["children_gene_nodes_loss_events"][()]
        
            # Load HGT events (simplified)
            hgt_events = {}
            hgt_grp_simpl = grp["nodes_hgt_events_simplified"]
            for site_id in hgt_grp_simpl.keys():
                hgt_events[int(site_id)] = hgt_grp_simpl[site_id][()]

            hot_encoded_nucleotide_sequences = one_hot_encode(nucleotide_sequences, gene_absence_presence_matrix, gene_length)

            # Fill the remaining nodes with zeros.
            pad_rows = len(nodes) - len(nucleotide_sequences)
            pad = torch.zeros((pad_rows, hot_encoded_nucleotide_sequences.shape[1]), dtype=hot_encoded_nucleotide_sequences.dtype)
            hot_encoded_nucleotide_sequences = torch.cat([hot_encoded_nucleotide_sequences, pad], dim=0)
        
            G = nx.DiGraph()
            
            # Füge Knoten hinzu (optional mit Koordinaten als Attribut)
            for i, node_id in enumerate(nodes.tolist()):
                G.add_node(node_id, node_time = coords[:, i].tolist()[5])
            
            # Füge Kanten hinzu
            edge_list = edges.tolist()
            for src, dst in zip(edge_list[0], edge_list[1]):
                G.add_edge(src, dst)
            
            # Collect all recipient_parent_nodes from all sites
            recipient_parent_nodes = set()
            for site_id in hgt_grp_simpl.keys():
                arr = hgt_grp_simpl[site_id][()]  # load dataset as numpy structured array
                recipient_parent_nodes.update(arr["recipient_child_node"].tolist())
            
            # Build theta_gains: 1 if node is in recipient_parent_nodes, else 0
            theta_gains = torch.tensor(
                [1 if node in recipient_parent_nodes else 0 for node in range(len(G.nodes))],
                dtype=torch.long
            )

            level = {n: 0 for n in G.nodes}  # Leaves haben Level 0
            
            # 3. Topologische Sortierung (damit Kinder vor Eltern behandelt werden)
            for node in reversed(list(nx.topological_sort(G))):
                successors = list(G.successors(node))
                if successors:
                    level[node] = 1 + max(level[s] for s in successors)
            
            # 4. Level als Attribut setzen
            nx.set_node_attributes(G, level, "level")

            ### Add candidate egdes, i.e. potential hgt edges:
        
            # Hole alle Knoten und ihre Zeiten
            node_times = {n: G.nodes[n]['node_time'] for n in G.nodes}
            sorted_nodes = sorted(node_times.keys(), key=lambda n: node_times[n])

            # Sort the level and node_times from 0 to max_node_id and not in G.nodes order:
            node_times = torch.tensor([node_times[n] for n in sorted_nodes], dtype=torch.float)
            node_levels = torch.tensor([level[n] for n in sorted_nodes], dtype=torch.float)

            # Füge Kanten hinzu
            existing_edges = set(zip(edges[0].tolist(), edges[1].tolist()))
            candidate_edges = []
            for i, src in enumerate(sorted_nodes):
                t_src = node_times[src]
                for dst in sorted_nodes[i+1:]:  # nur spätere Knoten
                    t_dst = node_times[dst]
                    if t_dst > t_src:
                        if (dst, src) not in existing_edges:  # vermeidet doppelte
                            #G.add_edge(src, dst)
                            candidate_edges.append((dst, src))
            if candidate_edges:  
                candidate_edges = torch.tensor(candidate_edges, dtype=torch.long).t()
            else:
                candidate_edges = torch.empty((2,0), dtype=torch.long)
                    
            data = Data(
                nucleotide_sequences = nucleotide_sequences,
                hot_encoded_nucleotide_sequences = hot_encoded_nucleotide_sequences,       # Node Features [num_nodes, 2]
                edge_index = edges[[1, 0], :],        # Edge Index [2, num_edges]
                candidate_edges = candidate_edges[[1, 0], :],
                y = theta_gains,            # Labels [num_nodes]
                file = file,
                G = G,
                recipient_parent_nodes = recipient_parent_nodes,
                gene_absence_presence_matrix = gene_absence_presence_matrix,
                node_times = node_times,
                node_levels = node_levels
            )

    return data


folder = "/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks"

files = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

examples = load_file(random.choice(files))

  machar = _get_machar(dtype)
  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from torch_geometric.data import Data
import math

# ==========================================================

# 1️⃣ Attention-Layer mit lernbaren Kantenwahrscheinlichkeiten

# ==========================================================

class HGTAttentionLayer(MessagePassing):
    def __init__(self, in_dim, out_dim, heads=2, dropout=0.1):
        super().__init__(aggr='add')
        self.heads = heads
        self.out_dim = out_dim
        self.W = nn.Linear(in_dim, heads * out_dim, bias=False)
        self.att_src = nn.Parameter(torch.Tensor(heads, out_dim))
        self.att_dst = nn.Parameter(torch.Tensor(heads, out_dim))
        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.att_src)
        nn.init.xavier_uniform_(self.att_dst)
    
    def forward(self, x, edge_index, edge_probs):
        # x: [N, in_dim]
        # edge_probs: [E]
        x_proj = self.W(x)                                # [N, H*D]
        N = x_proj.size(0)
        return self.propagate(edge_index, x=x_proj, edge_attr=edge_probs, size=(N, N))

    def message(self, x_j, x_i, edge_attr, index, ptr, size_i):
        # x_j, x_i: [E, H*D]  (reshape back)
        # edge_attr: [E]
        H = self.heads
        D = self.out_dim

        # reshape node features to [E, H, D]
        x_j = x_j.view(-1, H, D)
        x_i = x_i.view(-1, H, D)

        ep = edge_attr
        if ep.dim() == 1:
            ep = ep
        elif ep.dim() == 2 and ep.shape[1] == 1:
            ep = ep.squeeze(-1)
        else:
            ep = ep.view(-1)

        # compute attention per head
        alpha = (x_i * self.att_dst + x_j * self.att_src).sum(dim=-1) / math.sqrt(D)
        alpha = alpha + torch.log(ep.unsqueeze(-1) + 1e-8)
        alpha = F.leaky_relu(alpha, 0.2)
        alpha = softmax(alpha, index)
        alpha = self.dropout(alpha)

        # scale messages: x_j * alpha * edge_prob
        return (x_j * alpha.unsqueeze(-1) * ep.unsqueeze(-1).unsqueeze(-1)).view(-1, H * D)

# ==========================================================

# 2️⃣ Vollständiges Modell

# ==========================================================

class GraphHGTModel(nn.Module):
    def __init__(self, in_dim, hidden=32, out_dim=1, heads=2, layers=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(layers):
            self.layers.append(HGTAttentionLayer(in_dim if i==0 else hidden, hidden, heads=heads))
        self.node_out = nn.Linear(hidden * heads, out_dim)
        
    def forward(self, data, edge_logits):

        x = data.hot_encoded_nucleotide_sequences          # (N, F)
        if not hasattr(data, "num_nodes") or data.num_nodes != x.shape[0]:
            data.num_nodes = x.shape[0]
        
        base_edges = data.edge_index.to(x.device)
        cand_edges = data.candidate_edges.to(x.device)
    
        # combine edge sets
        if cand_edges.numel() > 0:
            edge_index = torch.cat([base_edges, cand_edges], dim=1)
            edge_probs = torch.cat([
                torch.ones(base_edges.shape[1], device=x.device),
                torch.sigmoid(edge_logits)
            ])
        else:
            edge_index = base_edges
            edge_probs = torch.ones(base_edges.shape[1], device=x.device)
    
        h = x
        for layer in self.layers:
            h = F.relu(layer(h, edge_index, edge_probs))
        out_nodes = self.node_out(h).squeeze(-1)
        return out_nodes, edge_probs


# ==========================================================

# 3️⃣ Training-Funktion

# ==========================================================

def train_hgt(model, data_list, device="cpu", epochs=20, lr=2e-3):
    model.to(device)
    # initialize edge logits
    global_edge_logits = {}
    for data in data_list:
        if data.candidate_edges.numel() > 0:
            E = data.candidate_edges.shape[1]
            global_edge_logits[data.file] = nn.Parameter(torch.full((E,), -3.0, device=device))
        else:
            global_edge_logits[data.file] = None
    params = list(model.parameters()) + [p for p in global_edge_logits.values() if p is not None]
    optim = torch.optim.Adam(params, lr=lr, weight_decay=1e-5)
    bce = nn.BCEWithLogitsLoss()
        
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data in data_list:
            data = data.to(device)
            edge_logits = global_edge_logits[data.file]
            y_true = data.y.float().to(device)
            out_nodes, edge_probs = model(data, edge_logits)
    
            # split probabilities
            if edge_logits is not None and data.candidate_edges.numel() > 0:
                cand_scores = edge_probs[data.edge_index.shape[1]:]
                target_scores = y_true[data.candidate_edges[1]]
    
                loss_node = bce(out_nodes, y_true)
                loss_edge = bce(cand_scores, target_scores)
                loss_sparsity = 0.01 * cand_scores.sum()
                loss = loss_node + 0.5 * loss_edge + loss_sparsity
            else:
                loss = bce(out_nodes, y_true)
    
            optim.zero_grad()
            loss.backward()
            optim.step()
            total_loss += loss.item()
    
        print(f"Epoch {epoch:03d} | Avg Loss: {total_loss/len(data_list):.4f}")
    
    # output learned candidate edge probabilities
    for data in data_list:
        edge_logits = global_edge_logits[data.file]
        if edge_logits is not None:
            p = torch.sigmoid(edge_logits).detach().cpu().numpy()
            print(f"File {data.file} | Candidate edge probs: {p[:10]}")

# ==========================================================

# 4️⃣ Anwendung

# ==========================================================

folder = "/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks"

files = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

"""
list_of_Data = []
for f in files:
    try:
        d = load_file(f)
        list_of_Data.append(d)
    except Exception as e:
        print(f"Fehler beim Laden von {f}: {e}")

print(f"{len(list_of_Data)} Dateien erfolgreich geladen.")
"""

model = GraphHGTModel(in_dim=list_of_Data[0].hot_encoded_nucleotide_sequences.shape[1])

train_hgt(model, list_of_Data, device="cpu")


Epoch 000 | Avg Loss: 0.5649
Epoch 001 | Avg Loss: 0.5201
Epoch 002 | Avg Loss: 0.5178
Epoch 003 | Avg Loss: 0.5166
Epoch 004 | Avg Loss: 0.5240
Epoch 005 | Avg Loss: 0.5170


KeyboardInterrupt: 

In [6]:
list_of_Data[0]

Data(edge_index=[2, 8], y=[9], nucleotide_sequences=[5], hot_encoded_nucleotide_sequences=[9, 1500], candidate_edges=[2, 18], file='/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks/simulation_0.h5', G=DiGraph with 9 nodes and 8 edges, recipient_parent_nodes=[0], gene_absence_presence_matrix=[5], node_times=[9], node_levels=[9])