In [44]:
import sys
import polars as pl
import plotly.express as px
sys.path.insert(0, '..')
from fs_thesis import sql, show

# Create Overview for Tables Patients and Admissions

In [45]:
# Check coverage of BMI vs Weight/Height
# If BMI is present for most patients who have Weight/Height, we can use it.
df_coverage = sql("""
    SELECT 
        result_name, 
        COUNT(*) as total_measurements,
        COUNT(DISTINCT subject_id) as unique_patients
    FROM hosp.omr 
    WHERE result_name IN ('BMI (kg/m2)', 'Weight (Lbs)', 'Height (Inches)')
    GROUP BY result_name
""")
show(df_coverage, limit=False)

Unnamed: 0,result_name,total_measurements,unique_patients
0,BMI (kg/m2),1901496,153725
1,Weight (Lbs),2145353,166872
2,Height (Inches),814964,148359


In [46]:
df = sql("SELECT * from hosp.patients")
show(df, limit=True)

Unnamed: 0,subject_id,gender,anchor_age,anchor_year,anchor_year_group,dod
0,10000032,F,52,2180,2014 - 2016,2180-09-09
1,10000048,F,23,2126,2008 - 2010,NaT
2,10000058,F,33,2168,2020 - 2022,NaT
3,10000068,F,19,2160,2008 - 2010,NaT
4,10000084,M,72,2160,2017 - 2019,2161-02-13
...,...,...,...,...,...,...
364622,19999828,F,46,2147,2017 - 2019,NaT
364623,19999829,F,28,2186,2008 - 2010,NaT
364624,19999840,M,58,2164,2008 - 2010,2164-09-17
364625,19999914,F,49,2158,2017 - 2019,NaT


In [47]:
df = sql("SELECT * from hosp.admissions")
show(df, limit=True)

Unnamed: 0,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admit_provider_id,admission_location,discharge_location,insurance,language,marital_status,race,edregtime,edouttime,hospital_expire_flag
0,10000032,22595853,2180-05-06 22:23:00,2180-05-07 17:15:00,NaT,URGENT,P49AFC,TRANSFER FROM HOSPITAL,HOME,Medicaid,English,WIDOWED,WHITE,2180-05-06 19:17:00,2180-05-06 23:30:00,0
1,10000032,22841357,2180-06-26 18:27:00,2180-06-27 18:49:00,NaT,EW EMER.,P784FA,EMERGENCY ROOM,HOME,Medicaid,English,WIDOWED,WHITE,2180-06-26 15:54:00,2180-06-26 21:31:00,0
2,10000032,25742920,2180-08-05 23:44:00,2180-08-07 17:50:00,NaT,EW EMER.,P19UTS,EMERGENCY ROOM,HOSPICE,Medicaid,English,WIDOWED,WHITE,2180-08-05 20:58:00,2180-08-06 01:44:00,0
3,10000032,29079034,2180-07-23 12:35:00,2180-07-25 17:55:00,NaT,EW EMER.,P06OTX,EMERGENCY ROOM,HOME,Medicaid,English,WIDOWED,WHITE,2180-07-23 05:54:00,2180-07-23 14:00:00,0
4,10000068,25022803,2160-03-03 23:16:00,2160-03-04 06:26:00,NaT,EU OBSERVATION,P39NWO,EMERGENCY ROOM,,,English,SINGLE,WHITE,2160-03-03 21:55:00,2160-03-04 06:26:00,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
546023,19999828,25744818,2149-01-08 16:44:00,2149-01-18 17:00:00,NaT,EW EMER.,P13JMH,TRANSFER FROM HOSPITAL,HOME HEALTH CARE,Medicaid,English,SINGLE,WHITE,2149-01-08 09:11:00,2149-01-08 18:12:00,0
546024,19999828,29734428,2147-07-18 16:23:00,2147-08-04 18:10:00,NaT,EW EMER.,P38XL8,PHYSICIAN REFERRAL,HOME HEALTH CARE,Medicaid,English,SINGLE,WHITE,2147-07-17 17:18:00,2147-07-18 17:34:00,0
546025,19999840,21033226,2164-09-10 13:47:00,2164-09-17 13:42:00,2164-09-17 13:42:00,EW EMER.,P33612,EMERGENCY ROOM,DIED,Private,English,WIDOWED,WHITE,2164-09-10 11:09:00,2164-09-10 14:46:00,1
546026,19999840,26071774,2164-07-25 00:27:00,2164-07-28 12:15:00,NaT,EW EMER.,P036NA,EMERGENCY ROOM,HOME,Private,English,WIDOWED,WHITE,2164-07-24 21:16:00,2164-07-25 01:20:00,0


