In [1]:
import xgboost as xgb
import shap
import joblib
import numpy as np
import pandas as pd
from sklearn.calibration import CalibratedClassifierCV
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.pipeline import Pipeline
from google.adk.agents import Agent
from google.adk.tools import FunctionTool
import uuid
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 1. Model Training & Calibration ==============================================
def train_and_save_model(X_train, y_train, numeric_features, categorical_features):
    # Preprocessing pipeline
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numeric_features),
            ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), categorical_features)
        ]
    )

    # Base pipeline
    base_pipeline = Pipeline([
        ('preprocessor', preprocessor),
        ('classifier', xgb.XGBClassifier(
            objective='binary:logistic',
            eval_metric='logloss',
            scale_pos_weight=(y_train == 0).sum() / (y_train == 1).sum()
        ))
    ])
    
    # Train and calibrate
    base_pipeline.fit(X_train, y_train)
    calibrated_pipeline = CalibratedClassifierCV(base_pipeline, method='isotonic', cv='prefit')
    calibrated_pipeline.fit(X_train, y_train)
    
    # Save model
    joblib.dump(calibrated_pipeline, 'insurance_risk_model.joblib')
    print("Model trained and saved successfully!")

In [3]:
class InsuranceRiskAgents:
    def __init__(self):
        # Load model and background data
        model_data = joblib.load("calibrated_risk_model.joblib")
        self.model = model_data['calibrated_model']
        self.background_data = model_data['background_data']
        
        # Access base pipeline through correct attribute
        self.base_pipeline = self.model.calibrated_classifiers_[0].estimator
        self.preprocessor = self.base_pipeline.named_steps['preprocessor']
        self.classifier = self.base_pipeline.named_steps['classifier']
        
        # Initialize SHAP explainer with proper data
        self.feature_names = self._get_feature_names()
        self.explainer = self._create_shap_explainer()
        self.model_version = "1.0.20240615"

    def _get_feature_names(self):
        """Get feature names from preprocessor"""
        numeric_features = self.preprocessor.transformers_[0][2]
        categorical_features = list(
            self.preprocessor.named_transformers_['cat']
            .get_feature_names_out(self.preprocessor.transformers_[1][2])
        )
        return numeric_features + categorical_features

    def _create_shap_explainer(self):
        """Initialize SHAP explainer with correct data types"""
        return shap.TreeExplainer(
            self.classifier,
            data=self.background_data,
            feature_perturbation="interventional",
            model_output="probability"
        )

    def predict_risk(self, input_data: dict) -> dict:
        """Generate risk prediction with calibrated probabilities"""
        try:
            df = pd.DataFrame([input_data])
            processed = self.preprocessor.transform(df)
            proba = self.model.predict_proba(processed)[0, 1]
            return {
                'risk_score': float(proba),
                'confidence': float(abs(proba - 0.5) * 2),
                'model_version': self.model_version
            }
        except Exception as e:
            return {
                'error': f"Prediction failed: {str(e)}",
                'model_version': self.model_version
            }

    def explain_risk(self, input_data: dict, tool_context=None) -> dict:
        """Generate regulator-ready explanation with audit trail"""
        try:
            # Handle session state if context exists
            if tool_context:
                history = tool_context.state.get("conversation_history", [])
                history.append({
                    "input": input_data,
                    "timestamp": datetime.now().isoformat()
                })
                tool_context.state["conversation_history"] = history

            # Generate SHAP explanation
            df = pd.DataFrame([input_data])
            processed = self.preprocessor.transform(df)
            shap_values = self.explainer.shap_values(processed)[0]
            
            # Format output
            top_factors = self._format_shap_factors(shap_values)
            plot_path = self._generate_shap_plot(shap_values)
            audit_id = str(uuid.uuid4())

            # Add audit log if context exists
            if tool_context:
                audit_log = tool_context.state.get("audit_log", [])
                audit_log.append({
                    "audit_id": audit_id,
                    "input": input_data,
                    "result": top_factors,
                    "timestamp": datetime.now().isoformat()
                })
                tool_context.state["audit_log"] = audit_log

            return {
                'factors': top_factors,
                'plot_path': plot_path,
                'compliance_note': "Explanation compliant with EU AI Act Article 13",
                'model_version': self.model_version,
                'audit_id': audit_id
            }
        except Exception as e:
            return {
                'error': f"Explanation failed: {str(e)}",
                'model_version': self.model_version
            }

    def _format_shap_factors(self, shap_values):
        """Format top 3 risk factors with impact scores"""
        top_idx = np.argsort(-np.abs(shap_values))[:3]
        return [
            f"{self.feature_names[i]} ({shap_values[i]:+.2f} impact)"
            for i in top_idx
        ]

    def _generate_shap_plot(self, shap_values):
        """Generate and save SHAP decision plot"""
        plot_id = str(uuid.uuid4())
        plt.figure(figsize=(10, 5))
        shap.decision_plot(
            self.explainer.expected_value,
            shap_values,
            self.feature_names,
            show=False
        )
        plt.title("Risk Decision Breakdown", fontsize=12)
        plt.tight_layout()
        plot_path = f"shap_plot_{plot_id}.png"
        plt.savefig(plot_path, bbox_inches='tight')
        plt.close()
        return plot_path

