In [1]:
import sys
import polars as pl
import plotly.express as px
sys.path.insert(0, '..')
from fs_thesis import sql, show

# Create Overview for Tables Patients and Admissions

In [2]:
# Check coverage of BMI vs Weight/Height
# If BMI is present for most patients who have Weight/Height, we can use it.
df_coverage = sql("""
    SELECT 
        result_name, 
        COUNT(*) as total_measurements,
        COUNT(DISTINCT subject_id) as unique_patients
    FROM hosp.omr 
    WHERE result_name IN ('BMI (kg/m2)', 'Weight (Lbs)', 'Height (Inches)')
    GROUP BY result_name
""")
show(df_coverage, limit=False)

Unnamed: 0,result_name,total_measurements,unique_patients
0,Weight (Lbs),2145353,166872
1,Height (Inches),814964,148359
2,BMI (kg/m2),1901496,153725


In [3]:
df = sql("SELECT * from hosp.patients")
show(df, limit=True)

Unnamed: 0,subject_id,gender,anchor_age,anchor_year,anchor_year_group,dod
0,10000032,F,52,2180,2014 - 2016,2180-09-09
1,10000048,F,23,2126,2008 - 2010,NaT
2,10000058,F,33,2168,2020 - 2022,NaT
3,10000068,F,19,2160,2008 - 2010,NaT
4,10000084,M,72,2160,2017 - 2019,2161-02-13
...,...,...,...,...,...,...
364622,19999828,F,46,2147,2017 - 2019,NaT
364623,19999829,F,28,2186,2008 - 2010,NaT
364624,19999840,M,58,2164,2008 - 2010,2164-09-17
364625,19999914,F,49,2158,2017 - 2019,NaT


In [4]:
df = sql("SELECT * from hosp.admissions")
show(df, limit=True)

Unnamed: 0,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admit_provider_id,admission_location,discharge_location,insurance,language,marital_status,race,edregtime,edouttime,hospital_expire_flag
0,10000032,22595853,2180-05-06 22:23:00,2180-05-07 17:15:00,NaT,URGENT,P49AFC,TRANSFER FROM HOSPITAL,HOME,Medicaid,English,WIDOWED,WHITE,2180-05-06 19:17:00,2180-05-06 23:30:00,0
1,10000032,22841357,2180-06-26 18:27:00,2180-06-27 18:49:00,NaT,EW EMER.,P784FA,EMERGENCY ROOM,HOME,Medicaid,English,WIDOWED,WHITE,2180-06-26 15:54:00,2180-06-26 21:31:00,0
2,10000032,25742920,2180-08-05 23:44:00,2180-08-07 17:50:00,NaT,EW EMER.,P19UTS,EMERGENCY ROOM,HOSPICE,Medicaid,English,WIDOWED,WHITE,2180-08-05 20:58:00,2180-08-06 01:44:00,0
3,10000032,29079034,2180-07-23 12:35:00,2180-07-25 17:55:00,NaT,EW EMER.,P06OTX,EMERGENCY ROOM,HOME,Medicaid,English,WIDOWED,WHITE,2180-07-23 05:54:00,2180-07-23 14:00:00,0
4,10000068,25022803,2160-03-03 23:16:00,2160-03-04 06:26:00,NaT,EU OBSERVATION,P39NWO,EMERGENCY ROOM,,,English,SINGLE,WHITE,2160-03-03 21:55:00,2160-03-04 06:26:00,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
546023,19999828,25744818,2149-01-08 16:44:00,2149-01-18 17:00:00,NaT,EW EMER.,P13JMH,TRANSFER FROM HOSPITAL,HOME HEALTH CARE,Medicaid,English,SINGLE,WHITE,2149-01-08 09:11:00,2149-01-08 18:12:00,0
546024,19999828,29734428,2147-07-18 16:23:00,2147-08-04 18:10:00,NaT,EW EMER.,P38XL8,PHYSICIAN REFERRAL,HOME HEALTH CARE,Medicaid,English,SINGLE,WHITE,2147-07-17 17:18:00,2147-07-18 17:34:00,0
546025,19999840,21033226,2164-09-10 13:47:00,2164-09-17 13:42:00,2164-09-17 13:42:00,EW EMER.,P33612,EMERGENCY ROOM,DIED,Private,English,WIDOWED,WHITE,2164-09-10 11:09:00,2164-09-10 14:46:00,1
546026,19999840,26071774,2164-07-25 00:27:00,2164-07-28 12:15:00,NaT,EW EMER.,P036NA,EMERGENCY ROOM,HOME,Private,English,WIDOWED,WHITE,2164-07-24 21:16:00,2164-07-25 01:20:00,0


# Patient Demographics & Time-to-Event Extraction

Extracts patient demographic data and calculates time-to-event metrics for survival analysis.

## Data Structure
- **Unique Identifier**: `subject_id`
- **Demographics**: gender, age, insurance, language, marital status, race
- **Baseline**: First admission timestamp (`t0_time`)
- **Event**: ICD-9/10 diagnosis code occurrence
- **Censoring**: Date of death (`dod`) or end of follow-up

## Output
Generates a dataset with:
- Demographic features
- Time to event (days from baseline)
- Event indicator (occurred/censored)
- Target classes: Early event (<1yr), late event (>1yr), or censored

In [5]:
df_baseline = sql("""SELECT 
    p.subject_id,
    p.gender,
    p.anchor_age,
    p.dod,
    -- Daten aus der ersten Aufnahme (Baseline)
    first_admit.admittime AS t0_time,
    first_admit.insurance,
    first_admit.language,
    first_admit.marital_status,
    first_admit.race,
    first_admit.admission_type
FROM hosp.patients p
INNER JOIN (
    -- Subquery, um nur die zeitlich erste Aufnahme pro Patient zu finden
    SELECT 
        subject_id, 
        admittime, 
        insurance, 
        language, 
        marital_status, 
        race, 
        admission_type,
        ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY admittime ASC) as row_num
    FROM hosp.admissions
) first_admit ON p.subject_id = first_admit.subject_id
WHERE first_admit.row_num = 1""")
show(df_baseline,limit=True)

