In [57]:
import os
import random
import h5py
import pickle
import torch
import networkx as nx
from torch_geometric.data import Data
import numpy as np

def one_hot_encode(sequences, gene_present, gene_length, alphabet=['A','C','T','G','-']):
    """
    sequences: List of strings (DNA sequences)
    gene_present: np.array(bool) oder Torch Tensor, gleiche L√§nge wie sequences
    gene_length: int, fixe L√§nge f√ºr das Hot-Encoding
    alphabet: list, Zeichenalphabet
    """
    num_samples = len(sequences)
    num_chars = len(alphabet)
    char_to_idx = {c:i for i,c in enumerate(alphabet)}
    sequences_str = [s.decode('utf-8') for s in sequences]
    gene_present = np.array(gene_present, dtype=bool)
    
    # 1Ô∏è‚É£ Leere Batch-Matrix vorbereiten: (num_samples, gene_length, num_chars)
    batch = np.zeros((num_samples, gene_length, num_chars), dtype=np.float32)
    
    # 2Ô∏è‚É£ Hot-Encode alle Sequenzen
    for i, seq in enumerate(sequences_str):
        if gene_present[i]:
            L = min(len(seq), gene_length)  # abschneiden
            for j, c in enumerate(seq[:L]):
                if c in char_to_idx:
                    batch[i, j, char_to_idx[c]] = 1.0
        elif gene_present[i] == 0:
            batch[i, :, :] = -1.0
            
    # 3Ô∏è‚É£ Zuf√§llige, aber konsistente Spaltenpermutation
    perm = np.random.permutation(gene_length)
    batch = batch[:, perm, :]
    
    # 4Ô∏è‚É£ Optional: Flatten zu Vektor (num_samples, gene_length*num_chars)
    batch_flat = batch.reshape(num_samples, -1)
    
    return torch.tensor(batch_flat)  # shape: (num_samples, gene_length*num_chars)

def load_file(file):

    gene_length = 300
    #nucleotide_mutation_rate = 0.1
    
    with h5py.File(file, "r") as f:
            grp = f["results"]
            # Load graph_properties (pickle stored in dataset)
            graph_properties = pickle.loads(grp["graph_properties"][()])
    
            # Unpack graph properties
            nodes = torch.tensor(graph_properties[0])                # [num_nodes]
            edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
            coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]
    
            # Load datasets instead of attrs
            gene_absence_presence_matrix = grp["gene_absence_presence_matrix"][()]
            nucleotide_sequences = grp["nucleotide_sequences"][()]
            #children_gene_nodes_loss_events = grp["children_gene_nodes_loss_events"][()]
        
            # Load HGT events (simplified)
            hgt_events = {}
            hgt_grp_simpl = grp["nodes_hgt_events_simplified"]
            for site_id in hgt_grp_simpl.keys():
                hgt_events[int(site_id)] = hgt_grp_simpl[site_id][()]

            hot_encoded_nucleotide_sequences = one_hot_encode(nucleotide_sequences, gene_absence_presence_matrix, gene_length)

            # Fill the remaining nodes with zeros.
            pad_rows = len(nodes) - len(nucleotide_sequences)
            pad = torch.zeros((pad_rows, hot_encoded_nucleotide_sequences.shape[1]), dtype=hot_encoded_nucleotide_sequences.dtype)
            hot_encoded_nucleotide_sequences = torch.cat([hot_encoded_nucleotide_sequences, pad], dim=0)
        
            G = nx.DiGraph()
            
            # F√ºge Knoten hinzu (optional mit Koordinaten als Attribut)
            for i, node_id in enumerate(nodes.tolist()):
                G.add_node(node_id, node_time = coords[:, i].tolist()[5])
            
            # F√ºge Kanten hinzu
            edge_list = edges.tolist()
            for src, dst in zip(edge_list[0], edge_list[1]):
                G.add_edge(src, dst)
            
            # Collect all recipient_parent_nodes from all sites
            recipient_parent_nodes = set()
            for site_id in hgt_grp_simpl.keys():
                arr = hgt_grp_simpl[site_id][()]  # load dataset as numpy structured array
                recipient_parent_nodes.update(arr["recipient_child_node"].tolist())
            
            # Build theta_gains: 1 if node is in recipient_parent_nodes, else 0
            theta_gains = torch.tensor(
                [1 if node in recipient_parent_nodes else 0 for node in range(len(G.nodes))],
                dtype=torch.long
            )

            level = {n: 0 for n in G.nodes}  # Leaves haben Level 0
            
            # 3. Topologische Sortierung (damit Kinder vor Eltern behandelt werden)
            for node in reversed(list(nx.topological_sort(G))):
                successors = list(G.successors(node))
                if successors:
                    level[node] = 1 + max(level[s] for s in successors)
            
            # 4. Level als Attribut setzen
            nx.set_node_attributes(G, level, "level")

            ### Add candidate egdes, i.e. potential hgt edges:
        
            # Hole alle Knoten und ihre Zeiten
            node_times = {n: G.nodes[n]['node_time'] for n in G.nodes}
            sorted_nodes = sorted(node_times.keys(), key=lambda n: node_times[n])

            # Sort the level and node_times from 0 to max_node_id and not in G.nodes order:
            node_times = torch.tensor([node_times[n] for n in sorted_nodes], dtype=torch.float)
            node_levels = torch.tensor([level[n] for n in sorted_nodes], dtype=torch.float)

            # F√ºge Kanten hinzu
            existing_edges = set(zip(edges[0].tolist(), edges[1].tolist()))
            candidate_edges = []
            for i, src in enumerate(sorted_nodes):
                t_src = node_times[src]
                for dst in sorted_nodes[i+1:]:  # nur sp√§tere Knoten
                    t_dst = node_times[dst]
                    if t_dst > t_src:
                        if (dst, src) not in existing_edges:  # vermeidet doppelte
                            #G.add_edge(src, dst)
                            candidate_edges.append((dst, src))
            if candidate_edges:  
                candidate_edges = torch.tensor(candidate_edges, dtype=torch.long).t()
            else:
                candidate_edges = torch.empty((2,0), dtype=torch.long)
                    
            data = Data(
                nucleotide_sequences = nucleotide_sequences,
                hot_encoded_nucleotide_sequences = hot_encoded_nucleotide_sequences,       # Node Features [num_nodes, 2]
                edge_index = edges[[1, 0], :],        # Edge Index [2, num_edges]
                candidate_edges = candidate_edges[[1, 0], :],
                y = theta_gains,            # Labels [num_nodes]
                file = file,
                G = G,
                recipient_parent_nodes = recipient_parent_nodes,
                gene_absence_presence_matrix = gene_absence_presence_matrix,
                node_times = node_times,
                node_levels = node_levels
            )

    return data


