In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, DirGNNConv, GraphConv, GATConv
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import random, h5py, pickle, glob, os
import networkx as nx
from copy import deepcopy
from pathlib import Path
import ImportData
import numpy as np

import importlib
importlib.reload(ImportData)

### WINDOWS:
#output_dir = r"C:\Users\uhewm\Desktop\ProjectHGT\simulation_chunks(4)"
output_dir = r"C:\Users\uhewm\Desktop\ProjectHGT\simulation_chunks"
all_files = sorted(glob.glob(os.path.join(output_dir, "*.h5")))

### LINUX:
from pathlib import Path
import glob

output_dir = Path("/mnt/c/Users/uhewm/Desktop/ProjectHGT/simulation_chunks")
all_files = sorted(glob.glob(str(output_dir / "*.h5")))

data = []

for file in random.sample(all_files, 100):
#for file in all_files:
    single_data = ImportData.load_file(file)
    data.append(single_data)

random.shuffle(data)
split_idx = int(0.8 * len(data))
train_data = data[:split_idx]
test_data = data[split_idx:]

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16)
evaluate_loader = DataLoader(data, batch_size=16)


  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)           # [2, num_nodes]


ValueError: not enough values to unpack (expected 2, got 1)

In [25]:
#### FUNKTIONIERT!

class GCNClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.3):
        super().__init__()
        self.conv1 = DirGNNConv(GCNConv(in_channels, hidden_channels), alpha = 0)
        self.conv2 = DirGNNConv(GCNConv(hidden_channels, hidden_channels), alpha = 0)
        self.lin = nn.Linear(hidden_channels, 1)
        self.dropout = dropout

    def forward(self, x, edge_index):
        #edge_index =  torch.cat([edge_index, edge_index.flip(0)], dim=1)
        #edge_index =  edge_index.flip(0)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return x.view(-1)

class NodeMLPClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.3):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, hidden_channels)
        self.fc3 = nn.Linear(hidden_channels, hidden_channels)
        self.fc4 = nn.Linear(hidden_channels, hidden_channels)
        self.out = nn.Linear(hidden_channels, 1)
        self.dropout = dropout

    def forward(self, x, edge_index):
        """
        x: Tensor mit Shape [num_nodes, in_channels]
        """
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.fc2(x))
        #x = F.dropout(x, p=self.dropout, training=self.training)
        #x = F.relu(self.fc3(x))
        #x = F.dropout(x, p=self.dropout, training=self.training)
        #x = F.relu(self.fc4(x))
        x = self.out(x)
        return x.view(-1)