# Patient Demographics & Time-to-Event Extraction

Extracts patient demographic data and calculates time-to-event metrics for survival analysis.

## Data Structure
- **Unique Identifier**: `subject_id`
- **Demographics**: gender, age, insurance, language, marital status, race
- **Baseline**: First admission timestamp (`t0_time`)
- **Event**: ICD-9/10 diagnosis code occurrence
- **Censoring**: Date of death (`dod`) or end of follow-up

## Output
Generates a dataset with:
- Demographic features
- Time to event (days from baseline)
- Event indicator (occurred/censored)
- Target classes: Early event (<1yr), late event (>1yr), or censored

In [48]:
df_baseline = sql("""SELECT 
    p.subject_id,
    p.gender,
    p.anchor_age,
    p.dod,
    -- Daten aus der ersten Aufnahme (Baseline)
    first_admit.admittime AS t0_time,
    first_admit.insurance,
    first_admit.language,
    first_admit.marital_status,
    first_admit.race,
    first_admit.admission_type
FROM hosp.patients p
INNER JOIN (
    -- Subquery, um nur die zeitlich erste Aufnahme pro Patient zu finden
    SELECT 
        subject_id, 
        admittime, 
        insurance, 
        language, 
        marital_status, 
        race, 
        admission_type,
        ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY admittime ASC) as row_num
    FROM hosp.admissions
) first_admit ON p.subject_id = first_admit.subject_id
WHERE first_admit.row_num = 1""")
show(df_baseline,limit=True)

Unnamed: 0,subject_id,gender,anchor_age,dod,t0_time,insurance,language,marital_status,race,admission_type
0,13585638,M,59,2195-07-25,2187-08-24 14:37:00,Private,English,MARRIED,WHITE - OTHER EUROPEAN,AMBULATORY OBSERVATION
1,13585896,F,78,NaT,2177-02-08 18:57:00,Medicare,Spanish,WIDOWED,OTHER,OBSERVATION ADMIT
2,13586115,F,19,NaT,2154-12-20 03:49:00,Private,English,SINGLE,WHITE,EU OBSERVATION
3,13588863,F,66,2152-01-09,2145-02-13 19:39:00,Medicare,English,WIDOWED,WHITE,EW EMER.
4,13589449,M,56,NaT,2161-07-01 19:11:00,Private,English,SINGLE,WHITE,OBSERVATION ADMIT
...,...,...,...,...,...,...,...,...,...,...
223447,19996366,M,65,NaT,2128-08-18 23:47:00,Medicare,Spanish,,UNKNOWN,EW EMER.
223448,19996902,F,35,NaT,2156-09-22 01:18:00,Other,English,SINGLE,WHITE,EW EMER.
223449,19996903,M,43,NaT,2124-05-17 18:23:00,Medicaid,English,MARRIED,OTHER,EU OBSERVATION
223450,19997293,M,76,2124-02-20,2123-10-12 13:46:00,Medicare,English,SINGLE,WHITE,URGENT


In [49]:
# Zählt die eindeutigen Werte
anzahl = len(df["subject_id"].unique())
anzahl


223452

In [50]:
# Prüft auf Unique-Status
ist_unique = len(df["subject_id"].unique()) == len(df["subject_id"])
ist_unique

False

# Event Time

Extracts the first occurrence of the target diagnosis (ICD-50 code, Beispiel: Herzinsuffizienz) for each patient.

Returns:
- `subject_id`: Patient identifier
- `event_time`: Admission timestamp of first diagnosis occurrence

In [51]:
data = sql("""SELECT d.*, i.long_title
FROM hosp.diagnoses_icd d
LEFT JOIN hosp.d_icd_diagnoses i ON d.icd_code = i.icd_code
WHERE d.icd_code LIKE 'I50%'""")
show(data, limit=True)

