<a href="https://colab.research.google.com/github/Naaao9999/shap_industries/blob/main/SHAP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install japanize_matplotlib -q

In [None]:
import logging
import pandas as pd
import numpy as np
import lightgbm as lgb
import shap
import matplotlib.pyplot as plt
import japanize_matplotlib
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_curve, auc as sklearn_auc
)
from typing import Tuple, Dict

In [None]:
# ログ設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 設定
CONFIG = {
    "INPUT_FILE": "投入用.csv",
    "RESULT_FILE": "result.csv",
    "ENCODING": "shift-jis",
    "TEST_SIZE": 0.2,
    "RANDOM_STATE": 37,
    "THRESHOLD": 0.5,
    "EARLY_STOPPING_ROUNDS": 100,
    "NUM_BOOST_ROUND": 1000,
    "FEATURE_COLUMNS": ["最短港湾", "最短空港", "賃金", "市場距離"],
}


def load_and_split_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
    """データの読み込みと分割"""
    try:
        df = pd.read_csv(
            CONFIG["INPUT_FILE"],
            encoding=CONFIG["ENCODING"],
            thousands=',',
            index_col=1
        )
        train_set, test_set = train_test_split(
            df,
            test_size=CONFIG["TEST_SIZE"],
            random_state=CONFIG["RANDOM_STATE"]
        )
        logger.info(f"データを読み込みました: {df.shape}")
        return train_set, test_set
    except Exception as e:
        logger.error(f"データ読み込みエラー: {e}")
        raise

def update_industries_from_file(config: dict) -> list:
    """投入ファイルの先頭行から産業名リストを取得"""
    try:
        # ヘッダー行のみ読み込み（nrows=0でデータ行は読み込まない）
        df = pd.read_csv(
            config["INPUT_FILE"],
            encoding=config["ENCODING"],
            nrows=0
        )

        # 全カラム名から除外すべきカラムを特定
        all_columns = list(df.columns)

        # 除外すべきカラム：先頭2列（自治体コード・自治体名）とFEATURE_COLUMNS
        exclude_columns = all_columns[:2] + config["FEATURE_COLUMNS"]

        # 除外すべきカラムを除いて産業名リストを作成
        industries = [col for col in all_columns if col not in exclude_columns]

        logger.info(f"ファイルから{len(industries)}個の産業名を取得しました")
        logger.info(f"除外されたカラム: {exclude_columns}")
        print("取得した産業名:", industries)
        return industries

    except Exception as e:
        logger.error(f"産業名の取得でエラーが発生しました: {e}")
        # エラー時は空のリストを返す（元のINDUSTRIESが未定義のため）
        return []



def prepare_data(train_set: pd.DataFrame, test_set: pd.DataFrame,
                industry: str) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
    """特徴量とターゲットの準備"""
    X_train = train_set.loc[:, CONFIG["FEATURE_COLUMNS"]]
    y_train = train_set[industry]
    X_test = test_set.loc[:, CONFIG["FEATURE_COLUMNS"]]
    y_test = test_set[industry]

    return X_train, y_train, X_test, y_test


def get_lgb_params() -> Dict:
    """LightGBMパラメータの取得"""
    return {
        'objective': 'binary',
        'boosting_type': 'gbdt',
        'metric': 'binary_logloss',
        'verbosity': -1,
        'seed': CONFIG["RANDOM_STATE"]
    }


def train_model(X_train: pd.DataFrame, y_train: pd.Series,
               X_test: pd.DataFrame, y_test: pd.Series) -> lgb.Booster:
    """モデルの訓練"""
    lgb_train = lgb.Dataset(X_train, y_train)
    lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)

    # コールバックを使用してearly stoppingを設定
    callbacks = [
        lgb.early_stopping(CONFIG["EARLY_STOPPING_ROUNDS"]),
        lgb.log_evaluation(0)  # ログ出力を抑制
    ]

    model = lgb.train(
        get_lgb_params(),
        train_set=lgb_train,
        valid_sets=[lgb_eval],
        num_boost_round=CONFIG["NUM_BOOST_ROUND"],
        callbacks=callbacks
    )

    return model


