In [1]:
import gradio as gr
import pandas as pd

In [2]:
path = "data/baf/preprocess/test.csv"
data = pd.read_csv(path, index_col=0).drop(columns=['fraud_bool']).iloc[:100]

In [3]:
data

Unnamed: 0,income,name_email_similarity,prev_address_months_count,current_address_months_count,customer_age,days_since_request,intended_balcon_amount,payment_type,zip_count_4w,velocity_6h,...,has_other_cards,proposed_credit_limit,foreign_request,source,session_length_in_minutes,device_os,keep_alive_session,device_distinct_emails_8w,device_fraud_count,month
795010,0.9,0.115829,-1,176,50,0.006259,-1.046460,AB,2393,5516.425324,...,0,1000.0,0,INTERNET,8.994525,linux,1,1,0,6
795076,0.4,0.875184,-1,301,60,0.024319,-1.500530,AB,509,4976.761277,...,0,2000.0,0,INTERNET,3.024375,linux,1,1,0,6
795115,0.8,0.253221,-1,120,40,0.021385,38.223542,AA,2129,3275.277790,...,0,200.0,0,INTERNET,2.843119,other,1,1,0,6
795159,0.7,0.883076,-1,69,40,0.007541,-1.177743,AC,1392,5266.460512,...,1,990.0,1,INTERNET,9.258009,linux,0,2,0,6
795203,0.4,0.493291,-1,39,30,0.003067,-1.480972,AC,534,1386.970887,...,0,200.0,0,INTERNET,2.600723,other,0,1,0,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
801290,0.9,0.881922,-1,294,50,0.014847,-0.724758,AB,1512,5075.606854,...,0,1000.0,0,INTERNET,6.084658,windows,0,1,0,6
801379,0.9,0.280653,-1,119,20,0.001136,-0.497909,AC,1058,5186.247819,...,1,1000.0,0,INTERNET,2.430285,macintosh,0,1,0,6
801398,0.1,0.370128,-1,100,50,0.006196,-1.102605,AB,891,686.810133,...,1,1500.0,1,INTERNET,3.153512,linux,0,1,0,6
801549,0.7,0.717302,-1,260,60,0.014053,-0.901252,AC,3001,2929.789399,...,1,1500.0,0,INTERNET,3.074278,other,1,2,0,6


In [16]:
def search_index(index):
    try:
        # Convert the input to an integer
        index = int(index)
        # Search for the row with the given index
        row = data.loc[[index]]
        if row.empty:
            return "Index not found"
        else:
            return row
            # return row.to_dict(orient='records')[0]
    except ValueError:
        return "Please enter a valid integer index"

def create_demo():
    with gr.Blocks() as demo:
        index_input = gr.Textbox(label="Enter index")
        output = gr.DataFrame()

        bttn = gr.Button("Search")
        bttn.click(fn=search_index, inputs=index_input, outputs=output)
    return demo

create_demo().launch(inbrowser=True, share=True)

Running on local URL:  http://127.0.0.1:7866
Running on public URL: https://a4f0137d8a015b8bdf.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [17]:

# Current Working Directory : /home/cwl/repo/pnpxai/tutorials/finance
# File : /home/cwl/repo/pnpxai/tutorials/finance/gradio_test.ipynb
import os
import numpy as np
import pandas as pd
import xgboost as xgb
import shap
import gradio as gr
import plotly.graph_objs as go
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

# Load and prepare data
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)

# Train model and calculate SHAP values
model = xgb.XGBClassifier(random_state=42).fit(X_train, y_train)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

def get_model_explanation():
    return """
    <h2>모델 설명 알고리즘</h2>
    <p>이 XGBoost 모델은 다양한 특징을 기반으로 유방암을 예측합니다. 모델의 결정을 TreeSHAP을 사용하여 설명하며, 각 특징이 예측에 미치는 영향을 이해하는 데 도움을 줍니다.</p>
    <h3>TreeSHAP 작동 방식:</h3>
    <ol>
        <li><strong>SHAP 값:</strong> 각 특징이 예측에 기여하는 정도를 나타냅니다.</li>
        <li><strong>특징 중요도:</strong> 인스턴스별 SHAP 값을 집계하여 결정됩니다.</li>
        <li><strong>로컬 설명:</strong> 개별 인스턴스 수준의 통찰을 제공합니다.</li>
        <li><strong>글로벌 통찰:</strong> 전체 모델의 행동을 보여줍니다.</li>
    </ol>
    <h3>결과 해석:</h3>
    <ul>
        <li>양수 SHAP 값: 악성 예측 쪽으로 영향을 줍니다.</li>
        <li>음수 SHAP 값: 양성 예측 쪽으로 영향을 줍니다.</li>
        <li>SHAP 값의 크기: 특징의 영향 강도를 나타냅니다.</li>
    </ul>
    """

