# Wine_Claude_Enhanced: Advanced Gradient Selection Methods

This notebook tests the enhanced gradient selection methods:
- **ContrastiveSelector**: Uses BOTH correct AND incorrect partial data
- **EnsembleSelector**: Combines multiple anomaly detection methods
- **CentroidDistanceSelector**: Distance-based selection
- **AdaptiveSelector**: Adjusts strictness over epochs

Also includes **diagnostic tools** to understand gradient behavior and plan improvements.

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import sys
sys.path.insert(0, '../')
sys.path.insert(0, '../GGH')

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import r2_score, silhouette_score
from scipy.stats import ks_2samp

from GGH.data_ops import DataOperator
from GGH.selection_algorithms import AlgoModulators, compute_individual_grads_nothread
from GGH.models import initialize_model, load_model
from GGH.train_val_loop import TrainValidationManager
from GGH.inspector import Inspector, get_gradarrays_n_labels, get_label

# Import enhanced modules
from GGH.gradient_diagnostics import (
    GradientDiagnostics,
    EnrichedVectorBuilder,
    compute_gradient_statistics,
    visualize_gradient_space
)
from GGH.enhanced_selection import (
    ContrastiveSelector,
    EnsembleSelector,
    CentroidDistanceSelector,
    AdaptiveSelector
)

import warnings
warnings.filterwarnings('ignore')

def set_to_deterministic(rand_state):
    import random
    random.seed(rand_state)
    np.random.seed(rand_state)
    torch.manual_seed(rand_state)
    torch.set_num_threads(1)
    torch.use_deterministic_algorithms(True)
    
print("Imports successful!")

In [None]:
# Data configuration
data_path = '../data/wine/red_wine.csv'
results_path = "../saved_results/Red Wine Enhanced"
inpt_vars = ['volatile acidity', 'total sulfur dioxide', 'citric acid'] 
target_vars = ['quality']
miss_vars = ['alcohol']
hypothesis = [[9.35, 10, 11.5, 15]]

# Model parameters
hidden_size = 32
batch_size = 100 * len(hypothesis[0])  # 400
partial_perc = 0.025  # Start with 2.5%

# Create directories
import os
os.makedirs(results_path, exist_ok=True)

# Initialize
INSPECT = Inspector(results_path, hidden_size)
diagnostics = GradientDiagnostics(save_path=results_path)

print(f"Results will be saved to: {results_path}")

## Part 1: Diagnostic Analysis of Gradient Behavior

Before testing new methods, let's understand how gradients from correct vs incorrect hypotheses differ.

In [None]:
# Train a single model and collect gradients for analysis
rand_state = 0
set_to_deterministic(rand_state)

DO = DataOperator(data_path, inpt_vars, target_vars, miss_vars, hypothesis,
                  partial_perc, rand_state, device='cpu')
DO.problem_type = 'regression'

if DO.lack_partial_coverage:
    print("WARNING: Insufficient partial coverage. Try different random state.")
else:
    print(f"Partial data rows: {len(DO.partial_rows_id)}")
    print(f"Total hypothesis combinations: {DO.num_hyp_comb}")
    print(f"Training data shape: {DO.df_train_hypothesis.shape}")

In [None]:
# Train model with standard settings
num_epochs = 30

AM = AlgoModulators(DO, lr=0.002, nu=0.1, normalize_grads_contx=False,
                   use_context=True, freqperc_cutoff=0.25)
dataloader = DO.prep_dataloader('use hypothesis', batch_size)
model = initialize_model(DO, dataloader, hidden_size, rand_state, dropout=0.05)

TVM = TrainValidationManager('use hypothesis', num_epochs, dataloader, batch_size,
                             rand_state, results_path, final_analysis=False)
TVM.train_model(DO, AM, model, final_analysis=False)

print(f"Training complete. Best validation loss: {min(TVM.valid_errors_epoch):.6f}")

In [None]:
# Extract gradients from the last epoch for analysis
hyp_class = 2  # Analyze one hypothesis class

