In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning) ## suppress annoying deprecation warnings

import pandas as pd
import seaborn.objects as so
import seaborn as sns
import matplotlib.pyplot as plt


from sklearn import tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import cross_validate

# Part 3: Machine Learning insights by Data Viz

In [None]:
# Renaming columns for better axis labels in plots

col_rename = {
	'tavg': 'Temp_Avg_°C',
	'tmax': 'Temp_Max_°C',
	'tmin': 'Temp_Min_°C',
	'rhum': 'Rel_Humidity_%',
	'coco': 'Condition',
	'wspd': 'Wind_Speed_kmh',
	'prcp': 'Precipation_mm',
	'wdir': 'Wind_Direction_°',
	'pres': 'Air_pressure_hPa',
	'dwpt': 'Dew_point_°C'
}

In [None]:
## Reload data
weather_df = pd.read_csv('global_weather.csv', parse_dates=['time'], dtype={'wmo':str, 'station':str}) 
weather_df = weather_df.dropna()

weather_df.rename(columns=col_rename, inplace=True)
weather_df = weather_df.assign(Continent = weather_df["timezone"].str.split('/').str[0])  ## Get continent from timezone column

## Plotting feature correlation: first try

In [None]:
weather_corr = weather_df.select_dtypes(include='number').corr() ## Calculate (Pearson) correlation for all numerical features in dataframe

f, ax = plt.subplots(figsize=(11, 9))
#!# sns.??(weather_corr, ## seaborn.object no function for heatmaps yet
			cbar_kws={"label": "Pearson correlation coefficient", "shrink": 0.6} ## Label adjustments
			) 

## Problem with standard correlation heatmap: colormap not suitable and features not ordered by similarity

In [None]:
sns.clustermap(weather_corr,                                            ## Clustermap will cluster the features by similarity
			    cmap='vlag',center=0, vmin=-1, vmax=1,                  ## Colormap: correlations range from -1 to +1 and have fixed midpoint at 0                                  
				cbar_kws={"label": "Pearson correlation coefficient"},  ## Color legend description
                cbar_pos = (0.05, 0.45, 0.03, 0.2)						## Change the weird default position of legend
				).ax_row_dendrogram.set_visible(False)					## Dendrogram of feature similarity is identical (symmetric matrix); we can omit it on one side

## What can be seen in the correlation heatmap?
(1) features with (almost) identical information (redundancy) <br>
(2) association of learned embeddings (t-SNE, PCA) with features

## ML show case: predict the manuel text annotation of the weather by the numerical features

In [None]:
weather_df['Condition'].value_counts()

## For simplicity reducing to Top categories

In [None]:
weather_df_red = weather_df[
	weather_df['Condition'].isin( 
		weather_df['Condition'].value_counts()[0:6].index  ## Top categories
		) 
	]

weather_df_red

In [None]:
X = weather_df_red.select_dtypes(include='number') 	## Define features
#!# y = weather_df_red[??]  				## Define target variable

clf = tree.DecisionTreeClassifier(max_depth=5,class_weight="balanced", max_leaf_nodes = 6) ## Define a simple decision tree

clf = clf.fit(X, y)		## training

plt.figure(figsize=(30,12))  ## Plot the full decision tree
anno = tree.plot_tree(	clf, 									## Decision tree
			   	feature_names=clf.feature_names_in_.tolist(),	## Features names
			   	class_names = clf.classes_.tolist(),			## Weather condition text
				filled = True,									## Colored by class decision
				impurity =True,								## For simiplicity exclude impurity values at splits
				precision=1,									## decimal precision
				fontsize=12)									## Fontsize for readibility

## Insights into more complex ML models: Random Forest
### Plotting results: confusion matrix
Important for Confusion Matrix visualization is color + number per entry for pattern exploration and precision

In [None]:
clf_rf = RandomForestClassifier(n_estimators=100,  class_weight="balanced", max_leaf_nodes=20) ## What happens if we negelect class imbalance?
#!# clf_rf = RandomForestClassifier(n_estimators=100,  class_weight=None, max_leaf_nodes=20) ## What happens if we negelect class imbalance?


output = cross_validate(clf_rf, X, y, cv=5, scoring = 'accuracy', return_estimator =True) ## RandomForest is non-deterministic ML -> Cross-validation for more robust results

clf_rf_1 = output['estimator'][0] ## Select a learned RF from the cross-validation

y_pred = clf_rf_1.predict(X)
labels = clf_rf_1.classes_.tolist()
cm = confusion_matrix(y, y_pred, labels=labels)
fig, ax = plt.subplots(figsize=(12,10))
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(ax=ax) 

## Getting explainability from feature importance
### Extracting importance from learned models across all cross validations for robustness

In [None]:
feature_importances = pd.DataFrame(X.columns.to_list(), index = X.columns.to_list(), columns=['feature']) ## Empty Data frame definition

for idx,estimator in enumerate(output['estimator']):  ## For each tree in RF
    feature_importances = feature_importances.join(         
            pd.DataFrame(estimator.feature_importances_ ,             ## Calculated feature importance
                        index = estimator.feature_names_in_.tolist(), ## Feature Names
                        columns=['importance_cv'+str(idx+1)])         ## Save from which CV split
            )

# feature_importances
#!# feature_importances_long = pd.??( feature_importances, stubnames="importance_cv",i="feature",j="cv") 

feature_importances_long.head() ## We need typically long format for plotting in grammar of graphics


## Show importance as bar chart including errorbars
### Trick: flip x and y axis for readibility

In [None]:
(
#!#    so.Plot(feature_importances_long,x=??,y=??)
    .add(so.Bar(), so.Agg())				## Bar plot showing the average
	.add(so.Range(), so.Est(errorbar="sd"))	## Whiskers showing the standard deviation
	.layout(size=(8, 6))	
)

## There are problems with bar charts and errorbars: https://doi.org/10.1371/journal.pbio.1002128
### Better show points/dots for every trained model (sample) and statistics

In [None]:
(
    so.Plot(feature_importances_long,x='importance_cv',y='feature')
	.add(so.Dot(pointsize=5), so.Shift(y=.0), so.Jitter(.5)) ## Jitter and Shift avoid overplotting
#!# .add(so.??(color="red"), so.Agg())					## Show a dash with the average
#!# .add(so.??(color="red"), so.Est(errorbar="sd"))		## Show the range of the standard deviation 
	.layout(size=(8, 6))	
)