In [None]:
import torch

# --- Hyperparameters ---
subset_hyperparam = None #None means all samples
learning_rate = 1e-3
batch_size = 128
epochs = 30
dropout_prob = 0.2
weight_decay = 1e-4

# --- Device selection ---
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: cuda


In [None]:
import torch.nn as nn

class StellaratorNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Softplus()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import math


def is_valid(example):
    """Check if an example contains all required metrics with valid values."""
    required_metrics = [
        'metrics.max_elongation',
        'metrics.average_triangularity',
        'metrics.edge_rotational_transform_over_n_field_periods',
        'boundary.n_field_periods',
        'metrics.aspect_ratio',
        'boundary.z_sin',
        'boundary.r_cos'
    ]
    for key in required_metrics:
        val = example.get(key, None)
        if val is None:
            return False
        if isinstance(val, (int, float)):
            if math.isnan(val):
                return False
        elif isinstance(val, list):
            if len(val) == 0:
                return False
            if all(isinstance(row, list) and len(row) > 0 for row in val) is False:
                return False
        else:
            return False
    return True


def load_constellaration_dataset(subset_hyperparam=None, split_ratio=0.8, seed=42):
    dataset = load_dataset("proxima-fusion/constellaration", "default")
    data = dataset["train"]

    # Filter invalid examples first
    data = data.filter(is_valid)

    # Optionally subset first N valid examples
    if subset_hyperparam is not None:
        data = data.select(range(min(subset_hyperparam, len(data))))

    # Shuffle indices
    indices = torch.arange(len(data))
    torch.manual_seed(seed)
    indices = indices[torch.randperm(len(indices))]

    # Split into train/test
    split = int(split_ratio * len(indices))
    train_idx, test_idx = indices[:split], indices[split:]

    train_dataset = data.select(train_idx.tolist())
    test_dataset = data.select(test_idx.tolist())

    return train_dataset, test_dataset


class ConstellarationDataset(Dataset):
    """
    Custom Dataset that separates optimizable features (r_cos, z_sin)
    from fixed features for post-training optimization.
    """
    def __init__(self, hf_dataset, separate_optimizable=False):
        """
        Args:
            hf_dataset: HuggingFace dataset
            separate_optimizable: If True, returns (optimizable_features, fixed_features, target)
                                 If False, returns (all_features, target) for training
        """
        self.data = []
        self.separate_optimizable = separate_optimizable

        # Pre-process and store as numpy arrays
        for sample in hf_dataset:
            try:
                # Extract optimizable features
                r_cos = np.array(sample["boundary.r_cos"], dtype=np.float32).flatten()
                z_sin = np.array(sample["boundary.z_sin"], dtype=np.float32).flatten()

                # Extract fixed features
                fixed_metrics = np.array([
                    float(sample["metrics.aspect_ratio"]),
                    float(sample["metrics.average_triangularity"]),
                    float(sample["metrics.edge_rotational_transform_over_n_field_periods"])
                ], dtype=np.float32)

                target = float(sample["metrics.max_elongation"])

                # Store separately for flexibility
                self.data.append({
                    'r_cos': r_cos,
                    'z_sin': z_sin,
                    'fixed': fixed_metrics,
                    'target': target
                })
            except Exception as e:
                print(f"Error processing sample: {e}")
                continue

        if len(self.data) == 0:
            raise ValueError("No valid samples found in dataset")

        # Store feature dimensions for later use
        sample = self.data[0]
        self.r_cos_dim = len(sample['r_cos'])
        self.z_sin_dim = len(sample['z_sin'])
        self.fixed_dim = len(sample['fixed'])

        print(f"Successfully loaded {len(self.data)} samples")
        #print(f"Feature dimensions - r_cos: {self.r_cos_dim}, z_sin: {self.z_sin_dim}, fixed: {self.fixed_dim}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.separate_optimizable:
            # For optimization: return optimizable and fixed features separately
            optimizable = np.concatenate([sample['r_cos'], sample['z_sin']])
            return (
                torch.from_numpy(optimizable),
                torch.from_numpy(sample['fixed']),
                torch.tensor(sample['target'], dtype=torch.float32)
            )
        else:
            # For training: return all features concatenated
            features = np.concatenate([sample['r_cos'], sample['z_sin'], sample['fixed']])
            return (
                torch.from_numpy(features),
                torch.tensor(sample['target'], dtype=torch.float32)
            )


