In [None]:
import torch
config = {
    "batch_size": 140,
    "n_theta": 16,
    "n_phi": 16,
    "mlp_input_dim": 3,
    "hidden_dim": 64,
    "num_outputs": 1,
    "learning_rate": 1e-3,
    "num_epochs": 5,
    "weight_decay": 1e-3
}

# --- Device selection ---
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: cuda


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class PeriodicConv2d(nn.Module):
    """
    2D convolution with circular (periodic) padding in both dimensions.
    Preserves periodicity along θ and φ coordinates.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias)
        self.padding = padding

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.padding > 0:
            x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='circular')
        return self.conv(x)

class BoundaryCNN(nn.Module):
    """
    CNN for processing 2D boundary maps with multi-scale pooling.
    Takes a tuple of (R, Z) boundary tensors as input.
    """
    def __init__(self, n_theta: int, n_phi: int, num_outputs: int = 1, dropout: float = 0.2):
        super().__init__()
        self.conv1 = PeriodicConv2d(2, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = PeriodicConv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = PeriodicConv2d(64, 128, kernel_size=3, padding=1, stride=2)
        self.bn3 = nn.BatchNorm2d(128)

        self.conv4 = PeriodicConv2d(128, 256, kernel_size=3, padding=1, stride=2)
        self.bn4 = nn.BatchNorm2d(256)

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.gmp = nn.AdaptiveMaxPool2d(1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(256 * 2, num_outputs)

    def forward(self, boundary_tuple: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        R, Z = boundary_tuple
        if R.dim() == 2:
            R = R.unsqueeze(0)
            Z = Z.unsqueeze(0)

        x = torch.stack([R, Z], dim=1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))

        x_avg = self.gap(x).view(x.size(0), -1)
        x_max = self.gmp(x).view(x.size(0), -1)
        x = torch.cat([x_avg, x_max], dim=1)
        x = self.dropout(x)
        return self.fc(x)

class MLP(nn.Module):
    """
    Multi-layer perceptron with batch normalization and dropout.
    """
    def __init__(self, input_dim: int, hidden_dim: int = 64, output_dim: int = 1, dropout: float = 0.3):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim * 2)
        self.bn2 = nn.BatchNorm1d(hidden_dim * 2)

        self.fc3 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.bn3 = nn.BatchNorm1d(hidden_dim)

        self.fc4 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.dropout(F.relu(self.bn1(self.fc1(x))))
        x = self.dropout(F.relu(self.bn2(self.fc2(x))))
        x = self.dropout(F.relu(self.bn3(self.fc3(x))))
        return self.fc4(x)

class CombinedModel(nn.Module):
    """
    Combines a BoundaryCNN and an MLP for joint feature processing.
    Concatenates CNN and MLP outputs before final fully connected layers.
    """
    def __init__(self, n_theta: int, n_phi: int, mlp_input_dim: int = 3, hidden_dim: int = 128, num_outputs: int = 1):
        super().__init__()
        self.boundary_cnn = BoundaryCNN(n_theta, n_phi, hidden_dim)
        self.mlp = MLP(mlp_input_dim, hidden_dim, hidden_dim)

        self.final_fc = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_outputs)
        )

    def forward(self, boundary_tuple: tuple[torch.Tensor, torch.Tensor], mlp_features: torch.Tensor) -> torch.Tensor:
        cnn_out = self.boundary_cnn(boundary_tuple)
        mlp_out = self.mlp(mlp_features)
        combined = torch.cat([cnn_out, mlp_out], dim=1)
        return F.softplus(self.final_fc(combined))

In [None]:
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset



def is_valid(sample: dict) -> bool:
    required_keys = [
        "metrics.max_elongation",
        "metrics.average_triangularity",
        "metrics.edge_rotational_transform_over_n_field_periods",
        "metrics.aspect_ratio",
        "boundary.n_field_periods",
        "boundary.r_cos",
        "boundary.z_sin"
    ]
    for key in required_keys:
        val = sample.get(key, None)
        if val is None:
            return False
        if isinstance(val, float) and (torch.isnan(torch.tensor(val)) or torch.isinf(torch.tensor(val))):
            return False
        if isinstance(val, list) and len(val) == 0:
            return False
    return True




def load_constellaration_dataset(subset_hyperparam: int = None, split_ratio: float = 0.8, seed: int = 42):
    """
    Load the Constelleration dataset, filter invalid examples, optionally subset,
    and split into training and testing sets.

    Returns:
        tuple: (train_dataset, test_dataset) as HuggingFace datasets
    """
    dataset = load_dataset("proxima-fusion/constellaration", "default")
    data = dataset["train"]

    data = data.filter(is_valid)

    if subset_hyperparam is not None:
        data = data.select(range(min(subset_hyperparam, len(data))))

    indices = torch.arange(len(data))
    torch.manual_seed(seed)
    indices = indices[torch.randperm(len(indices))]

    split = int(split_ratio * len(indices))
    train_idx, test_idx = indices[:split], indices[split:]

    train_dataset = data.select(train_idx.tolist())
    test_dataset = data.select(test_idx.tolist())

    return train_dataset, test_dataset

def extract_mlpfeatures(sample: dict, device: torch.device = None) -> torch.Tensor | None:
    """
    Extracts selected MLP features from a dataset sample.
    """
    features = [
        sample.get("metrics.aspect_ratio"),
        sample.get("metrics.average_triangularity"),
        sample.get("metrics.edge_rotational_transform_over_n_field_periods"),
    ]
    if any(f is None for f in features):
        return None

    x = torch.tensor(features, dtype=torch.float32)
    if device:
        x = x.to(device)
    return x

def create_boundary(sample: dict, n_theta: int = 15, n_phi: int = 15, device: torch.device = None) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Reconstructs the boundary surface functions R(theta, phi) and Z(theta, phi)
    from Fourier-like mode coefficients stored in the sample.
    """
    n_fp = torch.tensor(sample['boundary.n_field_periods'], dtype=torch.float32)
    r_cos = torch.as_tensor(sample['boundary.r_cos'], dtype=torch.float32)
    z_sin = torch.as_tensor(sample['boundary.z_sin'], dtype=torch.float32)
    assert r_cos.shape == z_sin.shape, "Coefficient shapes must match"
    M, N = r_cos.shape

    theta = torch.linspace(0, 2 * torch.pi, n_theta)
    phi = torch.linspace(0, 2 * torch.pi, n_phi)
    theta, phi = torch.meshgrid(theta, phi, indexing='ij')
    theta = theta.unsqueeze(0).unsqueeze(0)
    phi = phi.unsqueeze(0).unsqueeze(0)

    m = torch.arange(M, dtype=torch.float32).view(M, 1, 1, 1)
    n = torch.arange(N, dtype=torch.float32).view(1, N, 1, 1)

    R = (r_cos.view(M, N, 1, 1) * torch.cos(m * theta - n * n_fp * phi)).sum(dim=(0, 1))
    Z = (z_sin.view(M, N, 1, 1) * torch.sin(m * theta - n * n_fp * phi)).sum(dim=(0, 1))

    if device:
        R = R.to(device)
        Z = Z.to(device)

    return R, Z