def generate_predictions(model: lgb.Booster, X_test: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
    """予測の生成"""
    y_pred_prob = model.predict(X_test, num_iteration=model.best_iteration)
    y_pred = np.where(y_pred_prob < CONFIG["THRESHOLD"], 0, 1)
    return y_pred_prob, y_pred


def generate_shap_plots(model: lgb.Booster, X_train: pd.DataFrame, industry: str) -> None:
    """SHAP可視化の生成（メモリ管理改善版）"""
    try:
        explainer = shap.TreeExplainer(model)

        sample_size = min(500, len(X_train))
        X_sample = X_train.sample(n=sample_size, random_state=CONFIG["RANDOM_STATE"])
        shap_values = explainer.shap_values(X_sample)

        if isinstance(shap_values, list) and len(shap_values) == 2:
            shap_vals = shap_values[1]
        else:
            shap_vals = shap_values

        # バー図
        plt.figure(figsize=(8, 6))
        shap.summary_plot(shap_vals, X_sample, plot_type="bar",
                        feature_names=X_sample.columns, show=False)
        plt.title(f"{industry} - 特徴量重要度")
        plt.tight_layout()
        plt.savefig(f"{industry}_bar.png", dpi=300, bbox_inches='tight')
        plt.show()  # Colabで直接表示
        plt.close()

        # バイオリン図
        plt.figure(figsize=(8, 6))
        shap.summary_plot(shap_vals, X_sample, feature_names=X_sample.columns, show=False)
        plt.title(f"{industry} - SHAP要約")
        plt.tight_layout()
        plt.savefig(f"{industry}_violin.png", dpi=300, bbox_inches='tight')
        plt.show()  # Colabで直接表示
        plt.close()

        # 特徴量依存関係図
        try:
            feature_importance = np.abs(shap_vals).mean(0)
            most_important_feature = feature_importance.argmax()

            plt.figure(figsize=(8, 6))
            shap.dependence_plot(most_important_feature, shap_vals, X_sample,
                               feature_names=X_sample.columns, show=False)
            plt.title(f"{industry} - {X_sample.columns[most_important_feature]}の依存関係")
            plt.tight_layout()
            plt.savefig(f"{industry}_dependence.png", dpi=300, bbox_inches='tight')
            plt.show()  # Colabで直接表示
            plt.close()

        except Exception as e:
            logger.warning(f"依存関係図の作成でエラーが発生しました ({industry}): {e}")

    except Exception as e:
        logger.warning(f"SHAP可視化でエラーが発生しました ({industry}): {e}")


def calculate_metrics(y_test: pd.Series, y_pred: np.ndarray, y_pred_prob: np.ndarray) -> Dict[str, float]:
    """評価指標の計算"""
    acc = accuracy_score(y_test, y_pred)
    prc = precision_score(y_test, y_pred, zero_division=0)
    rec = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)

    try:
        fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
        auc_score = sklearn_auc(fpr, tpr)  # 名前を変更したimportを使用
    except ValueError:
        logger.warning("AUC計算でエラーが発生しました。デフォルト値0.5を使用します。")
        auc_score = 0.5

    return {
        "accuracy": acc,
        "precision": prc,
        "recall": rec,
        "f1": f1,
        "auc": auc_score
    }


def save_predictions(y_test: pd.Series, y_pred: np.ndarray, industry: str) -> None:
    """予測結果の保存"""
    df_pred = pd.DataFrame({'target': y_test, 'target_pred': y_pred})
    df_pred.to_csv(f"{industry}.csv", encoding=CONFIG["ENCODING"])


def classify_industry(train_set: pd.DataFrame, test_set: pd.DataFrame, industry: str) -> Dict[str, float]:
    """単一産業の分類実行"""
    logger.info(f"分析開始: {industry}")

    try:
        # データ準備
        X_train, y_train, X_test, y_test = prepare_data(train_set, test_set, industry)

        # モデル訓練
        model = train_model(X_train, y_train, X_test, y_test)

        # 予測生成
        y_pred_prob, y_pred = generate_predictions(model, X_test)

        # SHAP可視化
        generate_shap_plots(model, X_train, industry)

        # 予測結果保存
        save_predictions(y_test, y_pred, industry)

        # 評価指標計算
        result = calculate_metrics(y_test, y_pred, y_pred_prob)

        logger.info(f"分析完了: {industry} - AUC: {result['auc']:.3f}")
        return result

    except Exception as e:
        logger.error(f"産業 {industry} の分析でエラーが発生しました: {e}")
        return {"accuracy": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0, "auc": 0.0}


def load_or_create_result_df() -> pd.DataFrame:
    """結果DataFrameの読み込みまたは作成"""
    try:
        result_df = pd.read_csv(CONFIG["RESULT_FILE"], encoding=CONFIG["ENCODING"], index_col=0)
    except FileNotFoundError:
        result_df = pd.DataFrame(
            index=CONFIG["INDUSTRIES"],
            columns=["Accuracy", "Precision", "Recall", "F値", "AUC"]
        )
    return result_df


def run_all_industries(train_set: pd.DataFrame, test_set: pd.DataFrame) -> pd.DataFrame:
    """全産業の分析実行"""
    # 結果DataFrame初期化
    result_df = load_or_create_result_df()

    total_industries = len(CONFIG["INDUSTRIES"])
    logger.info(f"=== 分析開始: 合計 {total_industries} 産業 ===")

    # 各産業の分析実行
    for i, industry in enumerate(CONFIG["INDUSTRIES"], 1):
        print(f"\n{'='*60}")
        print(f"進捗: {i}/{total_industries} ({i/total_industries*100:.1f}%)")
        print(f"処理中: {industry}")
        print(f"{'='*60}")

        result = classify_industry(train_set, test_set, industry)
        result_df.loc[industry] = [
            result["accuracy"], result["precision"], result["recall"],
            result["f1"], result["auc"]
        ]

        print(f"完了: {industry} - AUC: {result['auc']:.3f}")

    # 結果保存
    result_df.to_csv(CONFIG["RESULT_FILE"], encoding=CONFIG["ENCODING"])
    logger.info("=== 全ての分析が完了しました ===")

    return result_df


def main():
    """メイン実行関数"""
    try:
        # Excelファイルの先頭行から産業名を動的に取得**
        CONFIG["INDUSTRIES"] = update_industries_from_file(CONFIG)
        logger.info(f"取得した産業名: {CONFIG['INDUSTRIES']}")

        # データ読み込みと分割
        train_set, test_set = load_and_split_data()

        # 全産業の分析実行
        results = run_all_industries(train_set, test_set)

        print("分析結果:")
        print(results)

    except Exception as e:
        logger.error(f"実行中にエラーが発生しました: {e}")



if __name__ == "__main__":
    main()