Unnamed: 0,subject_id,gender,anchor_age,dod,t0_time,insurance,language,marital_status,race,admission_type
0,10001472,F,35,NaT,2186-01-10 00:00:00,Private,English,MARRIED,WHITE,URGENT
1,10002804,M,64,NaT,2148-09-28 01:21:00,Private,English,MARRIED,UNKNOWN,EW EMER.
2,10003502,F,86,2169-09-10,2161-06-29 14:34:00,Medicare,Russian,MARRIED,WHITE,EW EMER.
3,10003637,M,57,2150-05-22,2145-01-04 19:56:00,Medicaid,English,DIVORCED,PORTUGUESE,URGENT
4,10004720,M,61,2186-11-17,2186-11-12 18:01:00,Medicare,English,SINGLE,WHITE,EW EMER.
...,...,...,...,...,...,...,...,...,...,...
223447,19994259,F,41,NaT,2128-11-20 00:17:00,Private,English,SINGLE,WHITE,EW EMER.
223448,19995732,F,91,NaT,2115-03-20 02:39:00,Medicare,Russian,WIDOWED,WHITE,EW EMER.
223449,19997072,F,37,NaT,2164-08-31 07:15:00,Private,English,MARRIED,WHITE,SURGICAL SAME DAY ADMISSION
223450,19997471,M,88,NaT,2120-08-19 02:31:00,Medicare,English,MARRIED,WHITE,DIRECT OBSERVATION


In [6]:
# Zählt die eindeutigen Werte
anzahl = len(df["subject_id"].unique())
anzahl


223452

In [7]:
# Prüft auf Unique-Status
ist_unique = len(df["subject_id"].unique()) == len(df["subject_id"])
ist_unique

False

# Event Time

Extracts the first occurrence of the target diagnosis (ICD-50 code, Beispiel: Herzinsuffizienz) for each patient.

Returns:
- `subject_id`: Patient identifier
- `event_time`: Admission timestamp of first diagnosis occurrence

In [8]:
data = sql("""SELECT d.*, i.long_title
FROM hosp.diagnoses_icd d
LEFT JOIN hosp.d_icd_diagnoses i ON d.icd_code = i.icd_code
WHERE d.icd_code LIKE 'I50%'""")
show(data, limit=True)

Unnamed: 0,subject_id,hadm_id,seq_num,icd_code,icd_version,long_title
0,10912602,23708919,8,I5032,10,Chronic diastolic (congestive) heart failure
1,10912800,24003151,2,I5032,10,Chronic diastolic (congestive) heart failure
2,10912800,24683894,3,I5032,10,Chronic diastolic (congestive) heart failure
3,10913302,20242491,5,I5022,10,Chronic systolic (congestive) heart failure
4,10913302,20317008,8,I5022,10,Chronic systolic (congestive) heart failure
...,...,...,...,...,...,...
43908,19996673,29017569,6,I5022,10,Chronic systolic (congestive) heart failure
43909,19997448,23560173,3,I5030,10,Unspecified diastolic (congestive) heart failure
43910,19997473,27787494,3,I5023,10,Acute on chronic systolic (congestive) heart f...
43911,19997752,29452285,5,I5030,10,Unspecified diastolic (congestive) heart failure


In [9]:
df_event = sql("""SELECT 
    d.subject_id,
    MIN(a.admittime) AS event_time
FROM hosp.diagnoses_icd d
JOIN hosp.admissions a ON d.hadm_id = a.hadm_id
WHERE d.icd_code LIKE 'I50%' 
GROUP BY d.subject_id
                         """)
show(df_event, limit=True)

Unnamed: 0,subject_id,event_time
0,10913779,2205-04-21 23:23:00
1,10918704,2129-10-30 01:50:00
2,10921049,2141-01-04 19:29:00
3,10925427,2128-05-30 00:00:00
4,10930646,2181-11-09 18:11:00
...,...,...
18885,19690756,2136-08-11 18:30:00
18886,19693734,2163-08-27 01:39:00
18887,19695463,2137-11-07 23:27:00
18888,19696359,2131-06-24 23:13:00


# Time-to-Event Calculation & Target Definition

Calculates time intervals and creates target classes based on event occurrence and timing.

## Steps:
1. **Join datasets**: Merges baseline demographics with event times (left join)
2. **Calculate time differences**: 
   - `t_event`: Days from baseline to diagnosis
   - `t_death`: Days from baseline to death
3. **Determine duration**: Uses event time if available, otherwise death time, fallback to 2000 days
4. **Handle errors**: Clips negative durations to 0
5. **Create target classes**:
   - Class 0: Early event (<365 days)
   - Class 1: Late event (>365 days)
   - Class 2: Censored (no event)

In [10]:
import polars as pl

df_final = (
    df_baseline.join(df_event, on="subject_id", how="left")
    # 1. Zeitdifferenzen und Event-Indikator berechnen
    .with_columns([
        ((pl.col("event_time") - pl.col("t0_time")).dt.total_days()).alias("t_event"),
        ((pl.col("dod") - pl.col("t0_time")).dt.total_days()).alias("t_death"),
        pl.col("event_time").is_not_null().cast(pl.Int32).alias("event_occurred")
    ])
    # 2. Duration festlegen (Priorität: Event > Tod > Fallback) und Clip
    .with_columns(
        pl.coalesce([
            pl.col("t_event"),
            pl.col("t_death"),
            pl.lit(2000)
        ])
        .clip(lower_bound=0)
        .alias("duration")
    )
    # 3. Zielvariable (Target) für TabPFN definieren
    .with_columns(
        pl.when((pl.col("event_occurred") == 1) & (pl.col("duration") <= 365))
        .then(0)     # Klasse 0: Früher Ausbruch (< 1 Jahr)
        .when((pl.col("event_occurred") == 1) & (pl.col("duration") > 365))
        .then(1)     # Klasse 1: Später Ausbruch (> 1 Jahr)
        .otherwise(2) # Klasse 2: Zensiert / Gesund / Kein Event
        .alias("target")
    )
)

# Ergebnis anzeigen
show(df_final, limit=True)

# Kurze Kontrolle der Verteilung
print(df_final["target"].value_counts())

