In [1]:
!pip install --quiet wandb scikit-learn pandas matplotlib

print('Install complete.')

Install complete.


In [2]:
import sys
import sklearn
import pandas as pd
import wandb

print('Python:', sys.version.splitlines()[0])
print('scikit-learn:', sklearn.__version__)
print('pandas:', pd.__version__)
print('wandb import ok')

Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
scikit-learn: 1.6.1
pandas: 2.2.2
wandb import ok


In [3]:
import wandb
wandb.login()
print('If login succeeded, you will see your W&B username above.')

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33muma_mahesh_iitpkd[0m ([33muma_mahesh_iitpkd-indian-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


If login succeeded, you will see your W&B username above.


In [4]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
import pandas as pd

# Load dataset
data = load_iris()
X = data['data']
y = data['target']
feature_names = data['feature_names']

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print('Shapes: X_train', X_train.shape, 'X_test', X_test.shape)

Shapes: X_train (105, 4) X_test (45, 4)


In [7]:
import matplotlib.pyplot as plt
import io

def plot_confusion_matrix(cm, labels):
    fig, ax = plt.subplots(figsize=(4,4))
    ax.imshow(cm, interpolation='nearest')
    ax.set_title('Confusion matrix')
    ax.set_xticks(range(len(labels)))
    ax.set_yticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45)
    ax.set_yticklabels(labels)
    for i in range(len(labels)):
        for j in range(len(labels)):
            ax.text(j, i, str(cm[i, j]), ha='center', va='center')
    plt.tight_layout()
    return fig

def train_and_log(run_name, n_estimators=50, random_state=42, simulate_shift=False, entity=None, project='mlops-performance-monitoring'):
    """Train RandomForest and log to W&B.
    If simulate_shift=True, a small shift is added to X_test to emulate data drift.
    """
    run = wandb.init(project=project, name=run_name, entity=entity, reinit=True)
    wandb.config.update({'n_estimators': n_estimators, 'random_state': random_state})

    clf = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state)
    clf.fit(X_train, y_train)
    if simulate_shift:
        X_eval = X_test + np.random.normal(loc=0.5, scale=0.1, size=X_test.shape)
    else:
        X_eval = X_test
    preds = clf.predict(X_eval)
    acc = accuracy_score(y_test, preds)
    cr = classification_report(y_test, preds, output_dict=True)
    cm = confusion_matrix(y_test, preds)

    # Log metrics
    wandb.log({'accuracy': acc})
    # Log classification report as metrics
    for k, v in cr.items():
        if k.isdigit():
            wandb.log({f'class_{k}_precision': v['precision'], f'class_{k}_recall': v['recall'], f'class_{k}_f1': v['f1-score']})

    # Confusion matrix image
    fig = plot_confusion_matrix(cm, labels=data['target_names'])
    wandb.log({"confusion_matrix": wandb.Image(fig)})
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    # wandb.log({"confusion_matrix": wandb.Image(buf)}) # Removed this line
    plt.close(fig)

    # Log model as artifact
    artifact = wandb.Artifact('rf-model', type='model')
    import joblib
    joblib.dump(clf, 'rf_model.joblib')
    artifact.add_file('rf_model.joblib')
    run.log_artifact(artifact)

    print(f'Run {run_name} logged. Accuracy = {acc:.4f}')

    return acc, run

In [8]:
baseline_acc,run = train_and_log('baseline-run', n_estimators=50, random_state=42)
run.finish()
baseline_acc

0,1
accuracy,▁
class_0_f1,▁
class_0_precision,▁
class_0_recall,▁
class_1_f1,▁
class_1_precision,▁
class_1_recall,▁
class_2_f1,▁
class_2_precision,▁
class_2_recall,▁

0,1
accuracy,1
class_0_f1,1
class_0_precision,1
class_0_recall,1
class_1_f1,1
class_1_precision,1
class_1_recall,1
class_2_f1,1
class_2_precision,1
class_2_recall,1


Run baseline-run logged. Accuracy = 1.0000


0,1
accuracy,▁
class_0_f1,▁
class_0_precision,▁
class_0_recall,▁
class_1_f1,▁
class_1_precision,▁
class_1_recall,▁
class_2_f1,▁
class_2_precision,▁
class_2_recall,▁

0,1
accuracy,1
class_0_f1,1
class_0_precision,1
class_0_recall,1
class_1_f1,1
class_1_precision,1
class_1_recall,1
class_2_f1,1
class_2_precision,1
class_2_recall,1


1.0

In [9]:
drifted_acc,run = train_and_log('drifted-run', n_estimators=50, random_state=99, simulate_shift=True)
run.finish()
drifted_acc

Run drifted-run logged. Accuracy = 0.6222


0,1
accuracy,▁
class_0_f1,▁
class_0_precision,▁
class_0_recall,▁
class_1_f1,▁
class_1_precision,▁
class_1_recall,▁
class_2_f1,▁
class_2_precision,▁
class_2_recall,▁

0,1
accuracy,0.62222
class_0_f1,0.8125
class_0_precision,1.0
class_0_recall,0.68421
class_1_f1,0.19048
class_1_precision,0.25
class_1_recall,0.15385
class_2_f1,0.7027
class_2_precision,0.54167
class_2_recall,1.0


0.6222222222222222

In [10]:
drifted_acc,run = train_and_log('drifted-run', simulate_shift=True)
threshold = 0.85
if drifted_acc < threshold:
    wandb.alert(title='Low accuracy detected', text=f'Accuracy {drifted_acc:.3f} below threshold {threshold}', level=wandb.AlertLevel.WARN)
    print('Alert sent (check W&B)')
else:
    print('Accuracy OK')

run.finish()

Run drifted-run logged. Accuracy = 0.7333
Alert sent (check W&B)


0,1
accuracy,▁
class_0_f1,▁
class_0_precision,▁
class_0_recall,▁
class_1_f1,▁
class_1_precision,▁
class_1_recall,▁
class_2_f1,▁
class_2_precision,▁
class_2_recall,▁

0,1
accuracy,0.73333
class_0_f1,0.94444
class_0_precision,1.0
class_0_recall,0.89474
class_1_f1,0.33333
class_1_precision,0.6
class_1_recall,0.23077
class_2_f1,0.72222
class_2_precision,0.56522
class_2_recall,1.0