# === 3. Modell, Optimizer, Loss ===
model = NodeMLPClassifier(in_channels=11, hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Klassengewichte berechnen (gegen Ungleichgewicht)
all_labels = torch.cat([g.y for g in train_data])
ratio = (len(all_labels) - all_labels.sum()) / all_labels.sum()
pos_weight = torch.tensor((ratio**0.3), dtype=torch.float)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

print(f"Pos Weight: {pos_weight.item():.2f}")

# === 4. Training & Evaluation ===
def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def evaluate(loader, threshold = 0.5):
    model.eval()
    total_correct = 0
    total_nodes = 0
    tp, fp, fn = 0, 0, 0

    for batch in loader:
        out = model(batch.x, batch.edge_index)
        preds = torch.sigmoid(out) > threshold
        total_correct += (preds == batch.y.bool()).sum().item()
        total_nodes += batch.y.size(0)

        # Metriken für Klasse 1
        tp += ((preds == 1) & (batch.y == 1)).sum().item()
        fp += ((preds == 1) & (batch.y == 0)).sum().item()
        fn += ((preds == 0) & (batch.y == 1)).sum().item()

    acc = total_correct / total_nodes
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    return acc, precision, recall, f1

@torch.no_grad()
def find_best_threshold(loader, thresholds=np.linspace(0, 1, 101)):
    model.eval()
    best_threshold = 0.5
    best_f1 = 0.0

    # Alle Outputs und Labels sammeln, damit man nicht für jeden Threshold neu durch die Daten geht
    all_outs = []
    all_labels = []
    for batch in loader:
        out = model(batch.x, batch.edge_index)
        all_outs.append(torch.sigmoid(out))
        all_labels.append(batch.y)
    all_outs = torch.cat(all_outs)
    all_labels = torch.cat(all_labels)

    for threshold in thresholds:
        preds = all_outs > threshold
        tp = ((preds == 1) & (all_labels == 1)).sum().item()
        fp = ((preds == 1) & (all_labels == 0)).sum().item()
        fn = ((preds == 0) & (all_labels == 1)).sum().item()

        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    return best_threshold, best_f1


# === 5. Training starten ===
for epoch in range(1, 51):
    loss = train()
    acc, prec, rec, f1 = evaluate(test_loader)
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Acc: {acc:.3f} | Prec: {prec:.3f} | Rec: {rec:.3f} | F1: {f1:.3f}")

# Nach Training besten Threshold bestimmen
best_threshold, best_f1 = find_best_threshold(evaluate_loader)
print(f"\nBester Threshold: {best_threshold:.3f} mit F1-Score: {best_f1:.3f}")

# Evaluation mit bestem Threshold
acc, prec, rec, f1 = evaluate(test_loader, threshold=best_threshold)
print(f"Evaluation mit bestem Threshold:")
print(f"Acc: {acc:.3f} | Prec: {prec:.3f} | Rec: {rec:.3f} | F1: {f1:.3f}")

  pos_weight = torch.tensor((ratio**0.3), dtype=torch.float)


Pos Weight: 4.97
Epoch 01 | Loss: 0.1435 | Acc: 0.995 | Prec: 0.475 | Rec: 0.813 | F1: 0.599
Epoch 02 | Loss: 0.0330 | Acc: 0.997 | Prec: 0.601 | Rec: 0.782 | F1: 0.680
Epoch 03 | Loss: 0.0250 | Acc: 0.997 | Prec: 0.622 | Rec: 0.778 | F1: 0.692
Epoch 04 | Loss: 0.0211 | Acc: 0.997 | Prec: 0.668 | Rec: 0.756 | F1: 0.709
Epoch 05 | Loss: 0.0184 | Acc: 0.997 | Prec: 0.677 | Rec: 0.767 | F1: 0.720
Epoch 06 | Loss: 0.0173 | Acc: 0.997 | Prec: 0.704 | Rec: 0.789 | F1: 0.744
Epoch 07 | Loss: 0.0166 | Acc: 0.997 | Prec: 0.608 | Rec: 0.847 | F1: 0.708
Epoch 08 | Loss: 0.0160 | Acc: 0.997 | Prec: 0.590 | Rec: 0.859 | F1: 0.700
Epoch 09 | Loss: 0.0152 | Acc: 0.997 | Prec: 0.652 | Rec: 0.836 | F1: 0.733
Epoch 10 | Loss: 0.0151 | Acc: 0.997 | Prec: 0.679 | Rec: 0.817 | F1: 0.742
Epoch 11 | Loss: 0.0150 | Acc: 0.998 | Prec: 0.777 | Rec: 0.755 | F1: 0.766
Epoch 12 | Loss: 0.0147 | Acc: 0.997 | Prec: 0.624 | Rec: 0.861 | F1: 0.723
Epoch 13 | Loss: 0.0145 | Acc: 0.998 | Prec: 0.738 | Rec: 0.788 | F1: 0

In [33]:
from pathlib import Path

# Basisverzeichnis
project_dir = Path("/mnt/c/Users/uhewm/Desktop/ProjectHGT")

# Neuer Unterordner für Modelle
model_dir = project_dir / "trained_models"
model_dir.mkdir(exist_ok=True)   # legt den Ordner an, falls er noch nicht existiert

# Speicherpfad für das Modell
model_path = model_dir / "node_mlp_classifier_full.pt"

# Speichern
torch.save(model, model_path)

# Laden
#model = torch.load(model_path)
model.eval()


NodeMLPClassifier(
  (fc1): Linear(in_features=11, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=64, bias=True)
  (out): Linear(in_features=64, out_features=1, bias=True)
)

In [299]:
import numpy as np
from torch_geometric.utils import add_self_loops

# === 2. Modell definieren ===
class GCNClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.3):
        super().__init__()
        self.conv1 = DirGNNConv(GCNConv(in_channels, hidden_channels), alpha = 0)
        self.conv2 = DirGNNConv(GCNConv(hidden_channels, hidden_channels), alpha = 0)
        self.conv3 = DirGNNConv(GCNConv(hidden_channels, hidden_channels), alpha = 0)
        self.conv4 = DirGNNConv(GCNConv(hidden_channels, hidden_channels), alpha = 0)
        self.conv5 = DirGNNConv(GCNConv(hidden_channels, hidden_channels), alpha = 0)
        self.lin = nn.Linear(hidden_channels, 1)
        self.dropout = dropout

    def forward(self, x, edge_index):
        #edge_index = edge_index[[1, 0], :]
        #edge_index, _ = add_self_loops(edge_index)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)


        x = self.conv5(x, edge_index)
        x = F.relu(x)

        x = self.lin(x)
        return x.view(-1)


