In [None]:
import torch
import torch.nn as nn
import numpy as np
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, classification_report

# 1. SETUP FEATURE EXTRACTOR
# We take your existing ResNet model but replace the final "decision" layer
# with an Identity layer. This lets the raw 512 features pass through.
model.fc = nn.Identity()
model.eval() # Set to evaluation mode
model.to(device)

print("Feature Extractor Ready. Model will output 512-dimensional vectors.")

# 2. HELPER FUNCTION TO EXTRACT FEATURES
def extract_features(loader, model, device):
    features_list = []
    labels_list = []

    print(f"Extracting features from {len(loader)} batches...")
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)

            # Get the features (vectors) instead of predictions
            features = model(images)

            # Move to CPU and flatten
            features_list.append(features.cpu().numpy())
            labels_list.append(labels.cpu().numpy())

    # Concatenate all batches into one large array
    X = np.concatenate(features_list, axis=0)
    y = np.concatenate(labels_list, axis=0)
    return X, y

# 3. EXTRACT DATA
# Note: This may take a minute as it processes all images
print("--- Processing Training Data ---")
X_train, y_train = extract_features(train_loader, model, device)

print("--- Processing Test Data ---")
X_test, y_test = extract_features(test_loader, model, device)

print(f"Data Loaded! Training Shape: {X_train.shape}")

# 4. TRAIN XGBOOST
# We use standard parameters for binary classification
print("--- Training XGBoost Classifier ---")
xgb_model = XGBClassifier(
    n_estimators=100,      # Number of trees
    learning_rate=0.1,     # Step size
    max_depth=5,           # Depth of trees
    objective='binary:logistic', # For Real vs Fake
    use_label_encoder=False,
    eval_metric='logloss'
)

xgb_model.fit(X_train, y_train)
print("Training Complete!")

# 5. EVALUATE
y_pred = xgb_model.predict(X_test)
acc = accuracy_score(y_test, y_pred)

print(f"\nFinal XGBoost Accuracy: {acc*100:.2f}%")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Real', 'Fake']))