Unnamed: 0,subject_id,hadm_id,seq_num,icd_code,icd_version,long_title
0,10000980,20897796,2,I5033,10,Acute on chronic diastolic (congestive) heart ...
1,10000980,25911675,2,I5023,10,Acute on chronic systolic (congestive) heart f...
2,10000980,29659838,1,I5023,10,Acute on chronic systolic (congestive) heart f...
3,10001667,22672901,2,I5030,10,Unspecified diastolic (congestive) heart failure
4,10001843,21728396,2,I5031,10,Acute diastolic (congestive) heart failure
...,...,...,...,...,...,...
43908,19697088,23862507,2,I5023,10,Acute on chronic systolic (congestive) heart f...
43909,19697088,29365088,3,I5022,10,Chronic systolic (congestive) heart failure
43910,19697640,27428442,2,I5022,10,Chronic systolic (congestive) heart failure
43911,19698359,27953498,6,I5033,10,Acute on chronic diastolic (congestive) heart ...


In [52]:
df_event = sql("""SELECT 
    d.subject_id,
    MIN(a.admittime) AS event_time
FROM hosp.diagnoses_icd d
JOIN hosp.admissions a ON d.hadm_id = a.hadm_id
WHERE d.icd_code LIKE 'I50%' 
GROUP BY d.subject_id
                         """)
show(df_event, limit=True)

Unnamed: 0,subject_id,event_time
0,12227418,2187-09-04 17:36:00
1,12238304,2114-10-06 10:14:00
2,12238622,2121-05-18 06:26:00
3,12244155,2146-08-24 04:08:00
4,12245501,2153-05-02 17:56:00
...,...,...
18885,19231666,2178-03-11 00:00:00
18886,19240056,2164-01-07 02:10:00
18887,19241223,2142-03-12 16:35:00
18888,19245405,2186-04-20 21:32:00


# Time-to-Event Calculation & Target Definition

Calculates time intervals and creates target classes based on event occurrence and timing.

## Steps:
1. **Join datasets**: Merges baseline demographics with event times (left join)
2. **Calculate time differences**: 
   - `t_event`: Days from baseline to diagnosis
   - `t_death`: Days from baseline to death
3. **Determine duration**: Uses event time if available, otherwise death time, fallback to 2000 days
4. **Handle errors**: Clips negative durations to 0
5. **Create target classes**:
   - Class 0: Early event (<365 days)
   - Class 1: Late event (>365 days)
   - Class 2: Censored (no event)

In [53]:
import polars as pl

df_final = (
    df_baseline.join(df_event, on="subject_id", how="left")
    # 1. Zeitdifferenzen und Event-Indikator berechnen
    .with_columns([
        ((pl.col("event_time") - pl.col("t0_time")).dt.total_days()).alias("t_event"),
        ((pl.col("dod") - pl.col("t0_time")).dt.total_days()).alias("t_death"),
        pl.col("event_time").is_not_null().cast(pl.Int32).alias("event_occurred")
    ])
    # 2. Duration festlegen (Priorität: Event > Tod > Fallback) und Clip
    .with_columns(
        pl.coalesce([
            pl.col("t_event"),
            pl.col("t_death"),
            pl.lit(2000)
        ])
        .clip(lower_bound=0)
        .alias("duration")
    )
    # 3. Zielvariable (Target) für TabPFN definieren
    .with_columns(
        pl.when((pl.col("event_occurred") == 1) & (pl.col("duration") <= 365))
        .then(0)     # Klasse 0: Früher Ausbruch (< 1 Jahr)
        .when((pl.col("event_occurred") == 1) & (pl.col("duration") > 365))
        .then(1)     # Klasse 1: Später Ausbruch (> 1 Jahr)
        .otherwise(2) # Klasse 2: Zensiert / Gesund / Kein Event
        .alias("target")
    )
)

# Ergebnis anzeigen
show(df_final, limit=True)

# Kurze Kontrolle der Verteilung
print(df_final["target"].value_counts())