# === 3. Modell, Optimizer, Loss ===
model = GCNClassifier(in_channels=9, hidden_channels=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Klassengewichte berechnen (gegen Ungleichgewicht)
all_labels = torch.cat([g.y for g in train_data])
ratio = (len(all_labels) - all_labels.sum()) / all_labels.sum()
pos_weight = torch.tensor((ratio**0.3), dtype=torch.float)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

print(f"Pos Weight: {pos_weight.item():.2f}")

# === 4. Training & Evaluation ===
def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

@torch.no_grad()
def evaluate(loader, threshold=0.5):
    model.eval()
    total_correct = 0
    total_nodes = 0
    tp, fp, fn = 0, 0, 0

    for batch in loader:
        out = model(batch.x, batch.edge_index)
        preds = torch.sigmoid(out) > threshold
        total_correct += (preds == batch.y.bool()).sum().item()
        total_nodes += batch.y.size(0)

        tp += ((preds == 1) & (batch.y == 1)).sum().item()
        fp += ((preds == 1) & (batch.y == 0)).sum().item()
        fn += ((preds == 0) & (batch.y == 1)).sum().item()

    acc = total_correct / total_nodes
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    return acc, precision, recall, f1

@torch.no_grad()
def find_best_threshold(loader, thresholds=np.linspace(0, 1, 101)):
    model.eval()
    best_threshold = 0.5
    best_f1 = 0.0

    # Alle Outputs und Labels sammeln, damit man nicht für jeden Threshold neu durch die Daten geht
    all_outs = []
    all_labels = []
    for batch in loader:
        out = model(batch.x, batch.edge_index)
        all_outs.append(torch.sigmoid(out))
        all_labels.append(batch.y)
    all_outs = torch.cat(all_outs)
    all_labels = torch.cat(all_labels)

    for threshold in thresholds:
        preds = all_outs > threshold
        tp = ((preds == 1) & (all_labels == 1)).sum().item()
        fp = ((preds == 1) & (all_labels == 0)).sum().item()
        fn = ((preds == 0) & (all_labels == 1)).sum().item()

        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    return best_threshold, best_f1

@torch.no_grad()
def show_some_predictions(loader, n_samples=3, threshold=0.5):
    model.eval()
    all_probs = []
    all_preds = []
    all_labels = []

    for batch in loader:
        out = model(batch.x, batch.edge_index)
        probs = torch.sigmoid(out)
        preds = (probs > threshold).long()
        all_probs.append(probs)
        all_preds.append(preds)
        all_labels.append(batch.y)

    all_probs = torch.cat(all_probs)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    total_samples = len(all_labels)
    indices = random.sample(range(total_samples), k=min(n_samples, total_samples))

    for i, idx in enumerate(indices):
        print(f"Sample {i + 1}:")
        print(f"  True label:      {all_labels[idx].item()}")
        print(f"  Predicted prob:  {all_probs[idx].item():.4f}")
        print(f"  Predicted label: {all_preds[idx].item()}")

# === 5. Training starten ===
for epoch in range(1, 51):
    loss = train()
    acc, prec, rec, f1 = evaluate(test_loader, threshold=0.5)
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Acc: {acc:.3f} | Prec: {prec:.3f} | Rec: {rec:.3f} | F1: {f1:.3f}")

# Nach Training besten Threshold bestimmen
best_threshold, best_f1 = find_best_threshold(test_loader)
print(f"\nBester Threshold: {best_threshold:.3f} mit F1-Score: {best_f1:.3f}")

# Evaluation mit bestem Threshold
acc, prec, rec, f1 = evaluate(test_loader, threshold=best_threshold)
print(f"Evaluation mit bestem Threshold:")
print(f"Acc: {acc:.3f} | Prec: {prec:.3f} | Rec: {rec:.3f} | F1: {f1:.3f}")

print("\nEin paar Beispielvorhersagen auf Trainingsdaten mit bestem Threshold:")
show_some_predictions(train_loader, n_samples=3, threshold=best_threshold)



Pos Weight: 3.15


  pos_weight = torch.tensor((ratio**0.3), dtype=torch.float)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3184x13 and 9x32)

