In [None]:
print("="*70)
print("BASELINE CNN MODEL - SUMMARY")
print("="*70)
print(f"\nüìä Dataset:")
print(f"  Training samples: {X_train.shape[0]:,}")
print(f"  Validation samples: {X_val.shape[0]:,}")
print(f"  Test samples: {total_test_samples:,}")
print(f"  Image dimensions: 1424 √ó 176 pixels")

print(f"\nüèóÔ∏è Model Architecture:")
print(f"  Type: Convolutional Neural Network (CNN)")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

print(f"\nüìà Training:")
print(f"  Epochs trained: {len(history['train_loss'])}")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Loss function: KL Divergence")
print(f"  Optimizer: Adam (lr={learning_rate})")

print(f"\nüéØ Validation Performance:")
print(f"  Œ©_m RMSE: {np.sqrt(mean_squared_error(omega_m_true, omega_m_pred)):.6f}")
print(f"  S_8 RMSE: {np.sqrt(mean_squared_error(s_8_true, s_8_pred)):.6f}")
print(f"  Œ©_m R¬≤: {r2_score(omega_m_true, omega_m_pred):.4f}")
print(f"  S_8 R¬≤: {r2_score(s_8_true, s_8_pred):.4f}")

print(f"\nüíæ Output Files:")
print(f"  ‚úì Best model: best_model.pth")
print(f"  ‚úì Submission (NumPy): submission_baseline_cnn.npy")
print(f"  ‚úì Submission (CSV): submission_baseline_cnn.csv")
print(f"  ‚úì Training history plot: training_history.png")
print(f"  ‚úì Performance plot: model_performance.png")
print(f"  ‚úì Test predictions plot: test_predictions.png")

print(f"\nüöÄ Next Steps for Improvement:")
print(f"  1. Data augmentation (rotations, flips, crops)")
print(f"  2. Deeper network architecture (ResNet, EfficientNet)")
print(f"  3. Ensemble multiple models")
print(f"  4. Incorporate nuisance parameters in training")
print(f"  5. Use attention mechanisms or Vision Transformers")
print(f"  6. Experiment with different loss functions")
print(f"  7. Add power spectrum features as additional input")
print(f"  8. Apply test-time augmentation")
print(f"  9. Hyperparameter tuning (learning rate, architecture)")
print(f"  10. Cross-validation for more robust evaluation")

print("\n" + "="*70)
print("Baseline model training and evaluation complete!")
print("="*70)

## 14. Summary and Next Steps

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# Omega_m distribution
axes[0, 0].hist(test_omega_m, bins=50, alpha=0.7, edgecolor='black')
axes[0, 0].set_xlabel('Predicted Œ©_m')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Distribution of Predicted Œ©_m')
axes[0, 0].grid(True, alpha=0.3, axis='y')

# S_8 distribution
axes[0, 1].hist(test_s_8, bins=50, alpha=0.7, edgecolor='black', color='orange')
axes[0, 1].set_xlabel('Predicted S_8')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Distribution of Predicted S_8')
axes[0, 1].grid(True, alpha=0.3, axis='y')

# Joint distribution
axes[0, 2].scatter(test_omega_m, test_s_8, alpha=0.5, s=20, c=range(len(test_omega_m)), cmap='viridis')
axes[0, 2].set_xlabel('Predicted Œ©_m')
axes[0, 2].set_ylabel('Predicted S_8')
axes[0, 2].set_title('Joint Distribution of Predictions')
axes[0, 2].grid(True, alpha=0.3)

# Omega_m uncertainty distribution
axes[1, 0].hist(test_sigma_omega_m, bins=50, alpha=0.7, edgecolor='black', color='green')
axes[1, 0].set_xlabel('Predicted œÉ_Œ©_m')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Distribution of Œ©_m Uncertainties')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# S_8 uncertainty distribution
axes[1, 1].hist(test_sigma_s_8, bins=50, alpha=0.7, edgecolor='black', color='red')
axes[1, 1].set_xlabel('Predicted œÉ_S_8')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title('Distribution of S_8 Uncertainties')
axes[1, 1].grid(True, alpha=0.3, axis='y')

