LDA : Linear Discriminant Analysis

In [None]:
# %% [markdown]
# # ML Baseline Paris - Application des 7 M√©thodes du Cours
# 
# **Objectif** : Tester les 7 m√©thodes sur le dataset Paris pr√©-trait√©
# 
# Les donn√©es sont d√©j√† :
# - Nettoy√©es et normalis√©es
# - Encod√©es (one-hot, amenities, etc.)
# - Avec target_class cr√©√©e (quartiles de prix)

# %%
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.ensemble import BaggingClassifier, RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils import resample
import umap.umap_ as umap
import warnings
warnings.filterwarnings('ignore')

# Configuration
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
pd.set_option('display.max_columns', None)

print("‚úÖ Biblioth√®ques import√©es")

# %% [markdown]
# ## 1. Chargement des Donn√©es (D√©j√† pr√©-trait√©es)

# %%
# Chargement du dataset final (d√©j√† normalis√© et encod√©)
path = '../data/paris_dataset_final_ready.csv.gz'
data_path = Path(path)
if not data_path.exists():
    raise FileNotFoundError(f"Dataset introuvable : {data_path}")
df = pd.read_csv(data_path, compression='gzip')

print(f"üìä Dataset charg√© : {df.shape}")
print(f"\nüîç Aper√ßu des colonnes :")
print(df.columns.tolist())

# V√©rification target
if 'target_class' in df.columns:
    print(f"\n‚úÖ Target trouv√©e : {df['target_class'].nunique()} classes")
    print(df['target_class'].value_counts().sort_index())
else:
    print("‚ö†Ô∏è Attention : 'target_class' non trouv√©e dans le dataset")

# Aper√ßu
print(f"\nüìã Aper√ßu des 5 premi√®res lignes :")
print(df.head())

# %% [markdown]
# ## 2. Pr√©paration X/y et Train/Test Split

# %%
print("=" * 60)
print("PR√âPARATION : S√©paration X/y et Train/Test Split")
print("=" * 60)
'''
# S√©paration X (features) et y (target)
y = df['target_class']
X = df.drop(columns=['target_class', 'price_clean', 'city_label'])

# Supprimer city_label si pr√©sent
if 'city_label' in X.columns:
    X = X.drop(columns=['city_label'])

print(f"\nFeatures (X) : {X.shape[1]} colonnes")
print(f"Target (y) : {len(y)} valeurs")
print(f"\nDistribution des classes :")
print(y.value_counts().sort_index())

# Split 80/20 stratifi√©
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    stratify=y, 
    random_state=42
)

print(f"\n‚úÖ Split effectu√© :")
print(f"  Train : {X_train.shape}")
print(f"  Test  : {X_test.shape}")
print(f"\nDistribution Train :")
print(y_train.value_counts().sort_index())
print(f"\nDistribution Test :")
print(y_test.value_counts().sort_index())

# %% [markdown]

Let's go

In [None]:
# %% [markdown]
# ## 5. M√âTHODE 3 : LDA - Linear Discriminant Analysis

# %%
print("\n" + "=" * 60)
print("M√âTHODE 3 : LDA (Linear Discriminant Analysis)")
print("=" * 60)

# LDA (max 3 composantes pour 4 classes)
lda = LinearDiscriminantAnalysis(n_components=3)
X_train_lda = lda.fit_transform(X_train, y_train)
X_test_lda = lda.transform(X_test)

print(f"Dimensions r√©duites : {X_train.shape[1]} ‚Üí {X_train_lda.shape[1]}")
print(f"Variance expliqu√©e : {lda.explained_variance_ratio_}")

# Visualisation 2D
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

for ax, (X_lda, y_set, title) in zip(axes, 
    [(X_train_lda, y_train, 'Train'), (X_test_lda, y_test, 'Test')]):
    
    for classe in range(4):
        mask = y_set == classe
        ax.scatter(X_lda[mask, 0], X_lda[mask, 1], 
                  c=colors_class[classe], label=f'Classe {classe}', alpha=0.5, s=10)
    
    ax.set_title(f"LDA - {title}")
    ax.set_xlabel(f"LD1 ({lda.explained_variance_ratio_[0]:.1%})")
    ax.set_ylabel(f"LD2 ({lda.explained_variance_ratio_[1]:.1%})")
    ax.legend()
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("‚úÖ LDA trouve les axes discriminants")