In [168]:
print("\nEin paar Beispielvorhersagen auf Trainingsdaten mit bestem Threshold:")
#show_some_predictions(train_loader, n_samples=25, threshold=best_threshold)




Ein paar Beispielvorhersagen auf Trainingsdaten mit bestem Threshold:


In [90]:
import torch
importlib.reload(ImportData)

for file in random.sample(all_files, 1):
    single_data = ImportData.load_file(file)
    
# === 1. Modellvorhersagen berechnen ===
model.eval()
with torch.no_grad():
    logits = model(single_data.x, single_data.edge_index)  # Shape: [num_nodes]
    probs = torch.sigmoid(logits).cpu().numpy()  # Werte zwischen 0 und 1

# Map von Node-ID zu Wahrscheinlichkeit
pred_probs = {i: p > best_threshold for i, p in enumerate(probs)}
pred_nodes = sorted([i for i, flag in pred_probs.items() if flag])

node_to_id = {node: i for i, node in enumerate(single_data.H.nodes)}
print("Predicted Nodes: ", sorted([node_to_id[i] for i, flag in pred_probs.items() if flag]))
print("True Nodes: ", sorted(set(single_data.parental_nodes_hgt_events_corrected)))

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0, 20: 0, 21: 0, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0, 29: 0, 30: 0, 31: 0, 32: 0, 33: 0, 34: 0, 35: 0, 36: 0, 37: 0, 38: 0, 39: 0, 40: 0, 41: 0, 42: 0, 43: 0, 44: 0, 45: 0, 46: 0, 47: 0, 48: 0, 49: 0, 50: 0, 51: 0, 52: 0, 53: 0, 54: 0, 55: 0, 56: 0, 57: 0, 58: 0, 59: 0, 60: 0, 61: 0, 62: 0, 63: 0, 64: 0, 65: 0, 66: 0, 67: 0, 68: 0, 69: 0, 70: 0, 71: 0, 72: 0, 73: 0, 74: 0, 75: 0, 76: 0, 77: 0, 78: 0, 79: 0, 80: 0, 81: 0, 82: 0, 83: 0, 84: 0, 85: 0, 86: 0, 87: 0, 88: 0, 89: 0, 90: 0, 91: 0, 92: 0, 93: 0, 94: 0, 95: 0, 96: 0, 97: 0, 98: 0, 99: 0, 100: 1, 101: 1, 102: 1, 103: 2, 104: 1, 105: 1, 106: 1, 107: 1, 108: 1, 109: 1, 110: 3, 111: 1, 112: 1, 113: 1, 114: 1, 115: 1, 116: 2, 117: 1, 118: 1, 119: 1, 120: 2, 121: 1, 122: 2, 123: 3, 124: 2, 125: 2, 126: 2, 127: 2, 128: 3, 129: 1, 130: 1, 131: 1, 134: 1, 133: 1, 132: 1, 137: 2, 136: 2, 135: 1, 140: 