def collate_fn(batch):
    """Custom collate function that handles both training and optimization modes."""
    if len(batch[0]) == 3:
        # Optimization mode: (optimizable, fixed, target)
        optimizable, fixed, targets = zip(*batch)
        return (
            torch.stack(optimizable),
            torch.stack(fixed),
            torch.stack(targets).reshape(-1, 1)
        )
    else:
        # Training mode: (features, target)
        features, targets = zip(*batch)
        return torch.stack(features), torch.stack(targets).reshape(-1, 1)


# Example usage for training
def create_training_dataloader(train_dataset, batch_size=32):
    """Create dataloader for training with all features concatenated."""
    dataset = ConstellarationDataset(train_dataset, separate_optimizable=False)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


# Example usage for optimization
def create_optimization_dataloader(test_dataset, batch_size=32):
    """Create dataloader for optimization with features separated."""
    dataset = ConstellarationDataset(test_dataset, separate_optimizable=True)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
import torch.optim as optim
def train_model(model=None, train_loader=None, test_loader=None, num_epochs=epochs, lr=learning_rate, save_path="best_model.pth"):
    """
    Trains the model and saves the weights with the best test loss.
    Args:
        model (nn.Module): PyTorch model to train
        train_loader (DataLoader): Training data loader
        test_loader (DataLoader): Test/validation data loader
        num_epochs (int): Number of training epochs
        lr (float): Learning rate
        save_path (str): Path to save the best model weights
    Returns:
        model: Trained PyTorch model
    """
    # If no model is provided, create one
    if model is None:
        train_data, _ = load_constellaration_dataset()
        # Use separate_optimizable=False for training
        train_dataset = ConstellarationDataset(train_data, separate_optimizable=False)
        input_dim = train_dataset[0][0].shape[0]
        model = StellaratorNet(input_dim).to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_test_loss = float("inf")

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for features, targets in train_loader:
            features, targets = features.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * features.size(0)

        train_loss /= len(train_loader.dataset)

        # --- Evaluation ---
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for features, targets in test_loader:
                features, targets = features.to(device), targets.to(device)
                outputs = model(features)
                loss = criterion(outputs, targets)
                test_loss += loss.item() * features.size(0)

        test_loss /= len(test_loader.dataset)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")

        # Save model if test loss improved
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model with test loss {best_test_loss:.4f} at epoch {epoch+1}")

    # Load best model weights before returning
    model.load_state_dict(torch.load(save_path))
    return model


def get_dataloaders(batch_size=batch_size, separate_optimizable=False):
    """
    Get dataloaders for training or optimization.

    Args:
        batch_size (int): Batch size
        separate_optimizable (bool): If False, returns concatenated features for training.
                                    If True, returns separated features for optimization.

    Returns:
        train_loader, test_loader: DataLoaders
    """
    train_data, test_data = load_constellaration_dataset(subset_hyperparam=None)

    train_dataset = ConstellarationDataset(train_data, separate_optimizable=separate_optimizable)
    test_dataset = ConstellarationDataset(test_data, separate_optimizable=separate_optimizable)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return train_loader, test_loader