def extract_target(sample: dict, device: torch.device = None) -> torch.Tensor | None:
    """
    Extracts the target variable (max_elongation) from a sample.
    """
    target = sample.get("metrics.max_elongation")
    if target is None:
        return None
    y = torch.tensor(target, dtype=torch.float32)
    if device:
        y = y.to(device)
    return y

class ConstellerationCNNDataset(Dataset):
    """
    PyTorch Dataset wrapper for the Constelleration dataset.
    Returns tuples of (boundary_tensor, mlp_features, target).
    """
    def __init__(self, hf_dataset, device: torch.device = None):
        self.data = hf_dataset
        self.device = device

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int):
        sample = self.data[idx]
        boundary = create_boundary(sample)
        mlp_features = extract_mlpfeatures(sample, device=self.device)
        target = extract_target(sample, device=self.device)
        return boundary, mlp_features, target

In [None]:
#test
train_hf, test_hf = load_constellaration_dataset()
print("Train size:", len(train_hf))
print("Test size:", len(test_hf))

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00003.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]

data/train-00001-of-00003.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

data/train-00002-of-00003.parquet:   0%|          | 0.00/155M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182222 [00:00<?, ? examples/s]

Filter:   0%|          | 0/182222 [00:00<?, ? examples/s]

Train size: 126948
Test size: 31737