folder = "/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks"

files = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

examples = load_file(random.choice(files))

  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)             # [2, num_nodes]


In [59]:
# Ben√∂tigte Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv, TransformerConv, global_mean_pool
from torch_geometric.loader import DataLoader
import numpy as np

import os
from torch.utils.data import random_split

# --------------------------
# 1) Sequence Encoder (1D-CNN; alternative: tiny Transformer)
# --------------------------
class SeqEncoderCNN(nn.Module):
    def __init__(self, input_dim, emb_dim=128, kernel_sizes=[3,5,7], dropout=0.1):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=input_dim, out_channels=emb_dim, kernel_size=k, padding=k//2)
            for k in kernel_sizes
        ])
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(len(kernel_sizes)*emb_dim, emb_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x_seq_flat):
        # x_seq_flat: [num_nodes, gene_length * alphabet_size]  (wie dein hot_encoded batch_flat)
        # Wir m√ºssen es wieder in (B, C, L)
        # Du musst gene_length und alphabet_size kennen (hier Annahme: 300, 5). Passen falls n√∂tig.
        B = x_seq_flat.size(0)
        # Beispielwerte, pass an:
        gene_length = 300
        alphabet_size = x_seq_flat.size(1) // gene_length
        x = x_seq_flat.view(B, gene_length, alphabet_size).permute(0,2,1)  # (B, C, L)
        conv_outs = []
        for conv in self.convs:
            y = F.relu(conv(x))
            y = self.pool(y).squeeze(-1)  # (B, out_channels)
            conv_outs.append(y)
        y = torch.cat(conv_outs, dim=1)
        y = self.drop(F.relu(self.fc(y)))
        return y  # (B, emb_dim)