def get_feature_dimensions():
    """
    Helper function to get feature dimensions for optimization setup.

    Returns:
        dict: Dictionary with r_cos_dim, z_sin_dim, fixed_dim
    """
    # Load enough samples to ensure we get at least one valid sample
    train_data, _ = load_constellaration_dataset(subset_hyperparam=100)
    train_dataset = ConstellarationDataset(train_data, separate_optimizable=False)

    return {
        'r_cos_dim': train_dataset.r_cos_dim,
        'z_sin_dim': train_dataset.z_sin_dim,
        'fixed_dim': train_dataset.fixed_dim,
        'total_dim': train_dataset.r_cos_dim + train_dataset.z_sin_dim + train_dataset.fixed_dim
    }


def get_feature_dimensions_from_dataset(dataset):
    """
    Alternative: Get dimensions from an existing dataset instance.

    Args:
        dataset: ConstellarationDataset instance

    Returns:
        dict: Dictionary with r_cos_dim, z_sin_dim, fixed_dim
    """
    return {
        'r_cos_dim': dataset.r_cos_dim,
        'z_sin_dim': dataset.z_sin_dim,
        'fixed_dim': dataset.fixed_dim,
        'total_dim': dataset.r_cos_dim + dataset.z_sin_dim + dataset.fixed_dim
    }

In [None]:


# Get data loaders
train_loader, test_loader = get_dataloaders()

# Initialize model
input_dim = train_loader.dataset[0][0].shape[0]
model = StellaratorNet(input_dim).to(device)

# Train the model
trained_model = train_model(model=model, train_loader=train_loader, test_loader=test_loader, num_epochs=epochs)

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00003.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]

data/train-00001-of-00003.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

data/train-00002-of-00003.parquet:   0%|          | 0.00/155M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182222 [00:00<?, ? examples/s]

Filter:   0%|          | 0/182222 [00:00<?, ? examples/s]

Successfully loaded 126948 samples
Successfully loaded 31737 samples
Epoch 1/30 | Train Loss: 2.1151 | Test Loss: 0.9890
Saved best model with test loss 0.9890 at epoch 1
Epoch 2/30 | Train Loss: 1.1409 | Test Loss: 0.7500
Saved best model with test loss 0.7500 at epoch 2
Epoch 3/30 | Train Loss: 0.9824 | Test Loss: 0.6399
Saved best model with test loss 0.6399 at epoch 3
Epoch 4/30 | Train Loss: 0.8785 | Test Loss: 0.5385
Saved best model with test loss 0.5385 at epoch 4
Epoch 5/30 | Train Loss: 0.7892 | Test Loss: 0.5136
Saved best model with test loss 0.5136 at epoch 5
Epoch 6/30 | Train Loss: 0.7201 | Test Loss: 0.4040
Saved best model with test loss 0.4040 at epoch 6
Epoch 7/30 | Train Loss: 0.6697 | Test Loss: 0.3582
Saved best model with test loss 0.3582 at epoch 7
Epoch 8/30 | Train Loss: 0.6353 | Test Loss: 0.3341
Saved best model with test loss 0.3341 at epoch 8
Epoch 9/30 | Train Loss: 0.5977 | Test Loss: 0.3159
Saved best model with test loss 0.3159 at epoch 9
Epoch 10/30 |

In [None]:
#Now we come to the optimization
# --- Hyperparameters ---
iterations = 20
lr_opti = 1e-3