In [32]:
import torch
from pyvis.network import Network
import subprocess

importlib.reload(ImportData)

for file in random.sample(all_files, 1):
    single_data = ImportData.load_file(file)
    
# === 1. Modellvorhersagen berechnen ===
model.eval()
with torch.no_grad():
    logits = model(single_data.x, single_data.edge_index)  # Shape: [num_nodes]
    probs = torch.sigmoid(logits).cpu().numpy()  # Werte zwischen 0 und 1

# Map von Node-ID zu Wahrscheinlichkeit
pred_probs = {i: p > best_threshold for i, p in enumerate(probs)}
pred_nodes = sorted([i for i, flag in pred_probs.items() if flag])

node_to_id = {node: i for i, node in enumerate(single_data.H.nodes)}
print("Predicted Nodes: ", sorted([node_to_id[i] for i, flag in pred_probs.items() if flag]))
print("True Nodes: ", sorted(set(single_data.parental_nodes_hgt_events_corrected)))

pred_probs = {node_to_id[i]: p for i, p in enumerate(probs)}


# --- x/y Koordinaten für Blätter und innere Knoten berechnen ---
x_spacing = 100
y_spacing = 100

node_x = {}
node_y = {}

# Maximaler Level aus single_data.H
max_level = max(single_data.H.nodes[n].get("level", 0) for n in single_data.H.nodes)

# Hilfsfunktion: finde alle Blätter unterhalb eines Knotens
def get_descendant_leaves(G, node):
    """Alle Blätter, die von `node` erreichbar sind (rekursiv)."""
    stack = list(G.predecessors(node))
    reachable_leaves = []
    while stack:
        temp_node = stack.pop()
        children = list(G.predecessors(temp_node))
        if len(children) > 0:
            stack.extend(children)
        else:
            reachable_leaves.append(temp_node)
    return reachable_leaves

# === Blätter (Level 0) oben ===
leaves = [n for n in single_data.H.nodes if single_data.H.nodes[n].get("level", 0) == 0]
for i, node in enumerate(sorted(leaves)):  
    node_x[node] = i * x_spacing
    node_y[node] = (max_level - 0) * y_spacing  # Blätter oben

# === Innere Knoten: levelweise platzieren ===
levels_in_graph = sorted(set(nx.get_node_attributes(single_data.H, "level").values()))
for level in levels_in_graph[1:]:  # 0 schon behandelt
    nodes_in_level = [n for n in single_data.H.nodes if single_data.H.nodes[n].get("level", 0) == level]
    for node in nodes_in_level:
        # Finde alle Blätter unterhalb
        reachable_leaves = get_descendant_leaves(single_data.H, node)
        if reachable_leaves:
            leaf_x = [node_x[l] for l in reachable_leaves if l in node_x]
            node_x[node] = np.mean(leaf_x)
        else:
            node_x[node] = 0
        node_y[node] = (max_level - level) * y_spacing