Unnamed: 0,subject_id,gender,anchor_age,dod,t0_time,insurance,language,marital_status,race,admission_type,event_time,t_event,t_death,event_occurred,duration,target
0,13585638,M,59,2195-07-25,2187-08-24 14:37:00,Private,English,MARRIED,WHITE - OTHER EUROPEAN,AMBULATORY OBSERVATION,NaT,,2891.0,0,2891,2
1,13585896,F,78,NaT,2177-02-08 18:57:00,Medicare,Spanish,WIDOWED,OTHER,OBSERVATION ADMIT,NaT,,,0,2000,2
2,13586115,F,19,NaT,2154-12-20 03:49:00,Private,English,SINGLE,WHITE,EU OBSERVATION,NaT,,,0,2000,2
3,13588863,F,66,2152-01-09,2145-02-13 19:39:00,Medicare,English,WIDOWED,WHITE,EW EMER.,NaT,,2520.0,0,2520,2
4,13589449,M,56,NaT,2161-07-01 19:11:00,Private,English,SINGLE,WHITE,OBSERVATION ADMIT,NaT,,,0,2000,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
223447,19996366,M,65,NaT,2128-08-18 23:47:00,Medicare,Spanish,,UNKNOWN,EW EMER.,2128-08-18 23:47:00,0.0,,1,0,0
223448,19996902,F,35,NaT,2156-09-22 01:18:00,Other,English,SINGLE,WHITE,EW EMER.,NaT,,,0,2000,2
223449,19996903,M,43,NaT,2124-05-17 18:23:00,Medicaid,English,MARRIED,OTHER,EU OBSERVATION,NaT,,,0,2000,2
223450,19997293,M,76,2124-02-20,2123-10-12 13:46:00,Medicare,English,SINGLE,WHITE,URGENT,NaT,,130.0,0,130,2


shape: (3, 2)
┌────────┬────────┐
│ target ┆ count  │
│ ---    ┆ ---    │
│ i32    ┆ u32    │
╞════════╪════════╡
│ 0      ┆ 10742  │
│ 1      ┆ 8148   │
│ 2      ┆ 204562 │
└────────┴────────┘


# Feature & Target Extraction

Prepares data for machine learning by selecting demographic features and target variable.

## Output:
- **X**: Feature matrix with demographics (pandas DataFrame)
  - gender, anchor_age, insurance, language, marital_status, race, admission_type
- **y**: Target vector with class labels (numpy array)
  - 0: Early event (<365 days)
  - 1: Late event (>365 days)
  - 2: Censored (no event)

In [54]:
# Feature Engineering: BMI
# Strategie: Wir nutzen den vorgefertigten BMI Wert, da er eine höhere Abdeckung hat als (Weight & Height) kombiniert.
# Aggregation: Wir nehmen den Median pro Patient, um Ausreißer zu minimieren.

df_bmi = sql("""
    SELECT 
        subject_id,
        MEDIAN(CAST(result_value AS FLOAT)) as bmi
    FROM hosp.omr
    WHERE result_name = 'BMI (kg/m2)'
    -- Einfache Bereinigung: Nur plausible Werte (z.B. 10 bis 100)
    AND TRY_CAST(result_value AS FLOAT) BETWEEN 10 AND 100
    GROUP BY subject_id
""")

# Check distribution
show(df_bmi, limit=True)

Unnamed: 0,subject_id,bmi
0,10524717,31.799999
1,10528130,36.099998
2,10529136,31.400000
3,10529338,35.200001
4,10529536,26.400000
...,...,...
153536,19698886,36.950001
153537,19709184,32.250000
153538,19723403,25.100000
153539,19727639,25.750000


In [55]:
# Join to clean dataset
# Wir machen einen Left Join, da nicht alle Patienten einen BMI haben (Missing Values werden später von TabPFN handled!)
df_final = df_final.join(
    df_bmi, 
    on="subject_id", 
    how="left"
)

# Prüfen wie viele Missings wir haben
print(f"Missing BMI: {df_final['bmi'].null_count()} of {len(df_final)}")
show(df_final.select(['subject_id', 'anchor_age', 'bmi', 'target']), limit=True)

Missing BMI: 98306 of 223452


Unnamed: 0,subject_id,anchor_age,bmi,target
0,13585638,59,26.600000,2
1,13585896,78,24.950001,2
2,13586115,19,,2
3,13588863,66,19.500000,2
4,13589449,56,,2
...,...,...,...,...
223447,19996366,65,,0
223448,19996902,35,32.799999,2
223449,19996903,43,,2
223450,19997293,76,,2


In [56]:
import polars as pl
from sklearn.model_selection import train_test_split

# 1. Konvertierung für Scikit-Learn (Split-Logik)
df_pd = df_final.to_pandas()

# Erster Split: Trenne das finale Test-Set ab (20% Hold-out)
df_train_val_raw, df_test = train_test_split(
    df_pd, 
    test_size=0.2, 
    stratify=df_pd["target"], 
    random_state=42
)

# Zweiter Split: Trenne den Rest in Training und Validierung (z.B. 80/20 vom Rest)
df_train_raw, df_val = train_test_split(
    df_train_val_raw, 
    test_size=0.2, 
    stratify=df_train_val_raw["target"], 
    random_state=42
)