In [4]:
class InsuranceRiskAgents:
    def __init__(self):
        # Load model and background data
        model_data = joblib.load("calibrated_risk_model.joblib")
        self.model = model_data['calibrated_model']
        self.background_data = model_data['background_data']
        
        # Access base pipeline through correct attribute
        self.base_pipeline = self.model.calibrated_classifiers_[0].estimator
        self.preprocessor = self.base_pipeline.named_steps['preprocessor']
        self.classifier = self.base_pipeline.named_steps['classifier']
        
        # Track original feature names (critical fix)
        self.original_feature_names = (
            self.preprocessor.transformers_[0][2] +  # numeric
            self.preprocessor.transformers_[1][2]    # categorical
        )
        
        # Initialize SHAP explainer
        self.feature_names = self._get_feature_names()
        self.explainer = self._create_shap_explainer()
        self.model_version = "1.0.20240615"

    def _get_feature_names(self):
        """Get transformed feature names from preprocessor"""
        numeric = self.preprocessor.transformers_[0][2]
        categorical = list(
            self.preprocessor.named_transformers_['cat']
            .get_feature_names_out(self.preprocessor.transformers_[1][2])
        )
        return numeric + categorical

    def _create_shap_explainer(self):
        """Initialize SHAP explainer with correct data types"""
        return shap.TreeExplainer(
            self.classifier,
            data=self.background_data,
            feature_perturbation="interventional",
            model_output="probability"
        )

    def _validate_input(self, input_data: dict) -> bool:
        """Ensure input contains all required raw features"""
        return all(feat in input_data for feat in self.original_feature_names)

    def predict_risk(self, input_data: dict) -> dict:
        """Generate risk prediction with calibrated probabilities"""
        if not self._validate_input(input_data):
            return {
                'error': f"Missing features. Required: {self.original_feature_names}",
                'model_version': self.model_version
            }
            
        try:
            df = pd.DataFrame([input_data])
            proba = self.model.predict_proba(df)[0, 1]
            return {
                'risk_score': float(proba),
                'confidence': float(abs(proba - 0.5) * 2),
                'model_version': self.model_version
            }
        except Exception as e:
            return {
                'error': f"Prediction failed: {str(e)}",
                'model_version': self.model_version
            }

    def explain_risk(self, input_data: dict, tool_context=None) -> dict:
        """Generate regulator-ready explanation with audit trail"""
        if not self._validate_input(input_data):
            return {
                'error': f"Missing features. Required: {self.original_feature_names}",
                'model_version': self.model_version
            }

        try:
            # Handle session state if context exists
            if tool_context:
                history = tool_context.state.get("conversation_history", [])
                history.append({
                    "input": input_data,
                    "timestamp": datetime.now().isoformat()
                })
                tool_context.state["conversation_history"] = history

            # Generate SHAP explanation
            df = pd.DataFrame([input_data])
            processed = self.preprocessor.transform(df)
            shap_values = self.explainer.shap_values(processed)[0]
            
            # Format output
            top_factors = self._format_shap_factors(shap_values)
            plot_path = self._generate_shap_plot(shap_values)
            audit_id = str(uuid.uuid4())

            # Add audit log if context exists
            if tool_context:
                audit_log = tool_context.state.get("audit_log", [])
                audit_log.append({
                    "audit_id": audit_id,
                    "input": input_data,
                    "result": top_factors,
                    "timestamp": datetime.now().isoformat()
                })
                tool_context.state["audit_log"] = audit_log

            return {
                'factors': top_factors,
                'plot_path': plot_path,
                'compliance_note': "Explanation compliant with EU AI Act Article 13",
                'model_version': self.model_version,
                'audit_id': audit_id
            }
        except Exception as e:
            return {
                'error': f"Explanation failed: {str(e)}",
                'model_version': self.model_version
            }

    def _format_shap_factors(self, shap_values):
        """Format top 3 risk factors with impact scores"""
        top_idx = np.argsort(-np.abs(shap_values))[:3]
        return [
            f"{self.feature_names[i]} ({shap_values[i]:+.2f} impact)"
            for i in top_idx
        ]

    def _generate_shap_plot(self, shap_values):
        """Generate and save SHAP decision plot"""
        plot_id = str(uuid.uuid4())
        plt.figure(figsize=(10, 5))
        shap.decision_plot(
            self.explainer.expected_value,
            shap_values,
            self.feature_names,
            show=False
        )
        plt.title("Risk Decision Breakdown", fontsize=12)
        plt.tight_layout()
        plot_path = f"shap_plot_{plot_id}.png"
        plt.savefig(plot_path, bbox_inches='tight')
        plt.close()
        return plot_path