"""
# === Blätter (Level 0) oben ===
leaves = [n for n in single_data.H.nodes if single_data.H.nodes[n].get("level", 0) == 0]
for i, node in enumerate(sorted(leaves)):  
    node_x[node] = i * x_spacing
    node_y[node] = (max_level - 0) * y_spacing  # Blätter oben

# === Innere Knoten: levelweise platzieren ===
levels_in_graph = sorted(set(nx.get_node_attributes(single_data.H, "level").values()))
for level in levels_in_graph[1:]:  # 0 schon behandelt
    nodes_in_level = [n for n in single_data.H.nodes if single_data.H.nodes[n].get("level", 0) == level]
    for node in nodes_in_level:
        children = list(single_data.H.predecessors(node))
        if children:
            child_x = [node_x[c] for c in children if c in node_x]
            if child_x:  # falls Kinder schon positioniert
                node_x[node] = np.mean(child_x)
            else:
                node_x[node] = 0
        else:
            node_x[node] = 0
        node_y[node] = (max_level - level) * y_spacing
"""
        
# === Netzwerk initialisieren (Hierarchical Layout deaktiviert!) ===
net = Network(height="900px", width="100%", directed=True)

net.set_options("""
{
  "nodes": {
    "shape": "dot",
    "size": 12,
    "font": { "size": 30 }
  },
  "edges": {
    "arrows": {
      "to": { "enabled": true, "scaleFactor": 0.5 }
    }
  },
  "physics": {
    "enabled": false
  }
}
""")
# === Knoten hinzufügen mit festen x/y ===
for node in single_data.H.nodes():
    core = single_data.H.nodes[node].get('core_distance', 0) #
    allele = single_data.H.nodes[node].get('allele_distance', 0)# * 10000
    allele_distance_convolution = single_data.H.nodes[node].get('allele_distance_convolution', 0)
    core_distance_convolution = single_data.H.nodes[node].get('core_distance_convolution', 0)
    leaf_count = single_data.H.nodes[node].get("leaf_count", 0)
    leaf_count_presence_matters = single_data.H.nodes[node].get("leaf_count_presence_matters", 0)
    node_count = single_data.H.nodes[node].get("node_count", 0)
    node_count_presence_matters = single_data.H.nodes[node].get("node_count_presence_matters", 0)
    allele_distance_only_new = single_data.H.nodes[node].get('allele_distance_only_new', 0)
    allele_distances_both_children_polymorph = single_data.H.nodes[node].get('allele_distances_both_children_polymorph', 0)
    true_allele_distance = single_data.H.nodes[node].get('true_allele_distance', 0)
    true_allele_convolution = single_data.H.nodes[node].get('true_allele_convolution', 0)
    node_time = single_data.H.nodes[node].get('node_time', 0)
    pred = pred_probs[node]
    #title = f"Core: {core:.2f}, Allele: {allele:.2f}, leaf_count: {leaf_count}, leaf_count_presence_matters: {leaf_count_presence_matters}, node_count: {node_count}, node_count_presence_matters: {node_count_presence_matters}, HGT_prob: {pred:.2f}"
    #label = f"{node}\n({core:.2f}, {allele:.2f}, {leaf_count}, {leaf_count_presence_matters}, {node_count}, {node_count_presence_matters}, {pred:.2f})"
    title = f"Core: {core:.2f}, Allele: {allele:.2f}, Core_convolution: {core_distance_convolution:.2f}, Allele_only_new: {allele_distance_only_new:.2f}, Both_polymorph: {allele_distances_both_children_polymorph:.2f}, True_allele_distance: {true_allele_distance:.2f}, True_allele_convolution: {true_allele_convolution:.2f}, Allele_convolution: {allele_distance_convolution:.2f}, Time: {node_time:.2f}, Pred: {pred:.2f}"
    label = f"{node}\n({core:.2f}, {allele:.2f}, {core_distance_convolution:.2f}, {allele_distance_only_new:.2f}, {allele_distances_both_children_polymorph:.2f}, {true_allele_distance:.2f}, {true_allele_convolution:.2f}, {allele_distance_convolution:.2f}, {node_time:.2f}, {pred:.2f})"

    # Farbe
    if node in single_data.parental_nodes_hgt_events_corrected and pred > best_threshold:
        color = "green"
    elif node not in single_data.parental_nodes_hgt_events_corrected and pred > best_threshold:
        color = "violet"
    elif node in single_data.parental_nodes_hgt_events_corrected and pred <= best_threshold:
        color = "red"
    elif node < 100 and single_data.gene_absence_presence_matrix[node] == 1:
        color = "orange"
    elif node < 100 and single_data.gene_absence_presence_matrix[node] == 0:
        color = "black"
    elif node not in single_data.parental_nodes_hgt_events_corrected and pred <= best_threshold:
        color = "lightblue"

    net.add_node(node, label=label, title=title, color=color,
                 x=node_x[node], y=node_y[node])