Unnamed: 0,subject_id,gender,anchor_age,dod,t0_time,insurance,language,marital_status,race,admission_type,event_time,t_event,t_death,event_occurred,duration,target
0,10001472,F,35,NaT,2186-01-10 00:00:00,Private,English,MARRIED,WHITE,URGENT,NaT,,,0,2000,2
1,10002804,M,64,NaT,2148-09-28 01:21:00,Private,English,MARRIED,UNKNOWN,EW EMER.,NaT,,,0,2000,2
2,10003502,F,86,2169-09-10,2161-06-29 14:34:00,Medicare,Russian,MARRIED,WHITE,EW EMER.,2169-08-26 16:14:00,2980.0,2994.0,1,2980,1
3,10003637,M,57,2150-05-22,2145-01-04 19:56:00,Medicaid,English,DIVORCED,PORTUGUESE,URGENT,2146-01-22 23:08:00,383.0,1963.0,1,383,1
4,10004720,M,61,2186-11-17,2186-11-12 18:01:00,Medicare,English,SINGLE,WHITE,EW EMER.,NaT,,4.0,0,4,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
223447,19994259,F,41,NaT,2128-11-20 00:17:00,Private,English,SINGLE,WHITE,EW EMER.,NaT,,,0,2000,2
223448,19995732,F,91,NaT,2115-03-20 02:39:00,Medicare,Russian,WIDOWED,WHITE,EW EMER.,NaT,,,0,2000,2
223449,19997072,F,37,NaT,2164-08-31 07:15:00,Private,English,MARRIED,WHITE,SURGICAL SAME DAY ADMISSION,NaT,,,0,2000,2
223450,19997471,M,88,NaT,2120-08-19 02:31:00,Medicare,English,MARRIED,WHITE,DIRECT OBSERVATION,NaT,,,0,2000,2


shape: (3, 2)
┌────────┬────────┐
│ target ┆ count  │
│ ---    ┆ ---    │
│ i32    ┆ u32    │
╞════════╪════════╡
│ 0      ┆ 10742  │
│ 2      ┆ 204562 │
│ 1      ┆ 8148   │
└────────┴────────┘


# Feature & Target Extraction

Prepares data for machine learning by selecting demographic features and target variable.

## Output:
- **X**: Feature matrix with demographics (pandas DataFrame)
  - gender, anchor_age, insurance, language, marital_status, race, admission_type
- **y**: Target vector with class labels (numpy array)
  - 0: Early event (<365 days)
  - 1: Late event (>365 days)
  - 2: Censored (no event)

In [11]:
# Feature Engineering: BMI
# Strategie: Wir nutzen den vorgefertigten BMI Wert, da er eine höhere Abdeckung hat als (Weight & Height) kombiniert.
# Aggregation: Wir nehmen den Median pro Patient, um Ausreißer zu minimieren.

df_bmi = sql("""
    SELECT 
        subject_id,
        MEDIAN(CAST(result_value AS FLOAT)) as bmi
    FROM hosp.omr
    WHERE result_name = 'BMI (kg/m2)'
    -- Einfache Bereinigung: Nur plausible Werte (z.B. 10 bis 100)
    AND TRY_CAST(result_value AS FLOAT) BETWEEN 10 AND 100
    GROUP BY subject_id
""")

# Check distribution
show(df_bmi, limit=True)

Unnamed: 0,subject_id,bmi
0,10256360,24.000000
1,10257865,28.850000
2,10259262,25.000000
3,10260792,31.500000
4,10260836,28.400000
...,...,...
153536,19967684,25.799999
153537,19968774,31.000000
153538,19969139,28.049999
153539,19972623,28.100000


In [12]:
# Join to clean dataset
# Wir machen einen Left Join, da nicht alle Patienten einen BMI haben (Missing Values werden später von TabPFN handled!)
df_final = df_final.join(
    df_bmi, 
    on="subject_id", 
    how="left"
)

# Prüfen wie viele Missings wir haben
print(f"Missing BMI: {df_final['bmi'].null_count()} of {len(df_final)}")
show(df_final.select(['subject_id', 'anchor_age', 'bmi', 'target']), limit=True)

Missing BMI: 98306 of 223452


Unnamed: 0,subject_id,anchor_age,bmi,target
0,10001472,35,31.200001,2
1,10002804,64,29.250000,2
2,10003502,86,25.400000,1
3,10003637,57,29.200001,1
4,10004720,61,21.549999,2
...,...,...,...,...
223447,19994259,41,,2
223448,19995732,91,,2
223449,19997072,37,32.400002,2
223450,19997471,88,,2


In [13]:
import polars as pl
from sklearn.model_selection import train_test_split

# 1. Konvertierung für Scikit-Learn (Split-Logik)
df_pd = df_final.to_pandas()

# Erster Split: Trenne das finale Test-Set ab (20% Hold-out)
df_train_val_raw, df_test = train_test_split(
    df_pd, 
    test_size=0.2, 
    stratify=df_pd["target"], 
    random_state=42
)

# Zweiter Split: Trenne den Rest in Training und Validierung (z.B. 80/20 vom Rest)
df_train_raw, df_val = train_test_split(
    df_train_val_raw, 
    test_size=0.2, 
    stratify=df_train_val_raw["target"], 
    random_state=42
)

# --- BACK TO POLARS: Balancing nur für das TRAINING ---
# Wir nehmen 3.000 Samples pro Klasse, um innerhalb des TabPFN-Limits zu bleiben
df_train_pl = pl.from_pandas(df_train_raw)
n_samples = 3000

df_balanced_train = pl.concat([
    df_train_pl.filter(pl.col("target") == 0).sample(n=min(n_samples, df_train_pl.filter(pl.col("target")==0).height), seed=42),
    df_train_pl.filter(pl.col("target") == 1).sample(n=min(n_samples, df_train_pl.filter(pl.col("target")==1).height), seed=42),
    df_train_pl.filter(pl.col("target") == 2).sample(n=n_samples, seed=42)
]).sample(fraction=1.0, shuffle=True, seed=42)

# --- FINALE FEATURES & TARGETS ---
feature_cols = ["gender", "anchor_age", "insurance", "language", "marital_status", "race", "admission_type", "bmi"]

# Training (Balanced)
X_train = df_balanced_train.select(feature_cols).to_pandas()
y_train = df_balanced_train.select("target").to_series().to_numpy()

# Validation (Realistisch, nicht balanciert)
X_val = df_val[feature_cols]
y_val = df_val["target"].values

# Test (Realistisch, nicht balanciert - Hold-out)
X_test = df_test[feature_cols]
y_test = df_test["target"].values

print(f"Train (balanced): {len(y_train)} | Val (real): {len(y_val)} | Test (real): {len(y_test)}")

KeyboardInterrupt: 

> ℹ️ **INFO**  
> Why Train set is smaller than val and test. The Train set is balanced and gives a total of 9k for the TabPFN Algo. The Test and Val set is larger because it's a % of the total dataset (200k)

# Tab PFN
- training

In [None]:
# 1. Training-Set massiv verkleinern (für Speed & Datenschutz-Konformität)
# Wir nehmen nur 333 Samples pro Klasse = ca. 1000 Total
n_emergency = 333