# Get gradient arrays with labels
grad_arrays, df_with_labels = get_gradarrays_n_labels(
    DO, hyp_class, 
    layer=-2, 
    remov_avg=True,  # Remove average gradient signal
    include_context=True, 
    normalize_grads_context=False,
    num_batches=3, 
    epoch=-1,  # Last epoch
    use_case="hypothesis"
)

labels = df_with_labels['label'].values

# Label meanings:
# 0: Incorrect hypothesis, no partial info
# 1: Correct hypothesis, no partial info  
# 2: Incorrect hypothesis, has partial info
# 3: Correct hypothesis, has partial info (ground truth)

print(f"\nGradient array shape: {grad_arrays.shape}")
print(f"Label distribution:")
for l in np.unique(labels):
    label_names = {0: 'Incorrect (no partial)', 1: 'Correct (no partial)', 
                   2: 'Incorrect (partial)', 3: 'Correct (partial)'}
    print(f"  {label_names.get(l, l)}: {np.sum(labels == l)}")

In [None]:
# Compute separability metrics
correct_mask = (labels == 1) | (labels == 3)
incorrect_mask = (labels == 0) | (labels == 2)
partial_correct_mask = labels == 3

correct_grads = grad_arrays[correct_mask]
incorrect_grads = grad_arrays[incorrect_mask]
partial_grads = grad_arrays[partial_correct_mask]

print("\n" + "="*60)
print("GRADIENT SEPARABILITY ANALYSIS")
print("="*60)

if len(correct_grads) > 0 and len(incorrect_grads) > 0:
    metrics = diagnostics.compute_separability_metrics(
        correct_grads, incorrect_grads, partial_grads
    )
    
    print("\nKey Metrics:")
    print(f"  Silhouette Score: {metrics.get('silhouette_score', 0):.4f} (range -1 to 1, higher=better)")
    print(f"  Centroid Distance: {metrics.get('centroid_distance', 0):.4f} (higher=more separated)")
    print(f"  Wasserstein Distance: {metrics.get('wasserstein_distance', 0):.4f} (higher=more different)")
    print(f"  KS Statistic: {metrics.get('ks_statistic', 0):.4f} (higher=more different distributions)")
    
    if 'partial_alignment_ratio' in metrics:
        print(f"\nPartial Data Alignment:")
        print(f"  Correct-to-Partial distance: {metrics.get('correct_to_partial_dist', 0):.4f}")
        print(f"  Incorrect-to-Partial distance: {metrics.get('incorrect_to_partial_dist', 0):.4f}")
        print(f"  Alignment ratio: {metrics.get('partial_alignment_ratio', 0):.4f} (>1 means correct is closer)")
else:
    print("Insufficient data for metrics computation")

In [None]:
# Visualize gradient distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 1. Gradient magnitude histogram
ax = axes[0]
grad_magnitudes = np.linalg.norm(grad_arrays, axis=1)

for label_val, color, name in [(1, 'green', 'Correct'), (0, 'red', 'Incorrect'),
                                (3, 'blue', 'Partial Correct')]:
    mask = labels == label_val
    if np.sum(mask) > 0:
        ax.hist(grad_magnitudes[mask], alpha=0.5, bins=20, label=name, color=color)

ax.set_xlabel('Gradient Magnitude')
ax.set_ylabel('Count')
ax.set_title('Gradient Magnitude Distribution')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. t-SNE visualization
ax = axes[1]
if len(grad_arrays) > 10:
    perplexity = min(30, len(grad_arrays) - 1)
    tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
    grad_tsne = tsne.fit_transform(grad_arrays)
    
    colors_map = {0: 'red', 1: 'green', 2: 'orange', 3: 'blue'}
    for label_val in np.unique(labels):
        mask = labels == label_val
        label_names = {0: 'Incorrect', 1: 'Correct', 2: 'Inc. Partial', 3: 'Correct Partial'}
        ax.scatter(grad_tsne[mask, 0], grad_tsne[mask, 1], 
                  c=colors_map.get(label_val, 'gray'), 
                  label=label_names.get(label_val), alpha=0.6, s=50)
    
    ax.set_xlabel('t-SNE 1')
    ax.set_ylabel('t-SNE 2')
    ax.set_title('t-SNE of Enriched Gradient Vectors')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{results_path}/gradient_analysis.png', dpi=150)