# Uncertainty vs prediction
axes[1, 2].scatter(test_omega_m, test_sigma_omega_m, alpha=0.4, s=15, label='Œ©_m')
axes[1, 2].scatter(test_s_8, test_sigma_s_8, alpha=0.4, s=15, label='S_8', color='orange')
axes[1, 2].set_xlabel('Predicted Value')
axes[1, 2].set_ylabel('Predicted Uncertainty (œÉ)')
axes[1, 2].set_title('Predictions vs Uncertainties')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.suptitle('Test Set Predictions', fontsize=14, y=1.00)
plt.tight_layout()
plt.savefig(os.path.join(data_dir, 'test_predictions.png'), dpi=150, bbox_inches='tight')
plt.show()

print("Test prediction plots saved!")

## 13. Visualize Test Predictions

In [None]:
# Create submission array
# Format: [Omega_m, S_8, sigma_Omega_m, sigma_S_8] for each test sample
submission = np.column_stack([test_omega_m, test_s_8, test_sigma_omega_m, test_sigma_s_8])

print(f"Submission shape: {submission.shape}")
print(f"Submission columns: [Omega_m, S_8, sigma_Omega_m, sigma_S_8]")

# Save submission
submission_file = os.path.join(data_dir, 'submission_baseline_cnn.npy')
np.save(submission_file, submission)
print(f"\nSubmission saved to: {submission_file}")

# Also save as CSV for easy viewing
import pandas as pd
submission_df = pd.DataFrame(submission, columns=['Omega_m', 'S_8', 'sigma_Omega_m', 'sigma_S_8'])
submission_csv = os.path.join(data_dir, 'submission_baseline_cnn.csv')
submission_df.to_csv(submission_csv, index=False)
print(f"Submission CSV saved to: {submission_csv}")

# Display first few rows
print("\nFirst 10 predictions:")
print(submission_df.head(10))

## 12. Save Submission File

In [None]:
# Generate predictions for test set
model.eval()
test_predictions = []
test_uncertainties = []

print("Generating test set predictions...")
with torch.no_grad():
    for batch_idx, (data, _) in enumerate(test_loader):
        data = data.to(device)
        output = model(data)
        
        predictions = output[:, :2].cpu().numpy()  # [Omega_m, S_8]
        log_vars = output[:, 2:].cpu().numpy()  # [log_var_Omega_m, log_var_S_8]
        uncertainties = np.sqrt(np.exp(log_vars))  # Convert to std dev
        
        test_predictions.append(predictions)
        test_uncertainties.append(uncertainties)
        
        if (batch_idx + 1) % 50 == 0:
            print(f"  Processed {(batch_idx + 1) * batch_size}/{total_test_samples} samples")

test_predictions = np.vstack(test_predictions)
test_uncertainties = np.vstack(test_uncertainties)

print(f"\nTest predictions shape: {test_predictions.shape}")
print(f"Test uncertainties shape: {test_uncertainties.shape}")

# Extract parameters
test_omega_m = test_predictions[:, 0]
test_s_8 = test_predictions[:, 1]
test_sigma_omega_m = test_uncertainties[:, 0]
test_sigma_s_8 = test_uncertainties[:, 1]

print("\nTest Set Predictions Summary:")
print(f"Œ©_m: mean={test_omega_m.mean():.4f}, std={test_omega_m.std():.4f}, range=[{test_omega_m.min():.4f}, {test_omega_m.max():.4f}]")
print(f"S_8: mean={test_s_8.mean():.4f}, std={test_s_8.std():.4f}, range=[{test_s_8.min():.4f}, {test_s_8.max():.4f}]")
print(f"œÉ_Œ©_m: mean={test_sigma_omega_m.mean():.6f}, median={np.median(test_sigma_omega_m):.6f}")
print(f"œÉ_S_8: mean={test_sigma_s_8.mean():.6f}, median={np.median(test_sigma_s_8):.6f}")