# === Kanten hinzufügen ===
for u, v in single_data.H.edges():
    net.add_edge(u, v)

# === HTML-Datei speichern und direkt in Chrome öffnen ===
html_file = Path("/mnt/c/Users/uhewm/OneDrive/PhD/Project No.2/pangenome/graph.html")
html_file.parent.mkdir(parents=True, exist_ok=True)
net.show(str(html_file), notebook=False)

# WSL-Pfad in Windows-Pfad umwandeln
win_path = subprocess.run(["wslpath", "-w", str(html_file)], capture_output=True, text=True).stdout.strip()

# Direkt in Chrome öffnen
subprocess.run(["cmd.exe", "/C", "start", "chrome", win_path])


  edges = torch.tensor(graph_properties[1], dtype=torch.long)  # [2, num_edges]
  coords = torch.tensor(graph_properties[2].T)           # [2, num_nodes]


Predicted Nodes:  [170, 180, 188, 197]
True Nodes:  [170, 180, 188, 191]
/mnt/c/Users/uhewm/OneDrive/PhD/Project No.2/pangenome/graph.html
[?1l>4;1H[2J[?47l84l[?1h=[m[m[37m[40m[1;1H                                                                                [2;1H                                                                                [3;1H                                                                                [4;1H                                                                                [5;1H                                                                                [6;1H                                                                                [7;1H                                                                                [8;1H                                                                                [9;1H                                                                                [10;1H                                

CompletedProcess(args=['cmd.exe', '/C', 'start', 'chrome', 'C:\\Users\\uhewm\\OneDrive\\PhD\\Project No.2\\pangenome\\graph.html'], returncode=0)

In [278]:
def get_descendant_leaves(G, node):
    """Alle Blätter, die von `node` erreichbar sind (rekursiv)."""
    stack = list(G.predecessors(node))
    reachable_leaves = []
    while stack:
        temp_node = stack.pop()
        children = list(G.predecessors(temp_node))
        if len(children) > 0:
            stack.extend(children)
        else:
            reachable_leaves.append(temp_node)
    return reachable_leaves

node = 117
reachable_leaves = get_descendant_leaves(single_data.H, node)
print(node, reachable_leaves)

117 [1, 0]


In [111]:
with h5py.File(file, "r") as f:
    # in jeder Datei steckt eine Gruppe namens 'results'
    grp = f["results"]
    
    hgt_rate = grp.attrs["hgt_rate"]
    rho = grp.attrs["rho"]
    fitch_score = grp.attrs["fitch_score"]
    gene_number_hgt_events_passed = grp.attrs["gene_number_hgt_events_passed"]

    gene_absence_presence_matrix = grp.attrs["gene_absence_presence_matrix"] 
    gene_number_hgt_events_passed = grp.attrs["gene_number_hgt_events_passed"]
    gene_number_loss_events = grp.attrs["gene_number_loss_events"]
    parental_nodes_hgt_events_corrected = grp.attrs["parental_nodes_hgt_events_corrected"]
    children_gene_nodes_loss_events = grp.attrs["children_gene_nodes_loss_events"]

rho

1.7366938591003418

In [155]:
test = random.choice(test_data)

model.eval()
with torch.no_grad():
    logits = model(test.x, test.edge_index)  # Shape: [num_nodes]
    probs = torch.sigmoid(logits).cpu().numpy()  # Werte zwischen 0 und 1

# Map von Node-ID zu Wahrscheinlichkeit
pred_probs = {i: p > best_threshold for i, p in enumerate(probs)}
pred_nodes = sorted([i for i, flag in pred_probs.items() if flag])

# 2) Alle Werte != 0 im Tensor
true_nodes = (test.y * torch.arange(199, dtype=test.y.dtype)).nonzero(as_tuple=True)[0].tolist()

print(true_nodes)
print(pred_nodes)

[129, 130, 138, 155, 173, 176, 181, 186, 192]
[129, 130, 138, 155, 170, 171, 172, 173, 174, 175, 176, 177, 178, 181, 186, 192]


In [696]:
num_nodes = test.x.shape[0]

# Spalte für true_nodes: 1, wenn der Knoten in true_nodes ist, sonst 0
true_nodes_col = torch.zeros(num_nodes, dtype=test.x.dtype)
true_nodes_col[true_nodes] = 1
true_nodes_col = true_nodes_col.unsqueeze(1)  # [num_nodes, 1]

# Spalte für pred_nodes: 1, wenn der Knoten in pred_nodes ist, sonst 0
pred_nodes_col = torch.zeros(num_nodes, dtype=test.x.dtype)
pred_nodes_col[pred_nodes] = 1
pred_nodes_col = pred_nodes_col.unsqueeze(1)  # [num_nodes, 1]

# Spaltenweise zusammenfügen
matrix = torch.cat([test.x[:, 0:5], true_nodes_col, pred_nodes_col], dim=1)
matrix = matrix[99:]

np.set_printoptions(threshold=np.inf, suppress=True)  # Alles anzeigen, keine Scientific Notation
print(np.round(matrix.cpu().numpy(), 2))  # sollte [num_nodes, 7] sein

[[  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0

In [156]:
total_TP, total_FP, total_FN = 0, 0, 0

for test in test_loader:
    model.eval()
    with torch.no_grad():
        logits = model(test.x, test.edge_index)
        probs = torch.sigmoid(logits).cpu().numpy()

    pred_nodes = {i for i, p in enumerate(probs) if p > best_threshold}
    true_nodes = set((test.y * torch.arange(len(test.y), dtype=test.y.dtype)).nonzero(as_tuple=True)[0].tolist())

    TP = len(true_nodes & pred_nodes)
    FP = len(pred_nodes - true_nodes)
    FN = len(true_nodes - pred_nodes)

    total_TP += TP
    total_FP += FP
    total_FN += FN

print("Gesamt-Statistik:")
print(f"TP = {total_TP}")
print(f"FP = {total_FP}")
print(f"FN = {total_FN}")

Gesamt-Statistik:
TP = 745
FP = 274
FN = 140


In [466]:
leaves = [n for n in H.nodes if H.in_degree(n) == 0]

# Schritt 2: Rekursiv/Subgraph-basiert Anzahl der Leaves pro Knoten zählen
def count_leaves(node, G, gene_presence_matters = False):
    # wenn Blatt → 1
    if node in leaves:
        if not gene_presence_matters or gene_absence_presence_matrix[node] == 1:
            return 1
        else:
            return 0
    # sonst Summe der Leaves aller Kinder
    print(list(count_leaves(child, G) for child in G.predecessors(node)))
    return sum(count_leaves(child, G) for child in G.predecessors(node))

count_leaves(140, H, gene_presence_matters = True)

[1, 1]


2

In [573]:
present_nodes = [n for n in G.nodes if n < len(gene_absence_presence_matrix) and gene_absence_presence_matrix[n] == 1]

def mrca_of_nodes(G, nodes):
    G = G.reverse()
    if not nodes:
        return None
    # Starte mit erstem Knoten
    mrca = nodes[0]
    for node in nodes[1:]:
        mrca = nx.lowest_common_ancestor(G, mrca, node)
        if mrca is None:  # kein gemeinsamer Vorfahre
            return None
    return mrca

mrca_node = mrca_of_nodes(G, present_nodes)
print(mrca_node)

99