df_emergency_train = pl.concat([
    df_train_pl.filter(pl.col("target") == 0).sample(n=min(n_emergency, df_train_pl.filter(pl.col("target")==0).height)),
    df_train_pl.filter(pl.col("target") == 1).sample(n=min(n_emergency, df_train_pl.filter(pl.col("target")==1).height)),
    df_train_pl.filter(pl.col("target") == 2).sample(n=n_emergency)
]).sample(fraction=1.0, shuffle=True)

X_train_fast = df_emergency_train.select(feature_cols).to_pandas()
y_train_fast = df_emergency_train.select("target").to_series().to_numpy()

# 2. Classifier neu initialisieren (CPU ist bei 1000 Samples oft stabiler als ein buggy MPS-Backend)
from tabpfn import TabPFNClassifier
classifier = TabPFNClassifier(device='cpu') # Zurück auf CPU, da N jetzt klein ist

# 3. Fit (Sekundensache)
classifier.fit(X_train_fast, y_train_fast)

# 4. Vorhersage für 500 Test-Patienten (ohne Batches, einfach direkt)
X_test_fast = X_val.iloc[:500]
y_test_fast = y_val[:500]

print("Starte Express-Vorhersage...")
y_val_pred_fast = classifier.predict(X_test_fast)
print("Fertig!")

Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client
  _validate_num_samples_for_cpu(


Starte Express-Vorhersage...
Fertig!


# Evaluation der TabPFN Analyse

In [None]:
import plotly.figure_factory as ff
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd

# 1. Classification Report ausgeben (Text-Form für die harten Fakten)
print("EXPRESS-CHECK REPORT (N_train=1000):")
report_dict = classification_report(y_test_fast, y_val_pred_fast, 
                                   target_names=['Früh (<1J)', 'Spät (1-3J)', 'Gesund'],
                                   output_dict=True,
                                   zero_division=0)
print(classification_report(y_test_fast, y_val_pred_fast, 
                           target_names=['Früh (<1J)', 'Spät (1-3J)', 'Gesund']))

import plotly.figure_factory as ff
import numpy as np

# 1. Normalisierung der Confusion Matrix (Zeilenweise)
# Wie viel % der tatsächlichen Klasse wurden wie vorhergesagt?
cm = confusion_matrix(y_test_fast, y_val_pred_fast)
cm_perc = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

labels = ['Früh (<1J)', 'Spät (1-3J)', 'Gesund']

# 2. Text für die Boxen (Absolute Zahl + Prozent)
annot_text = [
    [f"<b>{val}</b><br>({perc:.1%})" for val, perc in zip(row_val, row_perc)]
    for row_val, row_perc in zip(cm, cm_perc)
]

# 3. Interaktive Heatmap
fig = ff.create_annotated_heatmap(
    cm_perc, 
    x=labels, 
    y=labels, 
    annotation_text=annot_text, 
    colorscale='Reds' # 'Reds' hebt die Treffer besser hervor
)

fig.update_layout(
    title='Zusammenhang: Soziale Faktoren & Wiederaufnahme (Normalisiert)',
    xaxis_title="Vorhersage des Modells",
    yaxis_title="Tatsächlicher Verlauf (MIMIC-Daten)",
    template="plotly_white",
    height=600
)

fig.show()

EXPRESS-CHECK REPORT (N_train=1000):
              precision    recall  f1-score   support

  Früh (<1J)       0.10      0.52      0.17        21
 Spät (1-3J)       0.08      0.87      0.15        15
      Gesund       0.98      0.50      0.66       464

    accuracy                           0.51       500
   macro avg       0.39      0.63      0.33       500
weighted avg       0.92      0.51      0.63       500



In [None]:
import plotly.graph_objects as go

# Daten aus deiner Confusion Matrix extrahieren
# Reihenfolge: [Früh, Spät, Gesund]
cm = confusion_matrix(y_test_fast, y_val_pred_fast)

# Labels für die Knoten
label_list = [
    "Tatsächlich: Früh", "Tatsächlich: Spät", "Tatsächlich: Gesund", # Quellen (Links)
    "Vorhergesagt: Früh", "Vorhergesagt: Spät", "Vorhergesagt: Gesund" # Ziele (Rechts)
]

# Definition der Flüsse (Sankey-Struktur)
source = [0, 0, 0, 1, 1, 1, 2, 2, 2] # Index der Quellen
target = [3, 4, 5, 3, 4, 5, 3, 4, 5] # Index der Ziele
value = cm.flatten() # Die Zahlen aus deiner Matrix

# Farben definieren (Grün für korrekt, Rot für Fehler)
color_link = [
    'rgba(31, 119, 180, 0.4)', 'rgba(31, 119, 180, 0.2)', 'rgba(31, 119, 180, 0.1)', # Von Früh
    'rgba(255, 127, 14, 0.2)', 'rgba(255, 127, 14, 0.4)', 'rgba(255, 127, 14, 0.1)', # Von Spät
    'rgba(44, 160, 44, 0.1)', 'rgba(44, 160, 44, 0.1)', 'rgba(44, 160, 44, 0.4)'    # Von Gesund
]

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15, thickness = 20, line = dict(color = "black", width = 0.5),
      label = label_list, color = "blue"
    ),
    link = dict(
      source = source, target = target, value = value, color = color_link
  ))])

fig.update_layout(title_text="Patienten-Fluss: Realität vs. KI-Vorhersage", font_size=12)
fig.show()

In [None]:
from sklearn.inspection import permutation_importance
import pandas as pd
import plotly.express as px

print("Berechne Feature Importance (kann ca. 1-2 Min. dauern oder 5)...")

# Wir nutzen unser Express-Testset für die Berechnung
result = permutation_importance(
    classifier, X_test_fast, y_test_fast, n_repeats=10, random_state=42, n_jobs=1
)

# Ergebnisse in ein DataFrame gießen
importance_df = pd.DataFrame({
    'Feature': feature_cols,
    'Importance': result.importances_mean,
    'Std_Dev': result.importances_std
}).sort_values(by='Importance', ascending=True) # Aufsteigend für horizontalen Plot

# Plotly Bar Chart
fig = px.bar(
    importance_df, 
    x='Importance', 
    y='Feature', 
    orientation='h',
    title='Welche sozialen Faktoren beeinflussen die Wiederaufnahme?',
    labels={'Importance': 'Wichtigkeit (Mean Decrease Accuracy)'}, # Wissenschaftlicher Label
    error_x='Std_Dev', # Zeigt die Variabilität der Wichtigkeit
    template="plotly_white",
    color='Importance',
    color_continuous_scale='Reds'
)

# NEU: Prozent-Formatierung für bessere Lesbarkeit
fig.update_layout(height=500, xaxis_tickformat='.1%')
fig.show()