In [None]:
# Prepare test data
print(f"Test data shape: {kappa_test.shape}")

# Reshape test data similar to training data
if len(kappa_test.shape) == 3:
    n_test_cosmo, n_test_samples, n_pixels = kappa_test.shape
    total_test_samples = n_test_cosmo * n_test_samples
    X_test_flat = kappa_test.reshape(total_test_samples, -1)
    print(f"Total test samples: {total_test_samples}")
else:
    X_test_flat = kappa_test
    total_test_samples = X_test_flat.shape[0]
    print(f"Total test samples: {total_test_samples}")

print(f"X_test_flat shape: {X_test_flat.shape}")

# Reconstruct 2D images for test set
print(f"Reconstructing {total_test_samples} test images...")
X_test_2d = np.zeros((total_test_samples, 1424, 176), dtype=np.float32)
for i in range(total_test_samples):
    X_test_2d[i] = reconstruct_image(X_test_flat[i], mask)
    if (i + 1) % 1000 == 0:
        print(f"  Processed {i + 1}/{total_test_samples} images")

print(f"Test data 2D shape: {X_test_2d.shape}")

# Create test dataset and loader
test_dataset = ConvergenceMapDataset(X_test_2d, np.zeros((total_test_samples, 2)))  # Dummy targets
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"Test batches: {len(test_loader)}")

## 11. Generate Predictions for Test Set

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# Omega_m: True vs Predicted
axes[0, 0].scatter(omega_m_true, omega_m_pred, alpha=0.5, s=20)
axes[0, 0].plot([omega_m_true.min(), omega_m_true.max()], 
                [omega_m_true.min(), omega_m_true.max()], 'r--', lw=2)
axes[0, 0].set_xlabel('True Œ©_m')
axes[0, 0].set_ylabel('Predicted Œ©_m')
axes[0, 0].set_title('Œ©_m: True vs Predicted')
axes[0, 0].grid(True, alpha=0.3)

# S_8: True vs Predicted
axes[0, 1].scatter(s_8_true, s_8_pred, alpha=0.5, s=20, color='orange')
axes[0, 1].plot([s_8_true.min(), s_8_true.max()], 
                [s_8_true.min(), s_8_true.max()], 'r--', lw=2)
axes[0, 1].set_xlabel('True S_8')
axes[0, 1].set_ylabel('Predicted S_8')
axes[0, 1].set_title('S_8: True vs Predicted')
axes[0, 1].grid(True, alpha=0.3)

# Joint distribution
axes[0, 2].scatter(omega_m_true, s_8_true, alpha=0.3, s=20, label='True', color='blue')
axes[0, 2].scatter(omega_m_pred, s_8_pred, alpha=0.3, s=20, label='Predicted', color='red')
axes[0, 2].set_xlabel('Œ©_m')
axes[0, 2].set_ylabel('S_8')
axes[0, 2].set_title('Joint Distribution')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Omega_m: Error distribution
axes[1, 0].hist(error_omega_m, bins=50, alpha=0.7, edgecolor='black')
axes[1, 0].axvline(0, color='red', linestyle='--', lw=2)
axes[1, 0].set_xlabel('Error (Predicted - True)')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title(f'Œ©_m Error Distribution\\nMean: {error_omega_m.mean():.6f}, Std: {error_omega_m.std():.6f}')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# S_8: Error distribution
axes[1, 1].hist(error_s_8, bins=50, alpha=0.7, edgecolor='black', color='orange')
axes[1, 1].axvline(0, color='red', linestyle='--', lw=2)
axes[1, 1].set_xlabel('Error (Predicted - True)')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title(f'S_8 Error Distribution\\nMean: {error_s_8.mean():.6f}, Std: {error_s_8.std():.6f}')
axes[1, 1].grid(True, alpha=0.3, axis='y')