def optimize_features(model_path, num_iterations=10, lr=0.01, batch_size=32):
    """
    Optimize r_cos and z_sin features to minimize the target (max_elongation)
    while keeping other features fixed.

    Args:
        model_path (str): Path to trained model weights
        num_iterations (int): Number of optimization iterations
        lr (float): Learning rate for feature optimization
        batch_size (int): Batch size

    Returns:
        optimized_features: Dictionary with optimized r_cos and z_sin for each sample
    """
    # Get data with separated features first
    _, test_data = load_constellaration_dataset()
    test_dataset = ConstellarationDataset(test_data, separate_optimizable=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # Get dimensions from the dataset
    dims = {
        'r_cos_dim': test_dataset.r_cos_dim,
        'z_sin_dim': test_dataset.z_sin_dim,
        'fixed_dim': test_dataset.fixed_dim,
        'total_dim': test_dataset.r_cos_dim + test_dataset.z_sin_dim + test_dataset.fixed_dim
    }

    # Load trained model
    model = StellaratorNet(dims['total_dim']).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set to eval mode - we're not training the model



    all_results = []

    # Process each batch
    for batch_idx, (optimizable, fixed, targets) in enumerate(test_loader):
        optimizable = optimizable.to(device)
        fixed = fixed.to(device)
        targets = targets.to(device)

        # Create optimizable parameters - detach from original and require grad
        optimizable_params = optimizable.clone().detach().requires_grad_(True)

        # Use Adam optimizer for the features
        optimizer = optim.Adam([optimizable_params], lr=lr)

        initial_loss = None

        # Optimize this batch
        for iteration in range(num_iterations):
            optimizer.zero_grad()

            # Concatenate optimizable and fixed features
            combined_features = torch.cat([optimizable_params, fixed], dim=1)

            # Forward pass through model
            predictions = model(combined_features)

            # Loss: we want to minimize the prediction (max_elongation)
            loss = predictions.mean()

            if iteration == 0:
                initial_loss = loss.item()

            # Backward pass - only optimizable_params will get gradients
            loss.backward()
            optimizer.step()

            if (iteration + 1) % 100 == 0:
                print(f"Batch {batch_idx+1}, Iteration {iteration+1}/{num_iterations}, Loss: {loss.item():.6f}")

        final_loss = loss.item()
        print(f"Batch {batch_idx+1} - Initial: {initial_loss:.6f}, Final: {final_loss:.6f}, "
              f"Improvement: {(initial_loss - final_loss):.6f}")

        # Store results
        r_cos_dim = dims['r_cos_dim']
        optimized_r_cos = optimizable_params[:, :r_cos_dim].detach().cpu()
        optimized_z_sin = optimizable_params[:, r_cos_dim:].detach().cpu()

        all_results.append({
            'r_cos': optimized_r_cos,
            'z_sin': optimized_z_sin,
            'fixed': fixed.cpu(),
            'original_target': targets.cpu(),
            'initial_prediction': initial_loss,
            'final_prediction': final_loss
        })

    return all_results



In [None]:
results = optimize_features(model_path="best_model.pth", num_iterations=iterations, lr=lr_opti, batch_size=batch_size)
torch.save(results, "optimized_features.pt")
print("Saved optimized features to optimized_features.pt")

Successfully loaded 31737 samples
Batch 1 - Initial: 4.655895, Final: 0.632839, Improvement: 4.023057
Batch 2 - Initial: 5.056080, Final: 0.663935, Improvement: 4.392145
Batch 3 - Initial: 4.895375, Final: 0.674048, Improvement: 4.221327
Batch 4 - Initial: 4.889223, Final: 0.633327, Improvement: 4.255896
Batch 5 - Initial: 4.822328, Final: 0.616691, Improvement: 4.205637
Batch 6 - Initial: 4.919339, Final: 0.709233, Improvement: 4.210106
Batch 7 - Initial: 4.856818, Final: 0.693722, Improvement: 4.163096
Batch 8 - Initial: 4.537639, Final: 0.648044, Improvement: 3.889595
Batch 9 - Initial: 4.703805, Final: 0.685797, Improvement: 4.018008
Batch 10 - Initial: 4.811751, Final: 0.582422, Improvement: 4.229330
Batch 11 - Initial: 4.897969, Final: 0.675318, Improvement: 4.222652
Batch 12 - Initial: 4.668969, Final: 0.630481, Improvement: 4.038488
Batch 13 - Initial: 4.891409, Final: 0.666493, Improvement: 4.224915
Batch 14 - Initial: 4.763196, Final: 0.670775, Improvement: 4.092421
Batch 15 