Berechne Feature Importance (kann ca. 1-2 Min. dauern oder 5)...


In [None]:
# Check: Zusammenhang zwischen Alter und Versicherung (Medicare)
# Wir schauen, wie hoch das Durchschnittsalter pro Versicherungsgruppe ist.
import plotly.express as px

df_check = X_test_fast.copy() # Kopie für Analyse
if hasattr(df_check, "to_pandas"):
    df_check = df_check.to_pandas()

# Boxplot zeigt klar: Medicare-Patienten sind fast alle 65+
fig = px.box(df_check, x="insurance", y="anchor_age", 
             title="Warum Alter unwichtig wirkt: Der 'Medicare'-Effekt",
             points="all", 
             color="insurance")
fig.show()

# Detaillierte Risiko-Analyse

Hier untersuchen wir, welche spezifischen Gruppen (Alter, BMI, Geschlecht) das höchste Risiko für einen **frühen Krankheitsausbruch (<1 Jahr)** tragen.

Wir nutzen die vom Modell vorhergesagte Wahrscheinlichkeit (`predict_proba`) für Klasse 0 (Early Event) als **Risk Score**.

In [None]:
df_analyze = X_test_fast.copy()
if hasattr(df_analyze, "to_pandas"):
    df_analyze = df_analyze.to_pandas()

print(f"Analyse basierend auf {len(df_analyze)} Patienten aus dem Test-Set.")

Analyse basierend auf 500 Patienten aus dem Test-Set.


In [None]:
import pandas as pd
import plotly.express as px

# 1. Daten vorbereiten (Scores berechnen)
y_proba_fast = classifier.predict_proba(X_test_fast)
# Wahrscheinlichkeit für Klasse 0 (Early Event)
risk_score = y_proba_fast[:, 0]

df_analyze['risk_score'] = risk_score
df_analyze['true_label'] = y_test_fast


In [None]:
# ---------------------------------------------------------
# 2. BMI Analyse (Vereinfacht: Balkendiagramm)
# ---------------------------------------------------------
# Wir schauen uns nur den DURCHSCHNITT an, das ist einfacher zu lesen als Boxplots.
bins = [0, 18.5, 25, 30, 100]
labels = ['Untergewicht (<18.5)', 'Normal (18.5-25)', 'Übergewicht (25-30)', 'Adipositas (>30)']
df_analyze['bmi_group'] = pd.cut(df_analyze['bmi'], bins=bins, labels=labels)

df_bmi_mean = df_analyze.groupby('bmi_group', observed=True)['risk_score'].mean().reset_index()

fig1 = px.bar(
    df_bmi_mean, 
    x='bmi_group', 
    y='risk_score',
    text_auto='.1%', # Zeigt %-Wert direkt auf dem Balken
    title='Durchschnittliches Risiko nach BMI-Gruppe',
    labels={'risk_score': 'Wahrscheinlichkeit (Früher Ausbruch)', 'bmi_group': 'BMI Gruppe'},
    color='risk_score',
    color_continuous_scale='Reds',
    template="plotly_white"
)
fig1.update_layout(yaxis_tickformat='.0%') # Y-Achse als Prozent formatieren
fig1.show()

In [None]:
# ---------------------------------------------------------
# 3. Alters Analyse (Vereinfacht)
# ---------------------------------------------------------
# Bins erweitert, um alle Altersgruppen zu erfassen (0-10, 10-20, ..., 90+)
df_analyze['age_group'] = pd.cut(
    df_analyze['anchor_age'], 
    bins=[0, 20, 30, 40, 50, 60, 70, 80, 90, 120],  # 0-20, 20-30, ..., 90-120
    labels=['<20', '20-29', '30-39', '40-49', '50-59', '60-69', '70-79', '80-89', '90+']
)
df_analyze['age_group'] = df_analyze['age_group'].astype(str)

df_age_mean = df_analyze.groupby('age_group')['risk_score'].mean().reset_index()

fig2 = px.bar(
    df_age_mean, 
    x='age_group', 
    y='risk_score',
    text_auto='.1%',
    title='Risiko-Entwicklung über das Alter',
    labels={'risk_score': 'Wahrscheinlichkeit (Früher Ausbruch)', 'age_group': 'Altersgruppe'},
    color='risk_score',
    color_continuous_scale='Reds',
    template="plotly_white"
)
fig2.update_layout(yaxis_tickformat='.0%')
fig2.show()

In [None]:
# ---------------------------------------------------------
# 4. Sozio-Ökonomisch (Vereinfacht: Balken nebeneinander)
# ---------------------------------------------------------
# Statt komplizierter Violin-Plots nutzen wir gruppierte Balken.
df_ins_mean = df_analyze.groupby(['insurance', 'gender'])['risk_score'].mean().reset_index()

fig3 = px.bar(
    df_ins_mean, 
    x='insurance', 
    y='risk_score', 
    color='gender', 
    barmode='group', # Männlich/Weiblich nebeneinander
    text_auto='.1%',
    title='Vergleich: Wer hat das höchste Risiko? (Versicherung & Geschlecht)',
    labels={'risk_score': 'Wahrscheinlichkeit (Ø)', 'insurance': 'Versicherung'},
    template="plotly_white"
)
fig3.update_layout(yaxis_tickformat='.0%')
fig3.show()

# Robustness Experiment (Stabilitätstest)

Um zu beweisen, dass die Ergebnisse kein Zufall sind, führen wir das Training **30 Mal** mit unterschiedlichen Seeds durch.

**Ablauf:**
1. Ziehe in jedem Durchlauf ein **neues, balanciertes Training-Set** (Sampling Variation).
2. Trainiere ein frisches TabPFN Modell.
3. Evaluiere auf einem **fixen Validierungs-Set**.
4. Speichere Metriken (F1-Score, Accuracy) in einer CSV und visualisiere die Varianz.

Dies simuliert, wie robust das Modell gegenüber Veränderungen in den Trainingsdaten ist.

In [None]:
import os
import time
import datetime
import pandas as pd
import numpy as np
import polars as pl
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from sklearn.inspection import permutation_importance
from tabpfn import TabPFNClassifier

# 1. Konfiguration
IS_TEST_RUN = True  # KIPPSCHALTER
TIMESTAMP = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

if IS_TEST_RUN:
    N_LOOPS = 1
    BASE_DIR = "../../reports/robustness_experiment_test"
    print(f"⚠️ TEST-LAUF: Starte nur {N_LOOPS} Runde.")
