In [None]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings("ignore")

def load_data():
    train = pd.read_csv("train.csv")
    test = pd.read_csv("test.csv")

    # Drop target column if present in train
    if 'target' in train.columns:
        train = train.drop(columns=['target'])

    # Add identifier for adversarial validation
    train['is_test'] = 0
    test['is_test'] = 1

    return train, test

def preprocess(train, test):
    combined = pd.concat([train, test], axis=0, ignore_index=True)

    # Label encoding for categorical features
    for col in combined.select_dtypes(include=['object']).columns:
        le = LabelEncoder()
        combined[col] = combined[col].fillna("NA")
        combined[col] = le.fit_transform(combined[col])

    # Fill missing values
    combined = combined.fillna(-999)

    return combined

def run_adversarial_validation(combined):
    X = combined.drop(columns=["is_test"])
    y = combined["is_test"]

    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

    clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
    clf.fit(X_train, y_train)

    val_preds = clf.predict_proba(X_val)[:, 1]
    auc = roc_auc_score(y_val, val_preds)
    print(f"\nAdversarial Validation AUC: {auc:.4f}")

    return auc

def main():
    print("Loading data...")
    train, test = load_data()
    print("Preprocessing...")
    combined = preprocess(train, test)
    print("Running adversarial validation...")
    auc = run_adversarial_validation(combined)

    with open("adversarial_validation_result.txt", "w") as f:
        f.write(f"Adversarial Validation AUC: {auc:.4f}\n")
        if auc > 0.75:
            f.write("=> Significant data drift detected between train and test datasets.\n")
        elif auc < 0.55:
            f.write("=> Minimal or no data drift detected.\n")
        else:
            f.write("=> Moderate data drift detected.\n")

if _name_ == "_main_":
    main()