plt.show()

In [None]:
# Analyze feature importance in enriched vectors
binary_labels = np.where((labels == 1) | (labels == 3), 1, 0)

importance_df = diagnostics.analyze_feature_importance(
    grad_arrays, binary_labels, 
    gradient_dim=hidden_size  # First hidden_size features are gradients
)

print("\n" + "="*60)
print("FEATURE IMPORTANCE ANALYSIS")
print("="*60)
print("\nTop 15 most discriminative features:")
print(importance_df[['feature', 'cohens_d', 'ks_statistic', 'is_gradient']].head(15).to_string())

# Summary
grad_features = importance_df[importance_df['is_gradient'] == True]
context_features = importance_df[importance_df['is_gradient'] == False]

print(f"\nAverage importance (Cohen's d):")
print(f"  Gradient features: {grad_features['cohens_d'].mean():.4f}")
print(f"  Context features: {context_features['cohens_d'].mean():.4f}")

## Part 2: Test Enhanced Selection Methods

Now test different selection algorithms to see which can better separate correct from incorrect hypotheses.

In [None]:
# Compare different selectors on the extracted gradients
# Use partial correct (label=3) as "known correct" training data
# Use partial incorrect (label=2) as "known incorrect" training data

partial_correct = grad_arrays[labels == 3]
partial_incorrect = grad_arrays[labels == 2]
unknown = grad_arrays[(labels == 0) | (labels == 1)]
unknown_labels = labels[(labels == 0) | (labels == 1)]
unknown_true = (unknown_labels == 1).astype(int)  # 1 if correct, 0 if incorrect

print(f"Training data: {len(partial_correct)} correct, {len(partial_incorrect)} incorrect")
print(f"Unknown data: {len(unknown)} samples ({np.sum(unknown_true)} correct, {len(unknown_true) - np.sum(unknown_true)} incorrect)")

In [None]:
# Test each selector
selectors = {
    'ContrastiveSelector (SVM)': ContrastiveSelector(classifier='svm'),
    'ContrastiveSelector (RF)': ContrastiveSelector(classifier='rf'),
    'EnsembleSelector (soft)': EnsembleSelector(voting='soft', threshold=0.5),
    'EnsembleSelector (hard)': EnsembleSelector(voting='hard'),
    'CentroidDistanceSelector': CentroidDistanceSelector(margin=0.0),
    'AdaptiveSelector (nu=0.3->0.1)': AdaptiveSelector(initial_nu=0.3, final_nu=0.1),
}

results = []

print("\n" + "="*60)
print("SELECTOR COMPARISON")
print("="*60 + "\n")