def get_dataset_info():
    return f"""
    데이터셋 정보:
    - 샘플 수: {len(data.data)}
    - 특징 수: {len(data.feature_names)}
    - 특징 이름: {', '.join(data.feature_names)}
    - 대상: 암 진단 (0: 양성, 1: 악성)
    - 클래스 분포: {np.bincount(data.target)}
    - 데이터 형태: {data.data.shape}
    """


def plot_feature_importance():
    mean_shap = np.abs(shap_values).mean(0)
    feature_importance = pd.DataFrame({'feature': data.feature_names, 'importance': mean_shap})
    top_features = feature_importance.nlargest(10, 'importance')
    
    fig = go.Figure(go.Bar(
        x=top_features['importance'],
        y=top_features['feature'],
        orientation='h'
        )
    )
    fig.update_layout(
        title='Top 10 Features by SHAP Value Magnitude',
        xaxis_title='Mean |SHAP Value|',
        yaxis_title='Feature'
    )
    return fig


title = "비정상 금융 거래 탐지 모델 (Plug & Play XAI)"
description = "이 인터페이스는 XGBoost 모델을 사용하여 유방암 데이터를 예측하고, SHAP을 사용하여 모델을 설명합니다."

title_html = f"<h1 style='text-align: center;'>{title}</h1>"
description_html = f"<p style='text-align: center;'>{description}</p>"

my_theme = gr.Theme.load("monochrome.json")

def click_button():
    print("Button clicked!")

def test():
    return "sample"
def create_gradio_interface():
    with gr.Blocks(theme=my_theme) as interface:
        # gr.HTML("<img src='./logo.png' style='width: 100%; max-width: 400px; margin: 0 auto;'/>")
    
        gr.HTML(title_html+description_html)
        with gr.Tab("모델 입력"):
            gr.Label("유방암 데이터를 사용하여 모델을 설명합니다.")
            file = gr.File(label="모델 파일을 업로드 해주세요.")
            gr.HTML("<p>이 데이터셋은 유방암 진단에 대한 30가지 특징을 포함하며, 이를 사용하여 악성과 양성 종양을 예측합니다.</p>")
            gr.HTML("<p>아래의 표는 데이터셋의 처음 5개 행을 보여줍니다.</p>")
            data_output = gr.Label("데이터셋 미리보기")
            file.upload(fn=test, outputs=data_output)
            # gr.DataFrame(data.data[:5], headers=data.feature_names)

        with gr.Tab("모델 기본 정보"):
            gr.HTML(get_model_explanation())
            gr.Markdown(get_dataset_info())
        with gr.Tab("모델 전반적 설명(Global Explanation)"):
            pass
        with gr.Tab("개별 모델 의사결정 설명(Local Explanation)"):
            gr.Radio(choices=["XGBoost", "Random Forest", "Logistic Regression"], label="모델 선택")
            gr.Radio(choices=["SHAP", "LIME", "LRP", "IG"], label="설명 알고리즘 선택")
            gr.CheckboxGroup(choices=["Feature Importance", "Local Explanation"], label="설명 타입 선택")
            gr.Dropdown(choices=["Top 10 Features", "All Features"], label="설명 범위 선택")
            # gr.FileExplorer(label="데이터셋 업로드")
            gr.File(label="데이터셋 업로드")
            # gr.Dataset(label="데이터셋 선택", samples=["a", "b", "c"])
            gr.DataFrame(label="데이터셋 미리보기")
            bttn = gr.Button("설명 보기")
            # bttn.click(click_button)
            gr.HTML("<p>아래의 막대 차트는 절대 SHAP 값에 따른 상위 10개 특징을 보여주며, 이는 예측에서의 전반적인 중요도를 나타냅니다.</p>")
            with gr.Row():
                gr.Plot(plot_feature_importance())
                gr.Plot(plot_feature_importance())
                gr.Plot(plot_feature_importance())
                gr.Plot(plot_feature_importance())
        with gr.Tab("설명 알고리즘 간 비교 평가"):
            pass
            
    return interface

# Create and launch the Gradio interface
create_gradio_interface().launch(inbrowser=True, share=True)

Running on local URL:  http://127.0.0.1:7867
Running on public URL: https://a0263944f62927479f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