else:
    N_LOOPS = 30
    BASE_DIR = "../../reports/robustness_experiment"
    print(f"🚀 PRODUKTIV-LAUF: Starte {N_LOOPS} Runden.")

RESULTS_DIR = os.path.join(BASE_DIR, TIMESTAMP)
os.makedirs(RESULTS_DIR, exist_ok=True)

X_val_robust = X_val.iloc[:1000].copy()
y_val_robust = y_val[:1000].copy()
results_list = []

print(f"🚀 Starte Experiment. Ergebnisse landen in: {RESULTS_DIR}")

# 2. Experiment Loop
for i in range(N_LOOPS):
    start_time = time.time()
    
    # Ordner für diesen Run
    run_dir = os.path.join(RESULTS_DIR, f"run_{i}")
    os.makedirs(run_dir, exist_ok=True)
    
    # A. Resampling
    current_seed = 42 + i
    n_per_class = 333
    
    df_iter_train = pl.concat([
        df_train_pl.filter(pl.col("target") == 0).sample(n=min(n_per_class, df_train_pl.filter(pl.col("target")==0).height), seed=current_seed),
        df_train_pl.filter(pl.col("target") == 1).sample(n=min(n_per_class, df_train_pl.filter(pl.col("target")==1).height), seed=current_seed),
        df_train_pl.filter(pl.col("target") == 2).sample(n=n_per_class, seed=current_seed)
    ]).sample(fraction=1.0, shuffle=True, seed=current_seed)
    
    X_train_iter = df_iter_train.select(feature_cols).to_pandas()
    y_train_iter = df_iter_train.select("target").to_series().to_numpy()
    
    # B. Training
    clf = TabPFNClassifier(device='cpu') 
    clf.fit(X_train_iter, y_train_iter)
    
    # C. Prediction
    y_pred = clf.predict(X_val_robust)
    y_proba = clf.predict_proba(X_val_robust)
    
    # Metriken
    f1_macro = f1_score(y_val_robust, y_pred, average='macro')
    acc = accuracy_score(y_val_robust, y_pred)
    prec_macro = precision_score(y_val_robust, y_pred, average='macro', zero_division=0)
    rec_macro = recall_score(y_val_robust, y_pred, average='macro', zero_division=0)
    f1_per_class = f1_score(y_val_robust, y_pred, average=None)

    # Feature Importance
    perm_result = permutation_importance(clf, X_val_robust, y_val_robust, n_repeats=2, random_state=42, n_jobs=1)
    
    duration = time.time() - start_time

    # --- RISK DIRECTION ANALYSE ---
    # Wir speichern durchschnittliche Risikoscores pro Gruppe in die CSV
    df_risk_check = X_val_robust.copy()
    df_risk_check['risk_score'] = y_proba[:, 0] # Klasse 0 = Early Event

    risk_stats = {}
    
    # 1. Kategorische Features (Gender, Insurance, etc.)
    cat_feats = [c for c in feature_cols if c not in ['bmi', 'anchor_age']]
    for cf in cat_feats:
        means = df_risk_check.groupby(cf, observed=True)['risk_score'].mean()
        for cat_val, val in means.items():
            # CSV-freundliche Spaltennamen
            safe_name = str(cat_val).replace(' ', '_').replace('/', '_').replace('-', '_').lower()[:20]
            risk_stats[f'risk_{cf}_{safe_name}'] = val

    # 2. Numerische Features (Bins)
    # BMI
    df_risk_check['bmi_grp'] = pd.cut(df_risk_check['bmi'], bins=[0, 18.5, 25, 30, 100], labels=['under', 'norm', 'over', 'obese'])
    for cat_val, val in df_risk_check.groupby('bmi_grp', observed=True)['risk_score'].mean().items():
        risk_stats[f'risk_bmi_{cat_val}'] = val
        
    # Age (Dekaden für Übersicht)
    df_risk_check['age_grp'] = pd.cut(df_risk_check['anchor_age'], bins=[0, 30, 50, 70, 90, 120], labels=['u30', '30_50', '50_70', '70_90', '90plus'])
    for cat_val, val in df_risk_check.groupby('age_grp', observed=True)['risk_score'].mean().items():
        risk_stats[f'risk_age_{cat_val}'] = val
    
    # Speichern
    result_dict = {
        'run_id': i, 'seed': current_seed, 'accuracy': acc, 'f1_macro': f1_macro,
        'precision_macro': prec_macro, 'recall_macro': rec_macro,
        'f1_class_0_early': f1_per_class[0], 'f1_class_1_late': f1_per_class[1], 'f1_class_2_healthy': f1_per_class[2],
        'duration_sec': duration
    }
    # Feature Importance hinzufügen
    for feat_name, importance_val in zip(feature_cols, perm_result.importances_mean):
        result_dict[f'imp_{feat_name}'] = importance_val
        
    # Risk-Stats hinzufügen
    result_dict.update(risk_stats)
    results_list.append(result_dict)

    # --- E. PLOTTING PRO RUN ---
    
    # 1. Confusion Matrix
    cm = confusion_matrix(y_val_robust, y_pred)
    cm_perc = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    labels = ['Früh', 'Spät', 'Gesund']
    annot_text = [[f"{val}<br>({perc:.1%})" for val, perc in zip(r, rp)] for r, rp in zip(cm, cm_perc)]
    fig_cm = ff.create_annotated_heatmap(cm_perc, x=labels, y=labels, annotation_text=annot_text, colorscale='Reds')
    fig_cm.update_layout(title=f'Confusion Matrix (Run {i})', height=400, width=500)
    fig_cm.write_image(os.path.join(run_dir, "confusion_matrix.png"))

    # 1b. Sankey Diagram
    label_list = ["Tatsächlich: Früh", "Tatsächlich: Spät", "Tatsächlich: Gesund", 
                  "Vorhergesagt: Früh", "Vorhergesagt: Spät", "Vorhergesagt: Gesund"]
    source = [0, 0, 0, 1, 1, 1, 2, 2, 2] 
    target = [3, 4, 5, 3, 4, 5, 3, 4, 5] 
    value = cm.flatten()
    color_link = [
        'rgba(31, 119, 180, 0.4)', 'rgba(31, 119, 180, 0.2)', 'rgba(31, 119, 180, 0.1)', # Von Früh
        'rgba(255, 127, 14, 0.2)', 'rgba(255, 127, 14, 0.4)', 'rgba(255, 127, 14, 0.1)', # Von Spät
        'rgba(44, 160, 44, 0.1)', 'rgba(44, 160, 44, 0.1)', 'rgba(44, 160, 44, 0.4)'    # Von Gesund
    ]
    fig_sankey = go.Figure(data=[go.Sankey(
        node = dict(pad = 15, thickness = 20, line = dict(color = "black", width = 0.5), label = label_list, color = "blue"),
        link = dict(source = source, target = target, value = value, color = color_link))])
    fig_sankey.update_layout(title_text=f"Patienten-Fluss (Run {i})", font_size=12, height=500)
    fig_sankey.write_image(os.path.join(run_dir, "sankey_flow.png"))
    
    # 2. Risk Plots Data Prep
    df_plot_run = X_val_robust.copy()
    df_plot_run['risk_score'] = y_proba[:, 0]
    
    # a) BMI
    bins_bmi = [0, 18.5, 25, 30, 100]
    labels_bmi = ['Untergewicht (<18.5)', 'Normal (18.5-25)', 'Übergewicht (25-30)', 'Adipositas (>30)']
    df_plot_run['bmi_group'] = pd.cut(df_plot_run['bmi'], bins=bins_bmi, labels=labels_bmi)
    df_bmi_agg = df_plot_run.groupby('bmi_group', observed=True)['risk_score'].mean().reset_index()
    
    fig_bmi = px.bar(df_bmi_agg, x='bmi_group', y='risk_score', text_auto='.1%', 
                     title=f'Risiko nach BMI-Gruppe (Run {i})', 
                     labels={'risk_score': 'Wahrscheinlichkeit', 'bmi_group': 'BMI Gruppe'},
                     template="plotly_white", height=400)
    fig_bmi.update_layout(yaxis_tickformat='.0%')
    fig_bmi.write_image(os.path.join(run_dir, "risk_bmi.png"))
    
    # b) Age
    bins_age = [0, 20, 30, 40, 50, 60, 70, 80, 90, 120]
    labels_age = ['<20', '20-29', '30-39', '40-49', '50-59', '60-69', '70-79', '80-89', '90+']
    df_plot_run['age_group'] = pd.cut(df_plot_run['anchor_age'], bins=bins_age, labels=labels_age)
    df_plot_run['age_group'] = df_plot_run['age_group'].astype(str)
    
    df_age_agg = df_plot_run.groupby('age_group')['risk_score'].mean().reset_index()
    
    fig_age = px.bar(df_age_agg, x='age_group', y='risk_score', text_auto='.1%', 
                     title=f'Risiko-Entwicklung über das Alter (Run {i})', 
                     labels={'risk_score': 'Wahrscheinlichkeit', 'age_group': 'Altersgruppe'},
                     template="plotly_white", height=400)
    fig_age.update_layout(yaxis_tickformat='.0%')
    fig_age.write_image(os.path.join(run_dir, "risk_age.png"))

    # c) Insurance & Gender
    df_ins_agg = df_plot_run.groupby(['insurance', 'gender'])['risk_score'].mean().reset_index()
    fig_ins = px.bar(df_ins_agg, x='insurance', y='risk_score', color='gender', barmode='group',
                     text_auto='.1%', 
                     title=f'Risiko: Versicherung & Geschlecht (Run {i})', 
                     labels={'risk_score': 'Wahrscheinlichkeit', 'insurance': 'Versicherung'},
                     template="plotly_white", height=400)
    fig_ins.update_layout(yaxis_tickformat='.0%')
    fig_ins.write_image(os.path.join(run_dir, "risk_insurance_gender.png"))

    # 3. Feature Importance 
    imp_df_run = pd.DataFrame({
        'Feature': feature_cols,
        'Importance': perm_result.importances_mean,
        'Std_Dev': perm_result.importances_std
    }).sort_values(by='Importance', ascending=True)

    fig_imp = px.bar(imp_df_run, x='Importance', y='Feature', orientation='h',
                     title=f'Feature Importance (Run {i})',
                     labels={'Importance': 'Wichtigkeit (Mean Decrease Accuracy)'}, error_x='Std_Dev',
                     template="plotly_white", color='Importance', color_continuous_scale='Reds')
    # NEU: Auch innerhalb der Runs Prozent-Formatierung für die Bilder
    fig_imp.update_layout(height=400, showlegend=False, xaxis_tickformat='.1%')
    fig_imp.write_image(os.path.join(run_dir, "feature_importance.png"))

    if i == 0 or (i+1) % 5 == 0:
        print(f"   Run {i+1}/{N_LOOPS}: F1={f1_macro:.3f} | Bilder in {run_dir}/")