In [5]:
# Initialize service components
# Create agents with PROPER parameters
risk_service = InsuranceRiskAgents()
predict_risk_tool = FunctionTool(risk_service.predict_risk)
explain_risk_tool = FunctionTool(risk_service.explain_risk)
risk_agent = Agent(
    name="risk_modeling_agent",
    model="gemini-2.0-flash",
    tools=[predict_risk_tool],
    instruction="Provide risk scores with confidence intervals",
    description="Production risk scoring engine"
)

explain_agent = Agent(
    name="compliance_explainer",
    model="gemini-2.0-pro",
    tools=[explain_risk_tool],
    instruction="Generate regulator-friendly explanations with SHAP plots",
    description="Enterprise explainability service"
)

import nest_asyncio
import asyncio
from google.adk.sessions import InMemorySessionService

# Allow nested async in Jupyter/Colab
nest_asyncio.apply()

# Initialize session service
session_service = InMemorySessionService()

# Create session with proper async handling
session = asyncio.get_event_loop().run_until_complete(
    session_service.create_session(
        app_name="insurance_risk",
        user_id="regulator_123",
        session_id="audit_456"
    )
)

# Now access state safely
session.state["model_version"] = risk_service.model_version
session.state["conversation_history"] = []


In [6]:
# Usage Example ===============================================================
if __name__ == "__main__":
    # First train and save model (do once)
    # train_and_save_model(X_train, y_train, numeric_features, categorical_features)
    
    # Initialize agents
    agents = InsuranceRiskAgents()
    
    # Sample prediction
    sample_input = {
    'age': 45,
    'income': 85000,
    'vehicle_age': 3,                # Added missing numeric feature
    'vehicle_value': 35000,
    'premium_amount': 1200,          # Added missing numeric feature
    'vehicle_brand': 'Toyota',       # Added categorical features
    'occupation': 'engineer',
    'city': 'Metropolis',
    'policy_type': 'Comprehensive',
    'gender': 'M'}
    
    prediction = agents.predict_risk(sample_input)
    explanation = agents.explain_risk(sample_input)
    
    print("Risk Prediction:", prediction)
    print("Explanation:", explanation)

Risk Prediction: {'risk_score': 0.9766454100608826, 'confidence': 0.9532908201217651, 'model_version': '1.0.20240615'}
Explanation: {'factors': ['vehicle_age (+0.18 impact)', 'policy_type_Comprehensive (+0.14 impact)', 'age (+0.05 impact)'], 'plot_path': 'shap_plot_de4cb2f0-39d5-4459-882c-bc7a8fd42ebb.png', 'compliance_note': 'Explanation compliant with EU AI Act Article 13', 'model_version': '1.0.20240615', 'audit_id': 'b865ebcd-f222-4ded-ac63-0aa62f054a1d'}


In [7]:
# After running a prediction
explanation = agents.explain_risk(sample_input)
print("Why did I get a high risk score?")
print("Answer:", explanation['factors'])

Why did I get a high risk score?
Answer: ['vehicle_age (+0.18 impact)', 'policy_type_Comprehensive (+0.14 impact)', 'age (+0.05 impact)']


In [8]:
# After running a prediction
explanation = agents.explain_risk(sample_input)
print("Is this a good score?")
print("Answer:", explanation['factors'])

Is this a good score?
Answer: ['vehicle_age (+0.18 impact)', 'policy_type_Comprehensive (+0.14 impact)', 'age (+0.05 impact)']