# Uncertainty calibration
axes[1, 2].scatter(sigma_omega_m, np.abs(error_omega_m), alpha=0.5, s=20, label='Œ©_m')
axes[1, 2].scatter(sigma_s_8, np.abs(error_s_8), alpha=0.5, s=20, label='S_8', color='orange')
axes[1, 2].plot([0, max(sigma_omega_m.max(), sigma_s_8.max())], 
                [0, max(sigma_omega_m.max(), sigma_s_8.max())], 'r--', lw=2, label='Perfect calibration')
axes[1, 2].set_xlabel('Predicted Uncertainty (œÉ)')
axes[1, 2].set_ylabel('Absolute Error')
axes[1, 2].set_title('Uncertainty Calibration')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.suptitle('Model Performance on Validation Set', fontsize=14, y=1.00)
plt.tight_layout()
plt.savefig(os.path.join(data_dir, 'model_performance.png'), dpi=150, bbox_inches='tight')
plt.show()

print("Performance plots saved!")

## 10. Visualize Predictions

In [None]:
# Calculate metrics
omega_m_pred = all_predictions[:, 0]
s_8_pred = all_predictions[:, 1]
omega_m_true = all_targets[:, 0]
s_8_true = all_targets[:, 1]
sigma_omega_m = all_uncertainties[:, 0]
sigma_s_8 = all_uncertainties[:, 1]

# Compute errors
error_omega_m = omega_m_pred - omega_m_true
error_s_8 = s_8_pred - s_8_true

# Compute metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

print("="*70)
print("VALIDATION SET PERFORMANCE")
print("="*70)

print("\nŒ©_m (Matter Density Fraction):")
print(f"  MSE:  {mean_squared_error(omega_m_true, omega_m_pred):.6f}")
print(f"  RMSE: {np.sqrt(mean_squared_error(omega_m_true, omega_m_pred)):.6f}")
print(f"  MAE:  {mean_absolute_error(omega_m_true, omega_m_pred):.6f}")
print(f"  R¬≤:   {r2_score(omega_m_true, omega_m_pred):.6f}")
print(f"  Mean predicted uncertainty: {sigma_omega_m.mean():.6f}")
print(f"  Median predicted uncertainty: {np.median(sigma_omega_m):.6f}")

print("\nS_8 (Matter Fluctuation Amplitude):")
print(f"  MSE:  {mean_squared_error(s_8_true, s_8_pred):.6f}")
print(f"  RMSE: {np.sqrt(mean_squared_error(s_8_true, s_8_pred)):.6f}")
print(f"  MAE:  {mean_absolute_error(s_8_true, s_8_pred):.6f}")
print(f"  R¬≤:   {r2_score(s_8_true, s_8_pred):.6f}")
print(f"  Mean predicted uncertainty: {sigma_s_8.mean():.6f}")
print(f"  Median predicted uncertainty: {np.median(sigma_s_8):.6f}")

print("="*70)

