In [None]:
from fastai.tabular.all import *
from codecarbon import EmissionsTracker
import pandas as pd
import os
import shap
import matplotlib.pyplot as plt

# Load Iris data
iris_path = 'data/iris/iris.csv'
if not os.path.exists(iris_path):
    os.makedirs(os.path.dirname(iris_path), exist_ok=True)
    url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv"
    df = pd.read_csv(url)
    df.to_csv(iris_path, index=False)
else:
    df = pd.read_csv(iris_path)

# Turn flower names into numbers
df['target'] = df['species'].map({'setosa': 0, 'versicolor': 1, 'virginica': 2})
df['target'] = pd.Categorical(df['target'])

# Check data
print(df.head())
print(df['target'].value_counts())
print(df.dtypes)

# Define measurements
cont_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
cat_names = []

# Make groups of flowers
dls = TabularDataLoaders.from_df(
    df,
    procs=[Normalize],
    y_names='target',
    cont_names=cont_names,
    cat_names=cat_names,
    splits=RandomSplitter(valid_pct=0.2, seed=42)(range_of(df)),
    bs=10
)

# Start energy tracking
tracker = EmissionsTracker(project_name="iris_classifier_shap", output_dir="emissions")
tracker.start()

# Build model
learn = tabular_learner(dls, layers=[50,25], metrics=accuracy)

# Find learning rate
learn.lr_find()

# Train model
learn.fit_one_cycle(10, lr_max=1e-2)

# Get validation accuracy
acc = learn.validate()[1]
print(f"Validation accuracy: {acc:.4f} ({acc*100:.2f}%)")

# Misclassification analysis
preds, targs = learn.get_preds(ds_idx=1)
misclassified = (preds.argmax(dim=1) != targs).nonzero(as_tuple=True)[0]
print("Misclassified indices (test group):", misclassified)
if len(misclassified) > 0:
    valid_indices = dls.valid_ds.items
    misclassified_df = df.iloc[valid_indices[misclassified]]
    print("Tricky flowers:", misclassified_df)
    print("Guesses (probabilities):", preds[misclassified])
    print("Real answers:", targs[misclassified])
else:
    print("No tricky flowers! Perfect score!")

# SHAP analysis
X_valid, y_valid = dls.valid.xs, dls.valid.ys
explainer = shap.KernelExplainer(learn.model.predict, X_valid)
shap_values = explainer.shap_values(X_valid)

# SHAP summary plot (feature importance for all classes)
shap.summary_plot(shap_values, X_valid, feature_names=cont_names, show=False)
plt.savefig('plots/shap_summary.png')
plt.close()

# SHAP force plot for first misclassified sample (if any)
if len(misclassified) > 0:
    shap.force_plot(explainer.expected_value[1], shap_values[1][misclassified[0]], 
                    X_valid.iloc[misclassified[0]], feature_names=cont_names, 
                    matplotlib=True, show=False)
    plt.savefig('plots/shap_force_misclassified.png')
    plt.close()

# Stop energy tracking
emissions = tracker.stop()
print(f"CO2 emissions: {emissions:.6f} kg")

[codecarbon INFO @ 05:38:34] Energy consumed for RAM : 0.133193 kWh. RAM Power : 10.0 W
[codecarbon INFO @ 05:38:34] Delta energy consumed for CPU with constant : 0.000125 kWh, power : 30.0 W
[codecarbon INFO @ 05:38:34] Energy consumed for All CPU : 0.399689 kWh
[codecarbon INFO @ 05:38:34] 0.532882 kWh of electricity used since the beginning.
[codecarbon INFO @ 05:38:49] Energy consumed for RAM : 0.133235 kWh. RAM Power : 10.0 W
[codecarbon INFO @ 05:38:49] Delta energy consumed for CPU with constant : 0.000126 kWh, power : 30.0 W
[codecarbon INFO @ 05:38:49] Energy consumed for All CPU : 0.399815 kWh
[codecarbon INFO @ 05:38:49] 0.533050 kWh of electricity used since the beginning.
[codecarbon INFO @ 05:39:04] Energy consumed for RAM : 0.133277 kWh. RAM Power : 10.0 W
[codecarbon INFO @ 05:39:04] Delta energy consumed for CPU with constant : 0.000125 kWh, power : 30.0 W
[codecarbon INFO @ 05:39:04] Energy consumed for All CPU : 0.399940 kWh
[codecarbon INFO @ 05:39:04] 0.533217 kWh 

## Iris Classifier Results
- Default layers, 3 epochs: Accuracy = 0.766667
- layers=[200,100], 5 epochs, lr_max=1e-2: Accuracy = 0.800000
- layers=[100,50], 100 epochs, lr_max=1e-3, bs=64: Accuracy = 0.300000
- layers=[50,25], 15 epochs, target, bs=16: Accuracy = 0.966667
- layers=[50,25], 10 epochs, target, bs=16: Accuracy = 0.933333
- layers=[50,25], 8 epochs, target, bs=8: Accuracy = 0.96667 