In [None]:
import pandas as pd
import jieba
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
from joblib import dump
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder

warnings.filterwarnings("ignore", category=UserWarning)

TRAIN_PATH = r"totaltr.csv"
TEST_PATH = r"./data/totalts.csv"  # please replace with your actual path
MODEL_SAVE_PATH = r"./output/models"  # please replace with your actual path

os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

def chinese_text_preprocessor(text):
    if not isinstance(text, str):
        return ""
    words = jieba.cut(text)
    return ' '.join([word for word in words if len(word) > 1])

def load_local_data(train_path, test_path):
    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)

    print("✅ Training set loaded successfully!")
    print(f"Training set shape: {train_df.shape}")
    print(f"Training columns: {train_df.columns}")
    print("\nTraining sample:")
    print(train_df.head())

    print("\n✅ Test set loaded successfully!")
    print(f"Test set shape: {test_df.shape}")
    print("\nTest sample:")
    print(test_df.head())

    return train_df, test_df

def preprocess_chinese_data(df, text_col='input', label_col='output'):
    print("\nMissing value statistics:")
    print(df.isnull().sum())

    df = df.dropna()
    df[text_col] = df[text_col].apply(chinese_text_preprocessor)

    if df[label_col].dtype == 'object':
        le = LabelEncoder()
        df[label_col] = le.fit_transform(df[label_col])
        print("\nLabel encoding applied.")

    return df[text_col], df[label_col]

def train_and_evaluate_chinese(X_train, X_test, y_train, y_test):
    tfidf = TfidfVectorizer(
        max_features=10000,
        ngram_range=(1, 2),
        token_pattern=r'(?u)\b\w+\b'
    )

    models = {
        'Logistic Regression': LogisticRegression(max_iter=1000),
        'Naive Bayes': MultinomialNB(),
        'SVM': svm.LinearSVC(max_iter=1000, dual=False),
        'Decision Tree': DecisionTreeClassifier(max_depth=5)
    }

    results = {}

    for name, model in models.items():
        pipeline = Pipeline([
            ('tfidf', tfidf),
            ('clf', model)
        ])

        print(f"\nTraining {name}...")
        pipeline.fit(X_train, y_train)
        y_pred = pipeline.predict(X_test)

        accuracy = accuracy_score(y_test, y_pred)
        report = classification_report(y_test, y_pred)
        cm = confusion_matrix(y_test, y_pred)

        results[name] = {
            'model': pipeline,
            'accuracy': accuracy,
            'report': report,
            'confusion_matrix': cm
        }

        print(f"{name} Accuracy: {accuracy:.4f}")
        print(f"Classification Report:\n{report}")

    return results

def visualize_results(results):
    accuracies = {name: res['accuracy'] for name, res in results.items()}
    plt.figure(figsize=(10, 6))
    plt.bar(accuracies.keys(), accuracies.values())
    plt.title('Model Accuracy Comparison')
    plt.ylabel('Accuracy')
    plt.ylim(0, 1)
    for i, v in enumerate(accuracies.values()):
        plt.text(i, v + 0.02, f"{v:.4f}", ha='center')
    plt.show()

    plt.figure(figsize=(15, 10))
    for i, (name, res) in enumerate(results.items(), 1):
        plt.subplot(2, 2, i)
        sns.heatmap(res['confusion_matrix'], annot=True, fmt='d', cmap='Blues')
        plt.title(f'{name} Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
    plt.tight_layout()
    plt.show()

def save_all_models(results, save_dir=MODEL_SAVE_PATH):
    for name, res in results.items():
        model = res['model']
        model_dir = os.path.join(save_dir, f'model_{name}')
        os.makedirs(model_dir, exist_ok=True)
        model_path = os.path.join(model_dir, 'model.joblib')
        dump(model, model_path)
        print(f"✅ Model {name} saved to: {model_path}")

def main():
    try:
        plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
        plt.rcParams['axes.unicode_minus'] = False
    except:
        try:
            plt.rcParams['font.sans-serif'] = ['SimHei']
            plt.rcParams['axes.unicode_minus'] = False
        except Exception as e:
            print(e)

    try:
        train_df, test_df = load_local_data(TRAIN_PATH, TEST_PATH)
    except Exception as e:
        print(e)
        return None, None

    try:
        X_train, y_train = preprocess_chinese_data(train_df)
        print("\nTraining data preprocessing completed.")
    except Exception as e:
        print(f"\nTraining data preprocessing failed: {str(e)}")
        return None, None

    try:
        X_test, y_test = preprocess_chinese_data(test_df)
        print("\nTest data preprocessing completed.")
    except Exception as e:
        print(f"\nTest data preprocessing failed: {str(e)}")
        return None, None

    print("\nTraining label distribution:")
    print(pd.Series(y_train).value_counts(normalize=True))

    print("\nTest label distribution:")
    print(pd.Series(y_test).value_counts(normalize=True))

    try:
        results = train_and_evaluate_chinese(X_train, X_test, y_train, y_test)
    except Exception as e:
        print(e)
        return None, None

    try:
        visualize_results(results)
    except Exception as e:
        print(e)

    try:
        best_model_name = max(results, key=lambda k: results[k]['accuracy'])
        best_model = results[best_model_name]['model']
        best_accuracy = results[best_model_name]['accuracy']
        print(f"\nBest Model: {best_model_name} | Accuracy: {best_accuracy:.4f}")
    except Exception as e:
        print(e)
        return None, None

    try:
        save_all_models(results)
    except Exception as e:
        print(e)

if __name__ == "__main__":
    best_model, model_path = main()