In [None]:
# Load best model
checkpoint = torch.load(os.path.join(data_dir, 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")

# Get predictions on validation set
model.eval()
all_predictions = []
all_targets = []
all_uncertainties = []

with torch.no_grad():
    for data, target in val_loader:
        data = data.to(device)
        output = model(data)
        
        predictions = output[:, :2].cpu().numpy()  # [Omega_m, S_8]
        log_vars = output[:, 2:].cpu().numpy()  # [log_var_Omega_m, log_var_S_8]
        uncertainties = np.sqrt(np.exp(log_vars))  # Convert to std dev
        
        all_predictions.append(predictions)
        all_targets.append(target.numpy())
        all_uncertainties.append(uncertainties)

all_predictions = np.vstack(all_predictions)
all_targets = np.vstack(all_targets)
all_uncertainties = np.vstack(all_uncertainties)

print(f"Predictions shape: {all_predictions.shape}")
print(f"Targets shape: {all_targets.shape}")
print(f"Uncertainties shape: {all_uncertainties.shape}")

## 9. Evaluate Model Performance

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot KL loss
axes[0].plot(history['train_loss'], label='Train', marker='o')
axes[0].plot(history['val_loss'], label='Validation', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('KL Divergence Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot MSE
axes[1].plot(history['train_mse'], label='Train', marker='o')
axes[1].plot(history['val_mse'], label='Validation', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Mean Squared Error')
axes[1].set_title('Training and Validation MSE')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(data_dir, 'training_history.png'), dpi=150, bbox_inches='tight')
plt.show()

print("Training history plots saved!")

## 8. Plot Training History

In [None]:
# Train the model
for epoch in range(num_epochs):
    train_loss, train_mse = train_epoch(model, train_loader, optimizer, device)
    val_loss, val_mse = validate(model, val_loader, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_mse'].append(train_mse)
    history['val_loss'].append(val_loss)
    history['val_mse'].append(val_mse)
    
    # Print progress
    print(f"Epoch {epoch+1:02d}/{num_epochs} | "
          f"Train Loss: {train_loss:.4f} | Train MSE: {train_mse:.6f} | "
          f"Val Loss: {val_loss:.4f} | Val MSE: {val_mse:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, os.path.join(data_dir, 'best_model.pth'))
        print(f"  ‚Üí Best model saved! (Val Loss: {val_loss:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered after {epoch+1} epochs!")
        break

print("-" * 70)
print(f"Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

In [None]:
def train_epoch(model, train_loader, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_mse = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        
        # Compute losses
        loss = kl_divergence_loss(output, target)
        mse = mse_loss(output, target)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_mse += mse.item()
    
    avg_loss = total_loss / len(train_loader)
    avg_mse = total_mse / len(train_loader)
    
    return avg_loss, avg_mse

def validate(model, val_loader, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    total_mse = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            loss = kl_divergence_loss(output, target)
            mse = mse_loss(output, target)
            
            total_loss += loss.item()
            total_mse += mse.item()
    
    avg_loss = total_loss / len(val_loader)
    avg_mse = total_mse / len(val_loader)
    
    return avg_loss, avg_mse

# Initialize optimizer and scheduler
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Training configuration
num_epochs = 20
best_val_loss = float('inf')
patience = 7
patience_counter = 0

# History
history = {
    'train_loss': [],
    'train_mse': [],
    'val_loss': [],
    'val_mse': []
}

print(f"Starting training for {num_epochs} epochs...")
print(f"Learning rate: {learning_rate}")
print(f"Batch size: {batch_size}")
print(f"Optimizer: Adam")
print(f"Device: {device}")
print("-" * 70)

## 7. Training Loop

In [None]:
def kl_divergence_loss(predictions, targets):
    """
    KL Divergence loss for uncertainty estimation.
    
    Args:
        predictions: Tensor of shape (batch_size, 4) 
                    [Omega_m_pred, S_8_pred, log_var_Omega_m, log_var_S_8]
        targets: Tensor of shape (batch_size, 2) [Omega_m_true, S_8_true]
    
    Returns:
        loss: Scalar tensor
    """
    # Split predictions
    omega_m_pred = predictions[:, 0]
    s_8_pred = predictions[:, 1]
    log_var_omega_m = predictions[:, 2]
    log_var_s_8 = predictions[:, 3]
    
    # Split targets
    omega_m_true = targets[:, 0]
    s_8_true = targets[:, 1]
    
    # Compute variance from log variance (for numerical stability)
    var_omega_m = torch.exp(log_var_omega_m)
    var_s_8 = torch.exp(log_var_s_8)
    
    # KL divergence loss components
    # Loss = (pred - true)^2 / var + log(var)
    loss_omega_m = ((omega_m_pred - omega_m_true) ** 2) / var_omega_m + log_var_omega_m
    loss_s_8 = ((s_8_pred - s_8_true) ** 2) / var_s_8 + log_var_s_8
    
    # Total loss (mean over batch)
    total_loss = torch.mean(loss_omega_m + loss_s_8)
    
    return total_loss

def mse_loss(predictions, targets):
    """Simple MSE loss for comparison"""
    omega_m_pred = predictions[:, 0]
    s_8_pred = predictions[:, 1]
    omega_m_true = targets[:, 0]
    s_8_true = targets[:, 1]
    
    loss = torch.mean((omega_m_pred - omega_m_true) ** 2 + (s_8_pred - s_8_true) ** 2)
    return loss

# Test the loss function
test_predictions = torch.randn(32, 4)
test_targets = torch.randn(32, 2)
test_loss = kl_divergence_loss(test_predictions, test_targets)
print(f"Test KL loss: {test_loss.item():.4f}")

## 6. Define KL Divergence Loss Function

In [None]:
class WeakLensingCNN(nn.Module):
    """
    CNN for predicting cosmological parameters from weak lensing convergence maps.
    Outputs: [Omega_m, S_8, log_var_Omega_m, log_var_S_8]
    """
    
    def __init__(self, input_channels=1):
        super(WeakLensingCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4, 4))  # Fixed output size
        )
        
        # Calculate flattened size
        self.flat_features = 256 * 4 * 4
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(self.flat_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 4)  # Output: [Omega_m, S_8, log_var_Omega_m, log_var_S_8]
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

# Create model
model = WeakLensingCNN(input_channels=1).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: WeakLensingCNN")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"\nModel architecture:")
print(model)

## 5. Define CNN Architecture

In [None]:
class ConvergenceMapDataset(Dataset):
    """Dataset for weak lensing convergence maps"""
    
    def __init__(self, images, targets, transform=None):
        """
        Args:
            images: numpy array of shape (N, H, W)
            targets: numpy array of shape (N, 2) - [Omega_m, S_8]
            transform: optional transform to apply
        """
        self.images = torch.FloatTensor(images).unsqueeze(1)  # Add channel dim: (N, 1, H, W)
        self.targets = torch.FloatTensor(targets)
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, target

# Split data into train and validation sets
X_train, X_val, y_train_split, y_val = train_test_split(
    X_train_2d, y_train, test_size=0.15, random_state=42
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")

# Create datasets
train_dataset = ConvergenceMapDataset(X_train, y_train_split)
val_dataset = ConvergenceMapDataset(X_val, y_val)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"\nBatch size: {batch_size}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Test the data loader
sample_batch, sample_targets = next(iter(train_loader))
print(f"\nSample batch shape: {sample_batch.shape}")  # (batch_size, 1, 1424, 176)
print(f"Sample targets shape: {sample_targets.shape}")  # (batch_size, 2)

## 4. Create PyTorch Dataset and DataLoader

In [None]:
# The data is currently flattened. We need to reconstruct the 2D images using the mask
# Mask shape is (1424, 176), and flattened valid pixels = 132019

# Find valid pixel positions
valid_mask = (mask > 0).flatten()
n_valid_pixels = valid_mask.sum()
print(f"Number of valid pixels in mask: {n_valid_pixels}")
print(f"Expected from data: {X_train_flat.shape[1]}")

# Create function to reconstruct 2D images from flattened data
def reconstruct_image(flat_data, mask):
    """
    Reconstruct 2D image from flattened valid pixels
    flat_data: (n_valid_pixels,)
    mask: (H, W) boolean mask
    returns: (H, W) image
    """
    H, W = mask.shape
    image = np.zeros((H, W), dtype=flat_data.dtype)
    image[mask > 0] = flat_data
    return image

# Reconstruct a sample image to verify
sample_idx = 0
sample_2d = reconstruct_image(X_train_flat[sample_idx], mask)
print(f"\nReconstructed image shape: {sample_2d.shape}")
print(f"Reconstructed image stats: min={sample_2d.min():.4f}, max={sample_2d.max():.4f}, mean={sample_2d.mean():.4f}")

# Visualize sample
plt.figure(figsize=(10, 6))
plt.imshow(sample_2d, cmap='RdBu_r', aspect='auto')
plt.colorbar(label='Œ∫ (convergence)')
plt.title(f'Sample Convergence Map\\nŒ©_m={y_omega_m[sample_idx]:.4f}, S_8={y_s_8[sample_idx]:.4f}')
plt.xlabel('Width (pixels)')
plt.ylabel('Height (pixels)')
plt.tight_layout()
plt.show()

# Reconstruct all training images
print(f"\nReconstructing all {total_samples} training images...")
X_train_2d = np.zeros((total_samples, 1424, 176), dtype=np.float32)
for i in range(total_samples):
    X_train_2d[i] = reconstruct_image(X_train_flat[i], mask)
    if (i + 1) % 5000 == 0:
        print(f"  Processed {i + 1}/{total_samples} images")

print(f"Final training data shape: {X_train_2d.shape}")
print(f"Memory usage: {X_train_2d.nbytes / (1024**3):.2f} GB")

## 3. Reshape Data to 2D Images

In [None]:
# Data paths
data_dir = r'c:\ML\Challenges\NeurIPS_2025'
label_file = os.path.join(data_dir, 'label.npy')
kappa_file = os.path.join(data_dir, 'WIDE12H_bin2_2arcmin_kappa.npy')
kappa_test_file = os.path.join(data_dir, 'WIDE12H_bin2_2arcmin_kappa_noisy_test.npy')
mask_file = os.path.join(data_dir, 'WIDE12H_bin2_2arcmin_mask.npy')

print("Loading data...")
# Load data
labels = np.load(label_file)  # Shape: (101, 256, 5)
kappa_train = np.load(kappa_file)  # Shape: (101, 256, 132019)
kappa_test = np.load(kappa_test_file)  # Shape: (?, ?, 132019)
mask = np.load(mask_file)  # Shape: (1424, 176)

print(f"Labels shape: {labels.shape}")
print(f"Training kappa shape: {kappa_train.shape}")
print(f"Test kappa shape: {kappa_test.shape}")
print(f"Mask shape: {mask.shape}")

# Extract cosmological parameters (Omega_m and S_8)
# Labels structure: [:, :, 0] = Omega_m, [:, :, 1] = S_8, [:, :, 2:] = nuisance params
omega_m = labels[:, :, 0]  # Shape: (101, 256)
s_8 = labels[:, :, 1]  # Shape: (101, 256)

# Reshape data for training
# Flatten cosmology and sample dimensions: (101, 256) -> (101*256,)
n_cosmologies, n_samples_per_cosmo = omega_m.shape
total_samples = n_cosmologies * n_samples_per_cosmo

X_train_flat = kappa_train.reshape(total_samples, -1)  # (25856, 132019)
y_omega_m = omega_m.flatten()  # (25856,)
y_s_8 = s_8.flatten()  # (25856,)

print(f"\nTotal training samples: {total_samples}")
print(f"X_train shape: {X_train_flat.shape}")
print(f"Y Omega_m shape: {y_omega_m.shape}")
print(f"Y S_8 shape: {y_s_8.shape}")

# Combine target parameters
y_train = np.column_stack([y_omega_m, y_s_8])  # (25856, 2)
print(f"Y_train (combined) shape: {y_train.shape}")

print("\nTarget statistics:")
print(f"Omega_m: mean={y_omega_m.mean():.4f}, std={y_omega_m.std():.4f}, range=[{y_omega_m.min():.4f}, {y_omega_m.max():.4f}]")
print(f"S_8: mean={y_s_8.mean():.4f}, std={y_s_8.std():.4f}, range=[{y_s_8.min():.4f}, {y_s_8.max():.4f}]")

## 2. Load and Preprocess Data

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import os
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")

## 1. Import Libraries and Setup

# NeurIPS 2025 Weak Lensing Challenge - Baseline CNN Model

This notebook implements a baseline Convolutional Neural Network (CNN) for predicting cosmological parameters (Œ©_m, S_8) and their uncertainties from weak lensing convergence maps.

## Approach: CNN Direct Prediction with KL Divergence Loss

The model predicts:
- **Point estimates**: (Œ©ÃÇ_m, ≈ú_8)
- **Uncertainties**: (œÉÃÇ_Œ©_m, œÉÃÇ_S_8)

Training uses KL divergence loss to optimize both predictions and uncertainties simultaneously.