In [1]:
import torch, jax
print("PyTorch sees:", torch.cuda.device_count(), "GPUs,", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "none")
print("JAX sees:", jax.devices())

PyTorch sees: 2 GPUs, NVIDIA L40S
JAX sees: [CudaDevice(id=0), CudaDevice(id=1)]


In [2]:
import sys
!{sys.executable} -m pip install dill --user

  pid, fd = os.forkpty()




In [3]:
import os
import dill
import numpy as np
import random
from sklearn.neighbors import NearestNeighbors
from scipy.sparse.csgraph import connected_components
import matplotlib.pyplot as plt
from collections import Counter
import pickle
from tqdm import tqdm
import sys
import jax
import jax.numpy as jnp
from jax import jit, vmap, pmap
from functools import partial
import cupy as cp  # For additional GPU acceleration
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp

# Set JAX to use all available GPUs
jax.config.update('jax_enable_x64', True)
print(f"Available JAX devices: {jax.devices()}")
print(f"Number of GPUs: {len(jax.devices('gpu'))}")

class GPULayerMonteCarloKNNOptimizer:
    def __init__(self, data_root='/project2/alvinjin_1630/results/D1', input_b0=9, output_b0=1, use_gpu=True):
        self.data_root = data_root
        self.input_b0 = input_b0
        self.output_b0 = output_b0
        self.layers = None
        self.input_layer_data = None
        self.output_layer_data = None
        self.use_gpu = use_gpu and len(jax.devices('gpu')) > 0
        self.num_devices = len(jax.devices('gpu')) if self.use_gpu else 1
        
        print(f"GPU Acceleration: {'Enabled' if self.use_gpu else 'Disabled'}")
        print(f"Using {self.num_devices} device(s)")
        
    def load_network_and_data(self):
        """Load the neural network and extract layer represenstations"""
        print("Loading neural network and extracting layer representations...")
        
        # Add path and import trainer
        sys.path.append('/project2/alvinjin_1630/results')
        from Trainer import Trainer
        
        # Initialize trainer
        trainer = Trainer(
            dataset='D1',
            hidden_dims=[30] * 4 + [10] * 4,
            act_fn=jax.nn.relu,
            study_name='30x4_10x4_relu',
            residual=False
        )
        
        dataset_path = os.path.join(self.data_root, 'only_0.dill')
        print(f"Loading dataset from: {dataset_path}")
        
        try:
            with open(dataset_path, 'rb') as f:
                only_0 = dill.load(f)
            print("‚úÖ Successfully loaded only_0.dill dataset")
        except Exception as e:
            print(f"‚ùå Error loading dataset: {e}")
            raise
        
        # Get layers from pre-trained network
        try:
            self.layers = trainer.get_layers(
                dataset=only_0,
                model_id=0,
                input_layer=True,
                return_labels=False
            )
            print(f"‚úÖ Successfully extracted {len(self.layers)} layers")
            print(f"Input layer shape: {self.layers[0].shape}")
            print(f"Output layer shape: {self.layers[-1].shape}")
            
            # Store the specific layers we'll work with
            # Convert to JAX arrays for GPU acceleration
            if self.use_gpu:
                self.input_layer_data = jnp.array(self.layers[0])
                self.output_layer_data = jnp.array(self.layers[-1])
                print("üöÄ Data moved to GPU")
            else:
                self.input_layer_data = np.array(self.layers[0])
                self.output_layer_data = np.array(self.layers[-1])
            
        except Exception as e:
            print(f"‚ùå Error extracting layers: {e}")
            raise
    
    @partial(jit, static_argnums=(2,))
    def gpu_pairwise_distances(self, X, Y, batch_size=1000):
        """GPU-accelerated pairwise distance computation"""
        def compute_batch_distances(x_batch):
            # Compute squared Euclidean distances
            diff = x_batch[:, None, :] - Y[None, :, :]
            return jnp.sum(diff**2, axis=2)
        
        # Process in batches to avoid memory issues
        n_samples = X.shape[0]
        n_batches = (n_samples + batch_size - 1) // batch_size
        
        distances = []
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n_samples)
            x_batch = X[start_idx:end_idx]
            batch_distances = compute_batch_distances(x_batch)
            distances.append(batch_distances)
        
        return jnp.concatenate(distances, axis=0)
    
    @partial(jit, static_argnums=(2,))
    def gpu_knn_graph(self, data, k):
        """GPU-accelerated k-NN graph construction"""
        # Compute pairwise distances
        distances = self.gpu_pairwise_distances(data, data, batch_size=500)
        
        # Find k nearest neighbors (excluding self)
        # Add large value to diagonal to exclude self-connections
        distances = distances.at[jnp.diag_indices(data.shape[0])].set(jnp.inf)
        
        # Get indices of k smallest distances for each point
        _, knn_indices = jax.lax.top_k(-distances, k)  # top_k finds largest, so negate distances
        
        # Create adjacency matrix
        n = data.shape[0]
        adjacency = jnp.zeros((n, n), dtype=bool)
        
        # Fill adjacency matrix
        row_indices = jnp.repeat(jnp.arange(n), k)
        col_indices = knn_indices.flatten()
        adjacency = adjacency.at[row_indices, col_indices].set(True)
        
        # Make symmetric
        adjacency = adjacency | adjacency.T
        
        return adjacency
    
    def gpu_connected_components(self, adjacency_matrix):
        """Compute connected components using CPU (scipy is more efficient for this)"""
        # Convert back to CPU for connected components (scipy is optimized for this)
        if self.use_gpu:
            adj_cpu = np.array(adjacency_matrix)
        else:
            adj_cpu = adjacency_matrix
        
        # Convert to sparse matrix for efficiency
        from scipy.sparse import csr_matrix
        sparse_adj = csr_matrix(adj_cpu.astype(float))
        
        n_components, _ = connected_components(csgraph=sparse_adj, directed=False)
        return n_components
    
    def find_k_for_subset_gpu(self, subset, target_b0, max_k=100):
        """GPU-accelerated optimal k finding"""
        if self.use_gpu:
            subset_gpu = jnp.array(subset) if not isinstance(subset, jnp.ndarray) else subset
        else:
            subset_gpu = subset
        
        for k in range(1, min(max_k, len(subset))):
            try:
                # Create k-NN graph on GPU
                if self.use_gpu:
                    adjacency = self.gpu_knn_graph(subset_gpu, k)
                else:
                    # Fallback to CPU implementation
                    neighbors = NearestNeighbors(n_neighbors=k+1).fit(subset)
                    graph = neighbors.kneighbors_graph(subset, mode='connectivity')
                    graph.setdiag(0)
                    adjacency = graph.maximum(graph.T).toarray().astype(bool)
                
                # Count connected components
                n_components = self.gpu_connected_components(adjacency)
                
                if n_components == target_b0:
                    return k
                    
            except Exception as e:
                continue
        
        return None
    
    def create_random_subset(self, data, fraction=0.25):
        """Create a random subset of the data"""
        total_points = len(data)
        subset_size = int(total_points * fraction)
        
        # Generate random indices
        subset_indices = random.sample(range(total_points), subset_size)
        
        if self.use_gpu and isinstance(data, jnp.ndarray):
            subset = data[subset_indices]
        else:
            subset = data[subset_indices]
        
        return subset, subset_indices
    
    def monte_carlo_worker(self, args):
        """Worker function for parallel Monte Carlo trials"""
        trial_id, data, target_b0, subset_fraction, max_k = args
        
        try:
            # Generate random subset
            subset, _ = self.create_random_subset(data, subset_fraction)
            
            # Find optimal k for this subset
            k = self.find_k_for_subset_gpu(subset, target_b0, max_k)
            
            return k if k is not None else None
            
        except Exception as e:
            return None
    
    def monte_carlo_k_optimization_gpu(self, layer_name, data, target_b0, n_trials=1000, 
                                      subset_fraction=0.25, max_k=100, n_workers=None):
        """GPU-accelerated Monte Carlo simulation with parallel processing"""
        print(f"\n=== GPU Monte Carlo optimization for {layer_name} ===")
        print(f"  Target B0: {target_b0}")
        print(f"  Data shape: {data.shape}")
        print(f"  Number of trials: {n_trials}")
        print(f"  Subset fraction: {subset_fraction}")
        print(f"  Max k to test: {max_k}")
        print(f"  GPU acceleration: {self.use_gpu}")
        
        # Calculate subset size
        total_points = len(data)
        subset_size = int(total_points * subset_fraction)
        print(f"  Total data points: {total_points}")
        print(f"  Subset size ({subset_fraction*100}%): {subset_size}")
        
        # Set number of workers (use number of GPUs * 2 for good utilization)
        if n_workers is None:
            n_workers = self.num_devices * 2 if self.use_gpu else mp.cpu_count()
        
        print(f"  Using {n_workers} parallel workers")
        
        if subset_size < 50:
            print(f"Warning: Subset size {subset_size} is very small.")
        
        # Prepare arguments for parallel processing
        trial_args = [(i, data, target_b0, subset_fraction, max_k) for i in range(n_trials)]
        
        successful_ks = []
        failed_trials = 0
        
        # Use ThreadPoolExecutor for GPU workloads (better for I/O bound GPU operations)
        with ThreadPoolExecutor(max_workers=n_workers) as executor:
            # Submit all trials
            future_to_trial = {executor.submit(self.monte_carlo_worker, args): args[0] 
                             for args in trial_args}
            
            # Process completed trials with progress bar
            for future in tqdm(as_completed(future_to_trial), total=n_trials, 
                             desc=f"GPU Monte Carlo ({layer_name})"):
                result = future.result()
                if result is not None:
                    successful_ks.append(result)
                else:
                    failed_trials += 1
        
        print(f"\n{layer_name} Results:")
        print(f"  Successful trials: {len(successful_ks)}")
        print(f"  Failed trials: {failed_trials}")
        print(f"  Success rate: {len(successful_ks)/n_trials*100:.1f}%")
        
        return successful_ks
    
    def analyze_k_distribution(self, successful_ks, layer_name):
        """Analyze the distribution of successful k values"""
        if not successful_ks:
            print(f"No successful k values to analyze for {layer_name}!")
            return None, None
        
        k_array = np.array(successful_ks)
        mean_k = np.mean(k_array)
        std_k = np.std(k_array)
        median_k = np.median(k_array)
        mode_result = Counter(successful_ks).most_common(1)
        mode_k = mode_result[0][0] if mode_result else None
        
        print(f"\n{layer_name} - K Distribution Analysis:")
        print(f"  Mean k: {mean_k:.2f}")
        print(f"  Median k: {median_k:.2f}")
        print(f"  Mode k: {mode_k}")
        print(f"  Standard deviation: {std_k:.2f}")
        print(f"  Min k: {min(successful_ks)}")
        print(f"  Max k: {max(successful_ks)}")
        
        return mean_k, std_k
    
    def run_full_analysis_gpu(self, n_trials=1000, subset_fraction=0.25, max_k=100, n_workers=None):
        """Run complete GPU-accelerated Monte Carlo analysis for both layers"""
        print("üöÄ Starting GPU-accelerated layer analysis...")
        
        # Load data first
        self.load_network_and_data()
        
        # Run Monte Carlo for input layer (target B0 = 9)
        input_ks = self.monte_carlo_k_optimization_gpu(
            layer_name="Input Layer (layer[0])",
            data=self.input_layer_data,
            target_b0=self.input_b0,
            n_trials=n_trials,
            subset_fraction=subset_fraction,
            max_k=max_k,
            n_workers=n_workers
        )
        
        # Run Monte Carlo for output layer (target B0 = 1)
        output_ks = self.monte_carlo_k_optimization_gpu(
            layer_name="Output Layer (layer[-1])",
            data=self.output_layer_data,
            target_b0=self.output_b0,
            n_trials=n_trials,
            subset_fraction=subset_fraction,
            max_k=max_k,
            n_workers=n_workers
        )
        
        # Analyze results
        print("\n" + "="*60)
        print("FINAL GPU ANALYSIS RESULTS")
        print("="*60)
        
        input_mean, input_std = self.analyze_k_distribution(input_ks, "Input Layer")
        output_mean, output_std = self.analyze_k_distribution(output_ks, "Output Layer")
        
        # Create comparison visualization
        self.plot_comparison(input_ks, output_ks)
        
        return {
            'input_layer': {
                'successful_ks': input_ks,
                'mean_k': input_mean,
                'std_k': input_std
            },
            'output_layer': {
                'successful_ks': output_ks,
                'mean_k': output_mean,
                'std_k': output_std
            }
        }
    
    def plot_comparison(self, input_ks, output_ks):
        """Create comparison plots for input and output layer k distributions"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Input layer histogram
        if input_ks:
            axes[0, 0].hist(input_ks, bins=max(10, len(set(input_ks))), alpha=0.7, color='blue', edgecolor='black')
            axes[0, 0].axvline(np.mean(input_ks), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(input_ks):.2f}')
            axes[0, 0].set_title('Input Layer (B0=9) - k Distribution')
            axes[0, 0].set_xlabel('k value')
            axes[0, 0].set_ylabel('Frequency')
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)
        
        # Output layer histogram
        if output_ks:
            axes[0, 1].hist(output_ks, bins=max(10, len(set(output_ks))), alpha=0.7, color='green', edgecolor='black')
            axes[0, 1].axvline(np.mean(output_ks), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(output_ks):.2f}')
            axes[0, 1].set_title('Output Layer (B0=1) - k Distribution')
            axes[0, 1].set_xlabel('k value')
            axes[0, 1].set_ylabel('Frequency')
            axes[0, 1].legend()
            axes[0, 1].grid(True, alpha=0.3)
        
        # Box plots comparison
        if input_ks and output_ks:
            box_data = [input_ks, output_ks]
            box = axes[1, 0].boxplot(box_data, labels=['Input (B0=9)', 'Output (B0=1)'])
            axes[1, 0].set_title('Box Plot Comparison')
            axes[1, 0].set_ylabel('k value')
            axes[1, 0].grid(True, alpha=0.3)
        
        # Summary statistics
        axes[1, 1].axis('off')
        summary_text = "GPU ACCELERATION SUMMARY\n\n"
        summary_text += f"GPU Enabled: {self.use_gpu}\n"
        summary_text += f"Devices Used: {self.num_devices}\n\n"
        if input_ks:
            summary_text += f"Input Layer (B0=9):\n"
            summary_text += f"  Mean k: {np.mean(input_ks):.2f}\n"
            summary_text += f"  Std k: {np.std(input_ks):.2f}\n"
            summary_text += f"  Success rate: {len(input_ks)}/trials\n\n"
        if output_ks:
            summary_text += f"Output Layer (B0=1):\n"
            summary_text += f"  Mean k: {np.mean(output_ks):.2f}\n"
            summary_text += f"  Std k: {np.std(output_ks):.2f}\n"
            summary_text += f"  Success rate: {len(output_ks)}/trials\n"
        
        axes[1, 1].text(0.1, 0.9, summary_text, transform=axes[1, 1].transAxes, 
                       fontsize=12, verticalalignment='top', fontfamily='monospace')
        
        plt.tight_layout()
        plt.show()

Available JAX devices: [CudaDevice(id=0), CudaDevice(id=1)]
Number of GPUs: 2


In [None]:
# Initialize the GPU-accelerated optimizer
gpu_optimizer = GPULayerMonteCarloKNNOptimizer(
    data_root='/project2/alvinjin_1630/results/D1',
    input_b0=9,   # Target B0 for input layer
    output_b0=1,  # Target B0 for output layer
    use_gpu=True  # Enable GPU acceleration
)

# Run the complete GPU-accelerated analysis
results = gpu_optimizer.run_full_analysis_gpu(
    n_trials=5000,        # More trials since it's faster with GPU
    subset_fraction=0.25, # Use 25% of data each trial
    max_k=50,            # Maximum k to test
    n_workers=4          # Number of parallel workers (auto-detected if None)
)

# Access results
print("üéØ Final Results:")
print(f"Input layer optimal k values: {results['input_layer']['successful_ks'][:10]}")
print(f"Output layer optimal k values: {results['output_layer']['successful_ks'][:10]}")


üöÄ GPU Acceleration: Enabled
üìü Using 2 device(s)
üöÄ Starting GPU-accelerated layer analysis...
Loading neural network and extracting layer representations...


  pid, fd = os.forkpty()


Note: you may need to restart the kernel to use updated packages.
/project2/alvinjin_1630/pcx
[30;43mSkipping virtualenv creation, as specified in config file.[39;49m
[34mInstalling dependencies from lock file[39m

[39;1mPackage operations[39;22m: [34m50[39m installs, [34m6[39m updates, [34m0[39m removals

  [34;1m-[39;22m [39mInstalling [39m[36mnumpy[39m[39m ([39m[39;1m2.2.4[39;22m[39m)[39m: [34mPending...[39m
[1A[0J  [34;1m-[39;22m [39mInstalling [39m[36mnumpy[39m[39m ([39m[39;1m2.2.4[39;22m[39m)[39m: [34mInstalling...[39m
[33mInstalling /home1/jli99757/.local/bin/f2py over existing file[39m
[33mInstalling /home1/jli99757/.local/bin/numpy-config over existing file[39m
[33mInstalling /home1/jli99757/.local/lib/python3.13/site-packages/numpy-2.2.4.dist-info/METADATA over existing file[39m
[33mInstalling /home1/jli99757/.local/lib/python3.13/site-packages/numpy-2.2.4.dist-info/LICENSE.txt over existing file[39m
[33mInstalling /home1/jli