# Predictive Maintenance – Deutsche Bahn
## Visualisiertes ML-Notebook mit Logistischer Regression & Random Forest

In diesem Notebook werden synthetische Predictive-Maintenance-Daten für ICE-Züge analysiert und mit Hilfe von Logistischer Regression und Random Forest ausgewertet. Es orientiert sich im Aufbau an einem klassischen Beispiel für logistische Regression mit umfangreicher Visualisierung (EDA, Confusion-Matrix, ROC-Kurven).

## 1. Libraries importieren

In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    roc_curve
)

%matplotlib inline

sns.set_style("whitegrid")


## 2. Daten laden und erste Übersicht

In [None]:
# CSV-Datei laden (muss im gleichen Ordner liegen)
df = pd.read_csv("predictive_maintenance_db.csv")

# Erster Blick in die Daten
df.head()


In [None]:
df.info()

In [None]:
df.describe()

In [None]:
# Verteilung der Zielvariable (Ausfall innerhalb von 30 Tagen)
df['failure_within_30d'].value_counts(normalize=True)


## 3. Explorative Datenanalyse (EDA)

### 3.1 Histogramme ausgewählter Features

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

df['temperature'].hist(bins=30, ax=axes[0, 0])
axes[0, 0].set_title('Temperature')

df['vibration_level'].hist(bins=30, ax=axes[0, 1])
axes[0, 1].set_title('Vibration Level')

df['operating_hours'].hist(bins=30, ax=axes[1, 0])
axes[1, 0].set_title('Operating Hours')

df['days_since_last_maintenance'].hist(bins=30, ax=axes[1, 1])
axes[1, 1].set_title('Days Since Last Maintenance')

plt.tight_layout()


### 3.2 Jointplot: Temperatur vs. Vibration Level

In [None]:
sns.jointplot(
    data=df.sample(min(2000, len(df)), random_state=42),
    x='temperature',
    y='vibration_level',
    kind='scatter'
)


### 3.3 Pairplot für ausgewählte numerische Variablen

In [None]:
subset_cols = [
    'temperature',
    'vibration_level',
    'operating_hours',
    'days_since_last_maintenance',
    'failure_within_30d'
]

sns.pairplot(
    df[subset_cols].sample(min(1000, len(df)), random_state=42),
    hue='failure_within_30d',
    diag_kind='hist'
)


### 3.4 Korrelationsmatrix & Heatmap

In [None]:
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()

corr = df[numeric_cols].corr()

plt.figure(figsize=(10, 8))
sns.heatmap(corr, annot=False)
plt.title('Korrelationsmatrix (numerische Features)')
plt.tight_layout()


## 4. Datenvorbereitung & Train/Test-Split

In [None]:
# Features (X) und Zielvariable (y) definieren
X = df.drop('failure_within_30d', axis=1)
y = df['failure_within_30d']

# Kategoriale und numerische Spalten definieren
categorical = ['component_type', 'weekday']
numeric = [col for col in X.columns if col not in categorical]

# Preprocessing-Pipeline (One-Hot-Encoding für kategoriale Variablen)
preprocess = ColumnTransformer(
    transformers=[
        ('cat', OneHotEncoder(handle_unknown='ignore'), categorical),
        ('num', 'passthrough', numeric)
    ]
)

# Train-Test-Split (mit Stratifikation wegen Klassenverteilung)
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42,
    stratify=y
)

X_train.shape, X_test.shape


## 5. Logistische Regression

In [None]:
log_reg = Pipeline(steps=[
    ('preprocess', preprocess),
    ('lr', LogisticRegression(max_iter=500, class_weight='balanced'))
])

log_reg.fit(X_train, y_train)

y_pred_lr = log_reg.predict(X_test)
y_proba_lr = log_reg.predict_proba(X_test)[:, 1]

print('### Klassifikationsbericht – Logistische Regression')
print(classification_report(y_test, y_pred_lr))

print('ROC-AUC (Logistische Regression):', roc_auc_score(y_test, y_proba_lr))


### 5.1 Confusion-Matrix (Logistische Regression)

In [None]:
cm_lr = confusion_matrix(y_test, y_pred_lr)

plt.figure(figsize=(4, 3))
sns.heatmap(cm_lr, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix – Logistische Regression')
plt.tight_layout()


### 5.2 ROC-Kurve (Logistische Regression)

In [None]:
fpr_lr, tpr_lr, _ = roc_curve(y_test, y_proba_lr)

plt.figure()
plt.plot(fpr_lr, tpr_lr, label='Logistische Regression')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC-Kurve – Logistische Regression')
plt.legend()
plt.tight_layout()


## 6. Random Forest Modell

In [None]:
rf = Pipeline(steps=[
    ('preprocess', preprocess),
    ('rf', RandomForestClassifier(
        n_estimators=200,
        random_state=42,
        class_weight='balanced'
    ))
])

rf.fit(X_train, y_train)

y_pred_rf = rf.predict(X_test)
y_proba_rf = rf.predict_proba(X_test)[:, 1]

print('### Klassifikationsbericht – Random Forest')
print(classification_report(y_test, y_pred_rf))

print('ROC-AUC (Random Forest):', roc_auc_score(y_test, y_proba_rf))


### 6.1 Confusion-Matrix (Random Forest)

In [None]:
cm_rf = confusion_matrix(y_test, y_pred_rf)

plt.figure(figsize=(4, 3))
sns.heatmap(cm_rf, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix – Random Forest')
plt.tight_layout()


### 6.2 ROC-Kurve (Random Forest)

In [None]:
fpr_rf, tpr_rf, _ = roc_curve(y_test, y_proba_rf)

plt.figure()
plt.plot(fpr_rf, tpr_rf, label='Random Forest')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC-Kurve – Random Forest')
plt.legend()
plt.tight_layout()


## 7. Vergleich der Modelle über ROC-Kurve

In [None]:
plt.figure()
plt.plot(fpr_lr, tpr_lr, label='Logistische Regression')
plt.plot(fpr_rf, tpr_rf, label='Random Forest')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC-Kurven – Modellvergleich')
plt.legend()
plt.tight_layout()