for name, selector in selectors.items():
    try:
        # Fit and predict
        selector.fit(partial_correct, partial_incorrect)
        result = selector.predict(unknown)
        
        # Evaluate
        predicted_correct = set(result.selected_indices)
        true_correct = set(np.where(unknown_true == 1)[0])
        
        tp = len(predicted_correct & true_correct)
        fp = len(predicted_correct - true_correct)
        fn = len(true_correct - predicted_correct)
        tn = len(unknown) - tp - fp - fn
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        accuracy = (tp + tn) / len(unknown)
        
        results.append({
            'Selector': name,
            'Precision': precision,
            'Recall': recall,
            'F1': f1,
            'Accuracy': accuracy,
            'Selection Rate': len(predicted_correct) / len(unknown)
        })
        
        print(f"{name}:")
        print(f"  Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
        print(f"  Accuracy: {accuracy:.4f}, Selection Rate: {len(predicted_correct)/len(unknown):.4f}")
        print()
        
    except Exception as e:
        print(f"{name}: ERROR - {e}\n")

results_df = pd.DataFrame(results)
print("\nSummary Table:")
print(results_df.to_string(index=False))

## Part 3: Full Training with Enhanced Selectors

Test if enhanced selectors improve final model performance.

In [None]:
def train_with_selector(selector_type, partial_perc, num_epochs, lr, nu, n_runs=5):
    """
    Train model with specified selector type.
    selector_type: 'standard' (original), 'contrastive', 'ensemble', 'centroid', 'adaptive'
    """
    results = []
    
    for r_state in range(100):
        set_to_deterministic(r_state)
        
        DO = DataOperator(data_path, inpt_vars, target_vars, miss_vars, hypothesis,
                          partial_perc, r_state, device='cpu')
        DO.problem_type = 'regression'
        
        if not DO.lack_partial_coverage:
            AM = AlgoModulators(DO, lr=lr, nu=nu, normalize_grads_contx=False,
                               use_context=True, freqperc_cutoff=0.25)
            
            dataloader = DO.prep_dataloader('use hypothesis', batch_size)
            model = initialize_model(DO, dataloader, hidden_size, r_state, dropout=0.05)
            
            TVM = TrainValidationManager('use hypothesis', num_epochs, dataloader, batch_size,
                                         r_state, results_path, final_analysis=False)
            
            # Train - for now use standard training
            # TODO: Integrate enhanced selectors into training loop
            TVM.train_model(DO, AM, model, final_analysis=False)
            
            # Evaluate
            model.load_state_dict(torch.load(TVM.weights_save_path))
            model.eval()
            test_pred = model(DO.full_test_input_tensor)
            test_true = DO.df_test[target_vars].values
            r2 = r2_score(test_true, test_pred.detach().numpy())
            results.append(r2)
            
            if len(results) >= n_runs:
                break
    
    return np.mean(results), np.std(results), results

print("Training function defined.")

In [None]:
# Quick comparison: baseline vs hypothesis selection
partial_perc = 0.025
num_epochs = 60
lr = 0.002
nu = 0.1

print("Running baseline comparison...\n")

# Partial info baseline
print("Testing partial info baseline...")
p_mean, p_std, _ = train_with_selector('standard', partial_perc, 200, 0.001, nu, n_runs=5)
print(f"  partial info: {p_mean:.4f} +/- {p_std:.4f}")

# Standard hypothesis selection
print("Testing use hypothesis (standard)...")
h_mean, h_std, _ = train_with_selector('standard', partial_perc, num_epochs, lr, nu, n_runs=5)
print(f"  use hypothesis: {h_mean:.4f} +/- {h_std:.4f}")

improvement = (h_mean - p_mean) * 100
print(f"\nImprovement: {improvement:.2f} percentage points")

## Part 4: Gradient Engineering - Improving Separability

If performance is not improving, we need to investigate:
1. What features best separate correct from incorrect?
2. How can we modify the enriched vectors to improve separation?
3. What additional context might help?

In [None]:
# Test different enriched vector configurations
print("\n" + "="*60)
print("ENRICHED VECTOR EXPERIMENTS")
print("="*60)

# Get raw gradients from training (reuse from earlier)
# For simplicity, use the already extracted gradients

# Simulate different enrichment strategies by modifying grad_arrays
configs = [
    ('Raw gradients only', grad_arrays[:, :hidden_size]),
    ('Gradients + input context', grad_arrays),
    ('Gradient magnitude only', np.linalg.norm(grad_arrays[:, :hidden_size], axis=1, keepdims=True)),
]

for name, vectors in configs:
    if len(vectors.shape) == 1:
        vectors = vectors.reshape(-1, 1)
    
    # Compute separability with these vectors
    correct_v = vectors[correct_mask]
    incorrect_v = vectors[incorrect_mask]
    
    if len(correct_v) > 0 and len(incorrect_v) > 0:
        # Centroid distance
        centroid_dist = np.linalg.norm(np.mean(correct_v, axis=0) - np.mean(incorrect_v, axis=0))
        
        # Try silhouette if enough samples
        combined = np.vstack([correct_v, incorrect_v])
        combined_labels = np.array([1]*len(correct_v) + [0]*len(incorrect_v))
        try:
            sil = silhouette_score(combined, combined_labels)
        except:
            sil = 0
        
        print(f"\n{name}:")
        print(f"  Vector dimension: {vectors.shape[1]}")
        print(f"  Centroid distance: {centroid_dist:.4f}")
        print(f"  Silhouette score: {sil:.4f}")

In [None]:
# Investigate why separation might be failing
print("\n" + "="*60)
print("INVESTIGATION: Why might separation be failing?")
print("="*60)

# 1. Check gradient statistics
print("\n1. GRADIENT STATISTICS:")
grad_only = grad_arrays[:, :hidden_size]
correct_grads_only = grad_only[correct_mask]
incorrect_grads_only = grad_only[incorrect_mask]

print(f"   Correct hypothesis gradients:")
print(f"     Mean magnitude: {np.mean(np.linalg.norm(correct_grads_only, axis=1)):.6f}")
print(f"     Std magnitude: {np.std(np.linalg.norm(correct_grads_only, axis=1)):.6f}")

print(f"   Incorrect hypothesis gradients:")
print(f"     Mean magnitude: {np.mean(np.linalg.norm(incorrect_grads_only, axis=1)):.6f}")
print(f"     Std magnitude: {np.std(np.linalg.norm(incorrect_grads_only, axis=1)):.6f}")

# 2. Check overlap in distributions
print("\n2. DISTRIBUTION OVERLAP:")
correct_mags = np.linalg.norm(correct_grads_only, axis=1)
incorrect_mags = np.linalg.norm(incorrect_grads_only, axis=1)

# Compute overlap
overlap_min = max(correct_mags.min(), incorrect_mags.min())
overlap_max = min(correct_mags.max(), incorrect_mags.max())
if overlap_max > overlap_min:
    correct_in_overlap = np.sum((correct_mags >= overlap_min) & (correct_mags <= overlap_max)) / len(correct_mags)
    incorrect_in_overlap = np.sum((incorrect_mags >= overlap_min) & (incorrect_mags <= overlap_max)) / len(incorrect_mags)
    print(f"   Overlap region: [{overlap_min:.4f}, {overlap_max:.4f}]")
    print(f"   % correct in overlap: {correct_in_overlap*100:.1f}%")
    print(f"   % incorrect in overlap: {incorrect_in_overlap*100:.1f}%")

# 3. Check if partial data is representative
print("\n3. PARTIAL DATA REPRESENTATIVENESS:")
partial_mags = np.linalg.norm(grad_arrays[partial_correct_mask, :hidden_size], axis=1)
print(f"   Partial correct gradients: {len(partial_mags)} samples")
if len(partial_mags) > 0:
    print(f"   Mean magnitude: {np.mean(partial_mags):.6f}")
    print(f"   Range: [{partial_mags.min():.6f}, {partial_mags.max():.6f}]")

In [None]:
# Generate recommendations based on analysis
print("\n" + "="*60)
print("RECOMMENDATIONS FOR IMPROVING GRADIENT SEPARATION")
print("="*60)

print("""
Based on the diagnostic analysis, consider these approaches:

1. GRADIENT NORMALIZATION:
   - If gradient magnitudes overlap significantly, try L2 normalizing
   - This focuses on gradient DIRECTION rather than magnitude

2. CONTEXT WEIGHTING:
   - If context features are more discriminative than gradients,
     increase their weight in the enriched vector
   - Or: use context features ONLY for selection

3. MULTI-EPOCH AGGREGATION:
   - Instead of using gradients from a single epoch,
     aggregate across multiple epochs (e.g., moving average)
   - This can reduce noise in gradient estimates

4. LOSS-WEIGHTED GRADIENTS:
   - Weight gradients by their corresponding loss values
   - Samples with higher loss may have more informative gradients

5. CONTRASTIVE LEARNING:
   - Use BOTH correct AND incorrect partial data for training
   - The ContrastiveSelector should help with this

6. ENSEMBLE METHODS:
   - Combine multiple selection criteria
   - More robust than relying on a single method

7. ADAPTIVE THRESHOLDING:
   - Start permissive (select more hypotheses)
   - Become stricter as training progresses
""")

## Part 5: Custom Experiments

Add your own experiments here based on the diagnostic insights.

In [None]:
# Space for custom experiments
# Based on the diagnostics above, try different approaches:

# Example: Test with normalized gradients
# normalized_grads = grad_arrays / (np.linalg.norm(grad_arrays, axis=1, keepdims=True) + 1e-10)
# Then rerun selector comparison...

print("Add your custom experiments here based on diagnostic insights.")