In [None]:
from torch import  optim

def train_model(
    model: nn.Module,
    train_dataset: torch.utils.data.Dataset,
    test_dataset: torch.utils.data.Dataset,
    weight_decay: float,
    batch_size: int = 32,
    num_epochs: int = 3,
    lr: float = 1e-3,
    device: torch.device = None,
    save_path: str = "best_model.pth",
) -> nn.Module:
    """
    Train a PyTorch model using the provided training and testing datasets.

    Args:
        model (nn.Module): The PyTorch model to train.
        train_dataset (Dataset): Dataset for training.
        test_dataset (Dataset): Dataset for evaluation.
        weight_decay (float): Weight decay for the AdamW optimizer.
        batch_size (int, optional): Number of samples per batch. Defaults to 32.
        num_epochs (int, optional): Number of training epochs. Defaults to 3.
        lr (float, optional): Learning rate for the optimizer. Defaults to 1e-3.
        device (torch.device, optional): Device to run training on. Defaults to GPU if available.
        save_path (str, optional): Path to save the best model checkpoint. Defaults to "best_model.pth".

    Returns:
        nn.Module: The trained model (best state saved during training).
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_test_loss = float("inf")

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for boundary_tuple, mlp_features, target in train_loader:
            R, Z = boundary_tuple
            if isinstance(R, list) or R.dim() == 2:
                R = R.unsqueeze(0)
                Z = Z.unsqueeze(0)
            boundary_tuple_batch = (R.to(device), Z.to(device))
            mlp_features = mlp_features.to(device)
            target = target.to(device).unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(boundary_tuple_batch, mlp_features)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * target.size(0)

        train_loss /= len(train_loader.dataset)

        # Evaluation
        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for boundary_tuple, mlp_features, target in test_loader:
                R, Z = boundary_tuple
                if isinstance(R, list) or R.dim() == 2:
                    R = R.unsqueeze(0)
                    Z = Z.unsqueeze(0)
                boundary_tuple_batch = (R.to(device), Z.to(device))
                mlp_features = mlp_features.to(device)
                target = target.to(device).unsqueeze(1)

                outputs = model(boundary_tuple_batch, mlp_features)
                loss = criterion(outputs, target)
                test_loss += loss.item() * target.size(0)

        test_loss /= len(test_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            torch.save(model.state_dict(), save_path)

    return model

In [None]:
def main():
    """
    Load datasets, initialize the CombinedModel, and train it using the train_model utility.
    """
    # Load training and testing datasets
    train_hf, test_hf = load_constellaration_dataset(subset_hyperparam=None)
    train_dataset = ConstellerationCNNDataset(train_hf, device=device)
    test_dataset = ConstellerationCNNDataset(test_hf, device=device)

    # Initialize the model
    model = CombinedModel(
        n_theta=config["n_theta"],
        n_phi=config["n_phi"],
        mlp_input_dim=config["mlp_input_dim"],
        hidden_dim=config["hidden_dim"],
        num_outputs=config["num_outputs"]
    )

    # Train the model
    trained_model = train_model(
        model,
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        batch_size=config["batch_size"],
        num_epochs=config['num_epochs'],
        lr=config["learning_rate"],
        device=device,
        weight_decay=config["weight_decay"]
    )

    return trained_model

if __name__ == "__main__":
    model = main()

Epoch 1/5 | Train Loss: 2.0436 | Test Loss: 0.5170
Epoch 2/5 | Train Loss: 1.3669 | Test Loss: 0.5858
Epoch 3/5 | Train Loss: 1.1525 | Test Loss: 0.2774
Epoch 4/5 | Train Loss: 1.0179 | Test Loss: 0.3035
Epoch 5/5 | Train Loss: 0.9043 | Test Loss: 0.2966
