In [None]:
# ------------------------------------------------
# Konfiguration
# ------------------------------------------------
datei = 'test.xml'
dateiname = "graph"
knotenwert = 1
kantenwert = 1
gesuchter_knoten = "a300"  # <-- hier den zu suchenden Knoten eintragen
Nur_Typnummer = True # # <-- Auf False setzen, um nur Typnummern (z. B. a300) zu zeigen

In [None]:
# -*- coding: utf-8 -*-
# ============================================
# Märchenanalyse mit Graph-Erstellung (pyvis)
# und Hervorhebung eines bestimmten Knotens
# ============================================

import xml.etree.ElementTree as ET
import pandas as pd
from pyvis.network import Network
from itertools import combinations
from collections import defaultdict

# ------------------------------------------------
# TEI-XML-Datei einlesen und df erzeugen
# ------------------------------------------------
def read_xml_into_df(datei: str):
    '''
     TEI-XML-Datei einlesen und CSV erzeugen
    '''
    tree = ET.parse(datei)
    root_node = tree.getroot()
    NS = {'tei': 'http://www.tei-c.org/ns/1.0'}

    labelas = []
    labelxs = []

    for corp in root_node.findall(".//tei:teiCorpus", NS):
        for tei in corp.findall("tei:TEI", NS):
            for ganze in tei.findall(".//tei:text", NS):
                labelx = ganze.attrib['{http://www.w3.org/XML/1998/namespace}id']
                for body in ganze.findall(".//tei:body", NS):
                    for absatz in body.findall(".//tei:p", NS):
                        for phrase in absatz.findall(".//tei:seg", NS):
                            labela = phrase.attrib['{www.dglab.uni-jena.de/vmf/a}ana']
                            if labela.startswith('a'):
                                labelas.append(labela)
                                labelxs.append(labelx)

    df = pd.DataFrame({"labela": labelas, "labelx": labelxs})
    df.drop_duplicates(inplace=True)
    return df

# ------------------------------------------------
# Graph-Daten aus XML laden
# ------------------------------------------------

df = read_xml_into_df(datei)
unique_df = df.drop_duplicates()

net = Network(notebook=True)
net.force_atlas_2based()

edges = {}
min_frequency = 1
min_unique_labelx = knotenwert
self_loops = {}
node_frequency = defaultdict(set)

# ------------------------------------------------
# XSD einlesen für Bezeichnungs-Ersetzungen
# ------------------------------------------------
xsd_path = 'kf/vmf_a.xsd'
tree = ET.parse(xsd_path)
root = tree.getroot()
ns = {'xs': 'http://www.w3.org/2001/XMLSchema'}

replacement_dict = {}
for enumeration in root.findall(".//xs:enumeration", ns):
    value = enumeration.attrib.get('value', '')
    if ' ' in value:
        split_value = value.split(' ', 1)
        key = split_value[0]
        replacement_dict[key] = value

# ------------------------------------------------
# Option: vollständige Labels ein- oder ausschalten
# ------------------------------------------------
show_full_labels = Nur_Typnummer  

def get_display_label(labela):
    """Gibt entweder nur die Typnummer oder den vollständigen Labeltext zurück."""
    if show_full_labels:
        return replacement_dict.get(labela, labela)
    else:
        return labela

# ------------------------------------------------
# Kanten- und Knotenwert berechnen
# ------------------------------------------------
for labelx, group in unique_df.groupby('labelx'):
    labelas = group['labela'].unique()
    
    if len(labelas) == 1:
        labela = labelas[0]
        if labela in self_loops:
            self_loops[labela]['weight'] += 1
            self_loops[labela]['labelx_values'].append(labelx)
        else:
            self_loops[labela] = {'weight': 1, 'labelx_values': [labelx]}
        node_frequency[labela].add(labelx)
    else:
        for pair in combinations(labelas, 2):
            sorted_pair = tuple(sorted(pair))
            if sorted_pair in edges:
                edges[sorted_pair]['weight'] += 1
                edges[sorted_pair]['labelx_values'].append(labelx)
            else:
                edges[sorted_pair] = {'weight': 1, 'labelx_values': [labelx]}
            node_frequency[pair[0]].add(labelx)
            node_frequency[pair[1]].add(labelx)

node_frequency = {node: len(labelx_set) for node, labelx_set in node_frequency.items()}

# ------------------------------------------------
# Knotengröße skalieren
# ------------------------------------------------
def get_node_size(frequency, min_size=10, max_size=50):
    max_frequency = max(node_frequency.values()) if node_frequency else 1
    return min_size + (max_size - min_size) * (frequency / max_frequency)

# ------------------------------------------------
# Knoten & Kanten hinzufügen
# ------------------------------------------------
for (labela1, labela2), data in edges.items():
    #if data['weight'] >= min_frequency:
    if data['weight'] >= kantenwert: # neue Schwelle: wie oft müssen zwei Typen gemeinsam auftreten?
        if node_frequency[labela1] >= min_unique_labelx and node_frequency[labela2] >= min_unique_labelx:
            title1 = replacement_dict.get(labela1, labela1)
            title2 = replacement_dict.get(labela2, labela2)
            size1 = get_node_size(node_frequency[labela1])
            size2 = get_node_size(node_frequency[labela2])

            display_label1 = get_display_label(labela1)
            display_label2 = get_display_label(labela2)

            net.add_node(labela1, label=display_label1, title=f'{title1} ({node_frequency[labela1]})', size=size1)
            net.add_node(labela2, label=display_label2, title=f'{title2} ({node_frequency[labela2]})', size=size2)
            net.add_edge(labela1, labela2, value=data['weight'], title=f'{data["weight"]}')
        
        if node_frequency[labela1] >= min_unique_labelx and node_frequency[labela2] >= min_unique_labelx:
            title1 = replacement_dict.get(labela1, labela1)
            title2 = replacement_dict.get(labela2, labela2)
            size1 = get_node_size(node_frequency[labela1])
            size2 = get_node_size(node_frequency[labela2])

            # Anzeige-Label je nach Einstellung wählen
            display_label1 = get_display_label(labela1)
            display_label2 = get_display_label(labela2)

            net.add_node(labela1, label=display_label1, title=f'{title1} ({node_frequency[labela1]})', size=size1)
            net.add_node(labela2, label=display_label2, title=f'{title2} ({node_frequency[labela2]})', size=size2)
            net.add_edge(labela1, labela2, value=data['weight'], title=f'{data["weight"]}')

for labela, data in self_loops.items():
    if node_frequency[labela] >= min_unique_labelx:
        title = replacement_dict.get(labela, labela)
        size = get_node_size(node_frequency[labela])

        # Anzeige-Label je nach Einstellung wählen
        display_label = get_display_label(labela)

        net.add_node(labela, label=display_label, title=f'{title} ({node_frequency[labela]})', size=size)
        net.add_edge(labela, labela, value=data['weight'], title=f'{data["weight"]}')

# ------------------------------------------------
# Gesuchten Knoten hervorheben
# ------------------------------------------------
found = False
for node in net.nodes:
    if node["id"] == gesuchter_knoten:
        found = True
        node["color"] = "red"
        node["title"] += " "
        print(f"Knoten '{gesuchter_knoten}' gefunden und hervorgehoben.")
        break

if not found:
    print(f"Knoten '{gesuchter_knoten}' wurde im Graphen nicht gefunden.")

# ------------------------------------------------
# Graph anzeigen/speichern
# ------------------------------------------------
net.show(dateiname + '.html')