# --- BACK TO POLARS: Balancing nur für das TRAINING ---
# Wir nehmen 3.000 Samples pro Klasse, um innerhalb des TabPFN-Limits zu bleiben
df_train_pl = pl.from_pandas(df_train_raw)
n_samples = 3000

df_balanced_train = pl.concat([
    df_train_pl.filter(pl.col("target") == 0).sample(n=min(n_samples, df_train_pl.filter(pl.col("target")==0).height), seed=42),
    df_train_pl.filter(pl.col("target") == 1).sample(n=min(n_samples, df_train_pl.filter(pl.col("target")==1).height), seed=42),
    df_train_pl.filter(pl.col("target") == 2).sample(n=n_samples, seed=42)
]).sample(fraction=1.0, shuffle=True, seed=42)

# --- FINALE FEATURES & TARGETS ---
feature_cols = ["gender", "anchor_age", "insurance", "language", "marital_status", "race", "admission_type", "bmi"]

# Training (Balanced)
X_train = df_balanced_train.select(feature_cols).to_pandas()
y_train = df_balanced_train.select("target").to_series().to_numpy()

# Validation (Realistisch, nicht balanciert)
X_val = df_val[feature_cols]
y_val = df_val["target"].values

# Test (Realistisch, nicht balanciert - Hold-out)
X_test = df_test[feature_cols]
y_test = df_test["target"].values

print(f"Train (balanced): {len(y_train)} | Val (real): {len(y_val)} | Test (real): {len(y_test)}")

Train (balanced): 9000 | Val (real): 35753 | Test (real): 44691


> ℹ️ **INFO**  
> Why Train set is smaller than val and test. The Train set is balanced and gives a total of 9k for the TabPFN Algo. The Test and Val set is larger because it's a % of the total dataset (200k)

# Tab PFN
- training

In [57]:
# 1. Training-Set massiv verkleinern (für Speed & Datenschutz-Konformität)
# Wir nehmen nur 333 Samples pro Klasse = ca. 1000 Total
n_emergency = 333

df_emergency_train = pl.concat([
    df_train_pl.filter(pl.col("target") == 0).sample(n=min(n_emergency, df_train_pl.filter(pl.col("target")==0).height)),
    df_train_pl.filter(pl.col("target") == 1).sample(n=min(n_emergency, df_train_pl.filter(pl.col("target")==1).height)),
    df_train_pl.filter(pl.col("target") == 2).sample(n=n_emergency)
]).sample(fraction=1.0, shuffle=True)

X_train_fast = df_emergency_train.select(feature_cols).to_pandas()
y_train_fast = df_emergency_train.select("target").to_series().to_numpy()

# 2. Classifier neu initialisieren (CPU ist bei 1000 Samples oft stabiler als ein buggy MPS-Backend)
from tabpfn import TabPFNClassifier
classifier = TabPFNClassifier(device='cpu') # Zurück auf CPU, da N jetzt klein ist

# 3. Fit (Sekundensache)
classifier.fit(X_train_fast, y_train_fast)

# 4. Vorhersage für 500 Test-Patienten (ohne Batches, einfach direkt)
X_test_fast = X_val.iloc[:500]
y_test_fast = y_val[:500]

print("Starte Express-Vorhersage...")
y_val_pred_fast = classifier.predict(X_test_fast)
print("Fertig!")


Running on CPU with more than 200 samples may be slow.
Consider using a GPU or the tabpfn-client API: https://github.com/PriorLabs/tabpfn-client



Starte Express-Vorhersage...
Fertig!


# Evaluation

In [58]:
import plotly.figure_factory as ff
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd

# 1. Classification Report ausgeben (Text-Form für die harten Fakten)
print("EXPRESS-CHECK REPORT (N_train=1000):")
report_dict = classification_report(y_test_fast, y_val_pred_fast, 
                                   target_names=['Früh (<1J)', 'Spät (1-3J)', 'Gesund'],
                                   output_dict=True,
                                   zero_division=0)
print(classification_report(y_test_fast, y_val_pred_fast, 
                           target_names=['Früh (<1J)', 'Spät (1-3J)', 'Gesund']))

import plotly.figure_factory as ff
import numpy as np

# 1. Normalisierung der Confusion Matrix (Zeilenweise)
# Wie viel % der tatsächlichen Klasse wurden wie vorhergesagt?
cm = confusion_matrix(y_test_fast, y_val_pred_fast)
cm_perc = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

