In [None]:
# adversarial_validation.py

# Step 1: Import Required Libraries
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns

# Step 2: Load Train and Test Datasets
train = pd.read_csv('path/to/train.csv')  # Update with actual path
test = pd.read_csv('path/to/test.csv')    # Update with actual path

# Step 3: Label the Datasets
train['is_test'] = 0
test['is_test'] = 1

# Step 4: Ensure Common Columns
common_cols = [col for col in train.columns if col in test.columns and col != 'is_test']
train = train[common_cols + ['is_test']]
test = test[common_cols + ['is_test']]

# Step 5: Combine the Datasets
combined = pd.concat([train, test], ignore_index=True)

# Step 6: Handle Missing Values (optional: can also use fillna)
combined = combined.dropna(axis=1)

# Step 7: Encode Categorical Variables
combined = pd.get_dummies(combined)

# Step 8: Prepare Features and Labels
X = combined.drop('is_test', axis=1)
y = combined['is_test']

# Step 9: Split for Training and Validation
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Step 10: Train the Adversarial Classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Step 11: Predict and Evaluate
y_pred = clf.predict_proba(X_val)[:, 1]
auc_score = roc_auc_score(y_val, y_pred)

print(f"Adversarial Validation AUC Score: {auc_score:.4f}")
if auc_score > 0.6:
    print(">> Significant data drift detected between train and test datasets.")
elif auc_score < 0.55:
    print(">> No significant data drift detected.")
else:
    print(">> Slight drift may be present.")

# Step 12: Plot Feature Importances
importances = pd.Series(clf.feature_importances_, index=X.columns)
importances.sort_values(ascending=False).head(20).plot(kind='barh', figsize=(10, 6))
plt.title('Top 20 Important Features Indicating Drift')
plt.xlabel('Importance')
plt.tight_layout()
plt.show()