# 3. Speichern
df_results = pd.DataFrame(results_list)
csv_path = os.path.join(RESULTS_DIR, "robustness_metrics_with_importance.csv")
df_results.to_csv(csv_path, index=False)
print(f"✅ Experiment abgeschlossen. Daten gespeichert in: {csv_path}")

⚠️ TEST-LAUF: Starte nur 1 Runde.
🚀 Starte Experiment. Ergebnisse landen in: ../../reports/robustness_experiment_test/2026-02-14_15-35-56



Running on CPU with more than 200 samples may be slow.
Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client



   Run 1/1: F1=0.318 | Bilder in ../../reports/robustness_experiment_test/2026-02-14_15-35-56/run_0/
✅ Experiment abgeschlossen. Daten gespeichert in: ../../reports/robustness_experiment_test/2026-02-14_15-35-56/robustness_metrics_with_importance.csv


In [None]:
# 4. Visualisierung der Robustness (Master-Thesis Style)
# Ziel: Durchschnittliche Performance zeigen + Stabilität (Fehlerbalken) beweisen

# Melt transformiert die Daten für Plotly (Wide -> Long Format)
df_melt = df_results.melt(
    id_vars=['run_id'], 
    value_vars=['f1_class_0_early', 'f1_class_1_late', 'f1_class_2_healthy'],
    var_name='Target Class', 
    value_name='F1 Score'
)

# Namen verschönern
df_melt['Target Class'] = df_melt['Target Class'].replace({
    'f1_class_0_early': 'Früher Ausbruch (<1J)',
    'f1_class_1_late': 'Später Ausbruch (1-3J)', 
    'f1_class_2_healthy': 'Kein Event / Gesund'
})

# Aggegierte Statistiken für Bar-Chart
df_stats = df_melt.groupby('Target Class')['F1 Score'].agg(['mean', 'std']).reset_index()

# Sortierung manuell festlegen (Logische Zeitreihe)
category_order = ['Früher Ausbruch (<1J)', 'Später Ausbruch (1-3J)', 'Kein Event / Gesund']