# --------------------------
# 2) Node Feature encoder: kombiniert seq-emb + numeric features
# --------------------------
class NodeFeatureEncoder(nn.Module):
    def __init__(self, node_emb_dim=128):
        super().__init__()
        self.seq_encoder = SeqEncoderCNN(input_dim=5, emb_dim=node_emb_dim//2)
        self.mlp = nn.Sequential(
            nn.Linear(node_emb_dim//2 + 3, node_emb_dim),  # 3 numerische features: time, level, gene_present
            nn.ReLU(),
            nn.LayerNorm(node_emb_dim)
        )
    def forward(self, hot_seq_flat, node_times, node_levels, gene_present):
        seq_emb = self.seq_encoder(hot_seq_flat)  # (N, D1)
        # Stack numeric features (unsqueeze falls n√∂tig)
        numeric = torch.stack([
            node_times.float(),
            node_levels.float(),
            gene_present.float(),
            #theta_gain.float()
        ], dim=1)
        x = torch.cat([seq_emb, numeric], dim=1)
        return self.mlp(x)  # (N, node_emb_dim)

# --------------------------
# 3) GNN Encoder (operates on existing tree edges)
# --------------------------
class GNNEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, num_layers=3, use_transformer=False):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_c = in_dim if i==0 else hidden_dim
            if use_transformer:
                conv = TransformerConv(in_c, hidden_dim//1, heads=4, concat=False)
            else:
                conv = GATConv(in_c, hidden_dim//4, heads=4, concat=True)  # outputs hidden_dim
            self.convs.append(conv)
            self.bn = nn.LayerNorm(hidden_dim)
        self.out_bn = nn.LayerNorm(hidden_dim)
    def forward(self, x, edge_index):
        h = x
        for conv in self.convs:
            h = conv(h, edge_index)
            h = F.relu(h)
            h = self.bn(h)
        return self.out_bn(h)

# --------------------------
# 4) Edge classifier
# --------------------------
class EdgeClassifier(nn.Module):
    def __init__(self, node_emb_dim, hidden=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(node_emb_dim*2 + 3, hidden),  # src||dst||3 pair-features
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
    def forward(self, h_src, h_dst, pair_feats):
        # h_src, h_dst: (E, D)
        x = torch.cat([h_src, h_dst, pair_feats], dim=1)
        return self.mlp(x).squeeze(-1)  # logits (E,)


# --------------------------
# 5) Full Model wrapper
# --------------------------
class TreeEdgeCorrectionModel(nn.Module):
    def __init__(self, node_emb_dim=32, gnn_layers=3, use_transformer=False):
        super().__init__()
        self.node_encoder = NodeFeatureEncoder(node_emb_dim=node_emb_dim)
        self.gnn = GNNEncoder(node_emb_dim, hidden_dim=node_emb_dim, num_layers=gnn_layers, use_transformer=use_transformer)
        self.edge_clf = EdgeClassifier(node_emb_dim, hidden=node_emb_dim)
    def forward(self, data):
        # Unpack data (assume fields from your load_file)
        hot = data.hot_encoded_nucleotide_sequences  # (N, gene_len*alpha)
        node_times = data.node_times
        node_levels = data.node_levels
        # gene presence: from your matrix (first len(nucleotide_sequences) True or False)
        # we assume data.gene_absence_presence_matrix is numpy array length = num_leaves (but pad rows were added)
        """
        if isinstance(data.gene_absence_presence_matrix, np.ndarray):
            gene_present = torch.tensor(data.gene_absence_presence_matrix, dtype=torch.float, device=hot.device)
            # pad to full nodes if necessary:
            if gene_present.shape[0] < hot.shape[0]:
                pad = torch.zeros(hot.shape[0] - gene_present.shape[0], dtype=torch.float, device=hot.device)
                gene_present = torch.cat([gene_present, pad], dim=0)
        else:
            # fallback zeros
            gene_present = torch.zeros(hot.shape[0], device=hot.device)
        """
        gene_present = torch.tensor(data.gene_absence_presence_matrix, dtype=torch.float, device=hot.device)
        # theta_gain (you already stored)
        theta = data.y.float().to(hot.device)
        # Node embeddings
        h0 = self.node_encoder(hot, node_times, node_levels, gene_present)  # (N, D)
        # GNN over existing tree edges
        h = self.gnn(h0, data.edge_index.to(hot.device))
        # Prepare candidate edges
        cand = data.candidate_edges.to(hot.device)  # shape [2, E_cand]
        src_idx = cand[0]
        dst_idx = cand[1]
        h_src = h[src_idx]
        h_dst = h[dst_idx]
        # Pair features: time_diff_abs, time_sign, phylo_distance_placeholder
        # time difference and sign:
        time_src = node_times[src_idx]
        time_dst = node_times[dst_idx]
        time_diff = (time_dst - time_src).unsqueeze(1)       # (E,1) positive means dst is later (younger)
        time_abs = torch.abs(time_diff)
        time_sign = (time_diff > 0).float()
        # optional phylogenetic distance: here we don't compute real tree distance; put zeros or compute shortest_path in G
        # for speed, we use absolute level difference as proxy
        level_src = node_levels[src_idx].unsqueeze(1)
        level_dst = node_levels[dst_idx].unsqueeze(1)
        level_diff = torch.abs(level_dst - level_src)
        pair_feats = torch.cat([time_abs, time_sign, level_diff], dim=1)
        logits = self.edge_clf(h_src, h_dst, pair_feats)
        return logits  # (E_cand,)

# --------------------------
# 6) Edge labels construction (aus deinen HGT-Ereignissen)
# --------------------------
def build_edge_labels(data):
    # data.candidate_edges: [2, E]
    cand = data.candidate_edges.numpy().T.tolist()  # list of (src, dst)
    cand_set = set((int(s), int(t)) for s,t in cand)
    # Extract true HGT edges from your hgt_grp_simpl (data.recipient_parent_nodes only contains nodes, not full pairs)
    # If you have raw hgt event arrays accessible somewhere in Data object, use them; else we try reconstruct entries from file
    # Here we assume data has attribute 'G' and also stored raw events maybe in data.recipient_parent_nodes is set of recipients.
    # We will create a label per candidate: positive if dst in recipient_parent_nodes (heuristic).
    labels = []
    recipients = getattr(data, 'recipient_parent_nodes', set())
    for s,t in cand:
        # Heuristic: if the child (t) is present in recipients set -> positive
        if t in recipients:
            labels.append(1)
        else:
            labels.append(0)
    return torch.tensor(labels, dtype=torch.float)

# --------------------------
# 8) Beispiel: Initialisierung
# --------------------------
"""
class TreeEdgeCorrectionModel(nn.Module):
    def __init__(self, seq_input_dim, node_emb_dim=128, gnn_layers=3):
        super().__init__()
        self.input_lin = nn.Linear(seq_input_dim, node_emb_dim)
        self.convs = nn.ModuleList([GCNConv(node_emb_dim, node_emb_dim) for _ in range(gnn_layers)])
        self.output_lin = nn.Linear(node_emb_dim, 1)

    def forward(self, data):
        x, edge_index = data.hot_encoded_nucleotide_sequences, data.edge_index
        x = torch.relu(self.input_lin(x))
        for conv in self.convs:
            x = torch.relu(conv(x, edge_index))
        out = self.output_lin(x).squeeze(-1)   # [num_nodes]
        return out

"""
# =============================
# üîπ Train/Val Split & Training
# =============================

def train_val_split(dataset, val_ratio=0.2, seed=42):
    """Teilt Datensatz in Training und Validierung."""
    n_total = len(dataset)
    n_val = int(n_total * val_ratio)
    n_train = n_total - n_val
    generator = torch.Generator().manual_seed(seed)
    train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=generator)
    return train_ds, val_ds


def evaluate(model, dataloader, device="cpu"):
    """Berechnet den Loss auf dem Validierungsdatensatz."""
    model.eval()
    criterion = torch.nn.BCEWithLogitsLoss()
    total_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            out = model(batch)
            y = batch.y.float()
            min_len = min(out.shape[0], y.shape[0])
            loss = criterion(out[:min_len], y[:min_len])
            total_loss += loss.item()
    return total_loss / len(dataloader)


def train_with_val(model, dataset, optimizer, device="cpu", epochs=20, val_ratio=0.2, batch_size=8):
    # ----------------------------
    # 1) Train/Val Split
    # ----------------------------
    
    train_ds, val_ds = train_val_split(dataset, val_ratio)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    
    # ----------------------------
    # 2) Training Loop
    # ----------------------------
    model.to(device)
    criterion = torch.nn.BCEWithLogitsLoss()

    for epoch in range(1, epochs + 1):
        model.train()
        total_train_loss = 0
        for batch_list in train_loader:
            data = batch_list[0]  # Data-Objekt
            hot = data.hot_encoded_nucleotide_sequences.to(device)
            edge_index = data.edge_index.to(device)
            candidate_edges = data.candidate_edges.to(device)
            y_edges = build_edge_labels(data).to(device)
            y = data.y.to(device)
            
            optimizer.zero_grad()
            out = model(data)  # model.forward nutzt data.G
            loss = criterion(out, y_edges)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        val_loss = evaluate(model, val_loader, device=device)
        print(f"Epoch {epoch:03d} | Train-Loss: {avg_train_loss:.4f} | Val-Loss: {val_loss:.4f}")

    # ----------------------------
    # 3) Endg√ºltiger Validierungsfehler
    # ----------------------------
    final_val_loss = evaluate(model, val_loader, device=device)
    print(f"\n‚úÖ Endg√ºltiger Validierungs-Loss: {final_val_loss:.4f}")
    return final_val_loss

# =============================
# üîπ Dateien laden
# =============================
folder = "/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks"

files = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

"""
list_of_Data = []
for f in files:
    try:
        d = load_file(f)
        list_of_Data.append(d)
    except Exception as e:
        print(f"Fehler beim Laden von {f}: {e}")
"""

print(f"{len(list_of_Data)} Dateien erfolgreich geladen.")

# =============================
# üîπ Model, Optimizer
# =============================
gene_length = 300
seq_input_dim = gene_length * 5
model = TreeEdgeCorrectionModel(node_emb_dim=128, gnn_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# =============================
# üîπ Training starten
# =============================
device = "cuda" if torch.cuda.is_available() else "cpu"
final_val_loss = train_with_val(model, list_of_Data, optimizer, device=device, epochs=10)

5508 Dateien erfolgreich geladen.


RuntimeError: stack expects each tensor to be equal size, but got [9] at entry 0 and [5] at entry 2

  hot_encoded_nucleotide_sequences = torch.tensor(d.hot_encoded_nucleotide_sequences, dtype=torch.float),
  edge_index = torch.tensor(d.edge_index, dtype=torch.long),
  candidate_edges = torch.tensor(d.candidate_edges, dtype=torch.long),
  y = torch.tensor(d.y, dtype=torch.float)


In [31]:
model

TreeEdgeCorrectionModel(
  (input_lin): Linear(in_features=1500, out_features=128, bias=True)
  (convs): ModuleList()
  (output_lin): Linear(in_features=128, out_features=1, bias=True)
)

In [26]:
import torch
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

def evaluate_predictions(model, dataloader, device="cpu", threshold=0.5):
    model.eval()
    y_true_all = []
    y_pred_all = []
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            out = model(batch)                 # Logits pro Node
            probs = torch.sigmoid(out)         # Wahrscheinlichkeiten [0,1]
            preds = (probs > threshold).long() # 1 = vorhergesagt als "hgt"
            #print(torch.round(probs * 100) / 100)
            
            # Sicherstellen, dass L√§ngen passen
            min_len = min(len(batch.y), len(preds))
            y_true_all.extend(batch.y[:min_len].cpu().numpy())
            y_pred_all.extend(preds[:min_len].cpu().numpy())
    
    y_true_all = torch.tensor(y_true_all)
    y_pred_all = torch.tensor(y_pred_all)
    
    acc = accuracy_score(y_true_all, y_pred_all)
    prec = precision_score(y_true_all, y_pred_all, zero_division=0)
    rec = recall_score(y_true_all, y_pred_all, zero_division=0)
    f1 = f1_score(y_true_all, y_pred_all, zero_division=0)
    
    print(f"Accuracy:  {acc:.3f}")
    print(f"Precision: {prec:.3f}")
    print(f"Recall:    {rec:.3f}")
    print(f"F1-Score:  {f1:.3f}")
    
    return y_true_all, y_pred_all

y_true, y_pred = evaluate_predictions(model, dataloader, device)


Accuracy:  0.992
Precision: 0.986
Recall:    0.787
F1-Score:  0.875


In [12]:
for i in dataloader:
    print(i)

DataBatch(edge_index=[2, 32], y=[36], nucleotide_sequences=[4], hot_encoded_nucleotide_sequences=[36, 1500], candidate_edges=[8, 18], file=[4], G=[4], recipient_parent_nodes=[4], gene_absence_presence_matrix=[4], batch=[36], ptr=[5])
DataBatch(edge_index=[2, 32], y=[36], nucleotide_sequences=[4], hot_encoded_nucleotide_sequences=[36, 1500], candidate_edges=[8, 18], file=[4], G=[4], recipient_parent_nodes=[4], gene_absence_presence_matrix=[4], batch=[36], ptr=[5])
DataBatch(edge_index=[2, 32], y=[36], nucleotide_sequences=[4], hot_encoded_nucleotide_sequences=[36, 1500], candidate_edges=[8, 18], file=[4], G=[4], recipient_parent_nodes=[4], gene_absence_presence_matrix=[4], batch=[36], ptr=[5])
DataBatch(edge_index=[2, 32], y=[36], nucleotide_sequences=[4], hot_encoded_nucleotide_sequences=[36, 1500], candidate_edges=[8, 18], file=[4], G=[4], recipient_parent_nodes=[4], gene_absence_presence_matrix=[4], batch=[36], ptr=[5])
DataBatch(edge_index=[2, 32], y=[36], nucleotide_sequences=[4], 