labels = ['Früh (<1J)', 'Spät (1-3J)', 'Gesund']

# 2. Text für die Boxen (Absolute Zahl + Prozent)
annot_text = [
    [f"<b>{val}</b><br>({perc:.1%})" for val, perc in zip(row_val, row_perc)]
    for row_val, row_perc in zip(cm, cm_perc)
]

# 3. Interaktive Heatmap
fig = ff.create_annotated_heatmap(
    cm_perc, 
    x=labels, 
    y=labels, 
    annotation_text=annot_text, 
    colorscale='Reds' # 'Reds' hebt die Treffer besser hervor
)

fig.update_layout(
    title='Zusammenhang: Soziale Faktoren & Wiederaufnahme (Normalisiert)',
    xaxis_title="Vorhersage des Modells",
    yaxis_title="Tatsächlicher Verlauf (MIMIC-Daten)",
    template="plotly_white",
    height=600
)

fig.show()

EXPRESS-CHECK REPORT (N_train=1000):
              precision    recall  f1-score   support

  Früh (<1J)       0.11      0.62      0.19        21
 Spät (1-3J)       0.12      1.00      0.21        15
      Gesund       0.99      0.54      0.70       464

    accuracy                           0.56       500
   macro avg       0.41      0.72      0.37       500
weighted avg       0.93      0.56      0.66       500



In [59]:
import plotly.graph_objects as go

# Daten aus deiner Confusion Matrix extrahieren
# Reihenfolge: [Früh, Spät, Gesund]
cm = confusion_matrix(y_test_fast, y_val_pred_fast)

# Labels für die Knoten
label_list = [
    "Tatsächlich: Früh", "Tatsächlich: Spät", "Tatsächlich: Gesund", # Quellen (Links)
    "Vorhergesagt: Früh", "Vorhergesagt: Spät", "Vorhergesagt: Gesund" # Ziele (Rechts)
]

# Definition der Flüsse (Sankey-Struktur)
source = [0, 0, 0, 1, 1, 1, 2, 2, 2] # Index der Quellen
target = [3, 4, 5, 3, 4, 5, 3, 4, 5] # Index der Ziele
value = cm.flatten() # Die Zahlen aus deiner Matrix

# Farben definieren (Grün für korrekt, Rot für Fehler)
color_link = [
    'rgba(31, 119, 180, 0.4)', 'rgba(31, 119, 180, 0.2)', 'rgba(31, 119, 180, 0.1)', # Von Früh
    'rgba(255, 127, 14, 0.2)', 'rgba(255, 127, 14, 0.4)', 'rgba(255, 127, 14, 0.1)', # Von Spät
    'rgba(44, 160, 44, 0.1)', 'rgba(44, 160, 44, 0.1)', 'rgba(44, 160, 44, 0.4)'    # Von Gesund
]

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15, thickness = 20, line = dict(color = "black", width = 0.5),
      label = label_list, color = "blue"
    ),
    link = dict(
      source = source, target = target, value = value, color = color_link
  ))])

fig.update_layout(title_text="Patienten-Fluss: Realität vs. KI-Vorhersage", font_size=12)
fig.show()

In [60]:
from sklearn.inspection import permutation_importance
import pandas as pd
import plotly.express as px

print("Berechne Feature Importance (kann ca. 1-2 Min. dauern oder 5)...")

# Wir nutzen unser Express-Testset für die Berechnung
result = permutation_importance(
    classifier, X_test_fast, y_test_fast, n_repeats=10, random_state=42, n_jobs=1
)

# Ergebnisse in ein DataFrame gießen
importance_df = pd.DataFrame({
    'Feature': feature_cols,
    'Importance': result.importances_mean,
    'Std_Dev': result.importances_std
}).sort_values(by='Importance', ascending=True) # Aufsteigend für horizontalen Plot

# Plotly Bar Chart
fig = px.bar(
    importance_df, 
    x='Importance', 
    y='Feature', 
    orientation='h',
    title='Welche sozialen Faktoren beeinflussen die Wiederaufnahme?',
    labels={'Importance': 'Rückgang der Genauigkeit bei Entfernung des Merkmals'},
    error_x='Std_Dev', # Zeigt die Variabilität der Wichtigkeit
    template="plotly_white",
    color='Importance',
    color_continuous_scale='Reds'
)

fig.update_layout(height=500)
fig.show()

Berechne Feature Importance (kann ca. 1-2 Min. dauern)...