# 1. Bar Chart mit Error Bars (Klassisch wissenschaftlich)
fig = px.bar(
    df_stats, 
    x="Target Class", 
    y="mean", 
    error_y="std", # Zeigt die Standardabweichung als Antenne (Robustheit)
    title=f"Modell-Performance: Durchschnitt & Stabilität ({N_LOOPS} Runs)",
    text_auto='.1%', # Beschriftung direkt am Balken
    labels={'mean': 'Durchschnittlicher F1-Score'},
    color="Target Class", 
    # Pastell Farben wirken professioneller und weniger überladen
    color_discrete_sequence=px.colors.qualitative.Pastel,
    template="plotly_white",
    category_orders={"Target Class": category_order} # Erzwingt die logische zeitliche Reihenfolge
)

# Balken leicht transparent machen (0.8), damit sie nicht zu massiv wirken
fig.update_traces(marker_opacity=0.8, showlegend=False)

# 2. Scatter-Punkte darüber legen (Kontrastfarbe für bessere Sichtbarkeit)
# Zeigt jeden einzelnen Run als Punkt -> Maximale Transparenz der Ergebnisse
scatter_trace = px.strip(
    df_melt, 
    x="Target Class", 
    y="F1 Score", 
    category_orders={"Target Class": category_order}
).data[0]

# Styling der Punkte: Dunkles Kontrast-Blau mit weißem Rand
# Das sorgt für Lesbarkeit sowohl auf hellen als auch dunklen Hintergründen
scatter_trace.marker.color = '#34495e' # "Wet Asphalt" (Dunkles Grau-Blau)
scatter_trace.marker.size = 6
scatter_trace.marker.opacity = 0.8
scatter_trace.marker.line = dict(width=1, color='white') # Weißer Rand lässt Punkte "poppen"
scatter_trace.showlegend = False

fig.add_trace(scatter_trace)

# Achsen formatieren (Prozent statt 0.x)
fig.update_layout(
    yaxis_tickformat='.0%', 
    yaxis_title="F1 Score (Macro)", 
    xaxis_title=None, # X-Achsen Titel ist redundant wegen den Labels
    font=dict(size=14) # Schrift etwas größer für Thesis
)
fig.update_yaxes(range=[0, 1.1]) # Platz für Error Bars lassen

fig.write_image(os.path.join(RESULTS_DIR, "robustness_scientific_bar.png"))
fig.show()

In [None]:
# 4b. Visualisierung der Feature Importance Stabilität
import pandas as pd
import plotly.express as px
import os

# Melt für Feature Wichtigkeit
feature_imp_cols = [c for c in df_results.columns if c.startswith("imp_")]
df_melt_imp = df_results.melt(
    id_vars=['run_id'], 
    value_vars=feature_imp_cols,
    var_name='Feature', 
    value_name='Importance Score'
)

# Prefix "imp_" für schönere Labels entfernen
df_melt_imp['Feature'] = df_melt_imp['Feature'].str.replace('imp_', '')

# Sortierung berechnen: Wir wollen das wichtigste Feature OBEN haben.
# Plotly Boxplots ordern die Y-Achse standardmäßig von unten nach oben.
# Wir sortieren hier nach Median aufsteigend (ascending=True).
# -> Das Feature mit dem kleinsten Median steht am Anfang der Liste (Index 0).
# -> Plotly zeichnet Index 0 unten.
# -> Das Feature mit dem größten Median steht am Ende der Liste.
# -> Plotly zeichnet das Ende oben.
# Falls es bei dir "falsch herum" ist, liegt es oft daran, dass Plotly 'Feature' noch als Faktor hat.
sorted_features = df_melt_imp.groupby('Feature')['Importance Score'].median().sort_values(ascending=False).index.tolist()

fig = px.box(
    df_melt_imp, 
    x="Importance Score", 
    y="Feature", 
    color="Feature", 
    # Wir erzwingen die Sortierung explizit über category_orders
    category_orders={"Feature": sorted_features}, 
    title=f"Feature Wichtigkeit über {N_LOOPS} Runs (Stabilitätstest)",
    points="all", 
    template="plotly_white",
    height=600,
    color_discrete_sequence=px.colors.qualitative.Pastel
)

fig.update_layout(showlegend=True)

# Punkte styling 
fig.update_traces(marker=dict(opacity=0.6, size=3)) # color='#34495e', 

# Achsen formatieren
fig.update_layout(
    xaxis_tickformat='.1%', 
    xaxis_title="Wichtigkeit (Mean Decrease Accuracy)",
    yaxis_title=None,
    font=dict(size=12)
)

# Speichern
fig.write_image(os.path.join(RESULTS_DIR, "robustness_feature_importance.png"))
fig.show()

In [None]:
# 5. Sanity Check / Plausibilitäts-Prüfung
# Hinweis: RESULTS_DIR kommt jetzt aus dem Block oben mit Zeitstempel!

import pandas as pd
import os

print(f"Lade Prüfung von: {RESULTS_DIR}") # Nimmt den aktuellsten Ordner
try:
    csv_load_path = os.path.join(RESULTS_DIR, "robustness_metrics_with_importance.csv")
    check_df = pd.read_csv(csv_load_path)

    print(f"Anzahl durchgeführter Runs: {len(check_df)}")
    print(f"Genutzte Seeds: {check_df['seed'].unique()}")

    # Prüfung 1: Haben wir Variation?
    if len(check_df) > 1:
        std_dev = check_df['f1_macro'].std()
        if std_dev > 0:
            print(f"✅ Test bestanden: Ergebnisse variieren (Std-Dev: {std_dev:.4f}).")
        else:
            print("⚠️ Warnung: Keine Variation! Prüfe den Random Seed!")
    else:
        print("ℹ️ Hinweis: Nur 1 Run durchgeführt (Test-Modus).")

    # Prüfung 2: Sind die Ergebnisse 'zu perfekt'?
    if check_df['f1_macro'].max() >= 0.99:
        print("⚠️ Warnung: F1-Score von fast 1.0 (>=0.99) gefunden. Das ist verdächtig.")
    else:
        print(f"✅ Test bestanden: Realistische Scores (Max: {check_df['f1_macro'].max():.3f})")

    check_df.head()
    
except FileNotFoundError:
    print(f"❌ Fehler: Keine CSV gefunden in {RESULTS_DIR}. Hast du den Loop oben laufen lassen?")

Lade Prüfung von: ../../reports/robustness_experiment_test/2026-02-14_15-35-56
Anzahl durchgeführter Runs: 1
Genutzte Seeds: [42]
ℹ️ Hinweis: Nur 1 Run durchgeführt (Test-Modus).
✅ Test bestanden: Realistische Scores (Max: 0.318)
