<a href="https://colab.research.google.com/github/Stimulatedsyn/Create/blob/main/experiment%20with%20dynamic%20solvers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [81]:
!pip install torchdiffeq torchsde torch torchvision matplotlib



In [82]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchdiffeq
import torchsde


In [83]:
# Deterministic dataset: Linear trajectory
def generate_deterministic_data(n_samples, input_dim):
    x = torch.linspace(0, 1, n_samples).unsqueeze(-1)
    y = 3 * x + 2  # Linear function
    return x, y

# Stochastic dataset: Brownian motion
def generate_stochastic_data(n_samples, input_dim):
    x = torch.linspace(0, 1, n_samples).unsqueeze(-1)
    y = torch.cumsum(torch.randn(n_samples, 1), dim=0)  # Brownian motion
    return x, y


In [84]:
class ControllerNetwork(nn.Module):
    """
    Neural network for controlling gating and discrete solver selection.
    Outputs:
    - Gate decision: Binary decision to use SDE or ODE.
    """
    def __init__(self, input_dim, hidden_dim):
        super(ControllerNetwork, self).__init__()
        self.gate_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Outputs probability for SDE
        )

    def forward(self, x):
        gate_prob = self.gate_net(x)
        return gate_prob


In [85]:
class UnifiedDynamicSolver(nn.Module):
    """
    Unified solver capable of either SDE or ODE dynamics, controlled by a discrete gate.
    """
    def __init__(self, input_dim, hidden_dim):
        super(UnifiedDynamicSolver, self).__init__()
        # Linear layer to project input to hidden_dim
        self.input_projection = nn.Linear(input_dim, hidden_dim)

        # Drift (shared by both solvers)
        self.drift = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        # Diffusion (used only in the SDE solver)
        self.diffusion = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x, t_span, gate_decision):
        # Project input to hidden_dim
        x_projected = self.input_projection(x)

        # Define the SDE function
        class SDEFunction(nn.Module):
            noise_type = "diagonal"
            sde_type = "ito"

            def __init__(self, drift, diffusion):
                super(SDEFunction, self).__init__()
                self.drift = drift
                self.diffusion = diffusion

            def f(self, t, y):
                """Drift term."""
                return self.drift(y)

            def g(self, t, y):
                """Diffusion term."""
                return self.diffusion(y)

        # Define the ODE function
        class ODEFunction(nn.Module):
            def __init__(self, drift):
                super(ODEFunction, self).__init__()
                self.drift = drift

            def forward(self, t, y):
                """Deterministic dynamics."""
                return self.drift(y)

        # Initialize solvers
        sde_func = SDEFunction(self.drift, self.diffusion)
        ode_func = ODEFunction(self.drift)

        # Solve using SDE and ODE for all inputs
        sde_solution = torchsde.sdeint(sde_func, x_projected, t_span, method='euler')[-1]
        ode_solution = torchdiffeq.odeint(ode_func, x_projected, t_span)[-1]

        # Select based on gate_decision
        combined_output = gate_decision * sde_solution + (1 - gate_decision) * ode_solution
        return combined_output


In [86]:
class AdvancedACMDN(nn.Module):
    """
    Model integrating dynamic solver switching with discrete gate decisions.
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(AdvancedACMDN, self).__init__()
        self.controller = ControllerNetwork(input_dim, hidden_dim)
        self.unified_solver = UnifiedDynamicSolver(input_dim, hidden_dim)  # Pass both arguments
        self.output_layer = nn.Linear(hidden_dim, output_dim * 2)  # Mean and log variance

    def forward(self, x, t_span):
        # Get gate decision (soft during training, hard during inference)
        gate_prob = self.controller(x)
        gate_decision = (gate_prob > 0.5).float()  # Binary decision

        # Solve with the selected solver
        solver_output = self.unified_solver(x, t_span, gate_decision)

        # Final output (mean and log variance for NLL computation)
        output = self.output_layer(solver_output).view(-1, 2)
        return output, gate_prob, gate_decision



In [87]:
import matplotlib.pyplot as plt

def visualize_gate_decisions(x, gate_probs, true_labels=None):
    """
    Visualize gate probabilities and their alignment with the task type.
    Args:
        x: Input data points (1D or 2D).
        gate_probs: Gate probabilities (0 to 1).
        true_labels: Ground truth labels (0 for deterministic, 1 for stochastic), optional.
    """
    plt.figure(figsize=(10, 6))
    plt.scatter(x.squeeze(), gate_probs.detach().numpy(), c='blue', label='Gate Probabilities')
    if true_labels is not None:
        plt.scatter(x.squeeze(), true_labels, c='red', alpha=0.5, label='True Labels')
    plt.axhline(0.5, color='gray', linestyle='--', label='Decision Threshold')
    plt.xlabel('Input')
    plt.ylabel('Gate Probability')
    plt.title('Gate Decisions')
    plt.legend()
    plt.show()


In [88]:
# NLL Loss
def nll_loss(predicted, target):
    mean, log_var = predicted[:, 0], predicted[:, 1]
    var = torch.exp(log_var)
    nll = 0.5 * torch.log(2 * torch.pi * var) + (target - mean) ** 2 / (2 * var)
    return torch.mean(nll)

# Combined Loss with Solver-Specific Loss Selection
def combined_loss(output, target, gate_decision):
    """
    Compute the combined loss for MSE and NLL based on gate decisions.
    Args:
        output: Model predictions (mean and log variance for NLL).
        target: Ground truth values.
        gate_decision: Binary mask for selecting between SDE and ODE losses.
    Returns:
        Scalar loss value.
    """
    # Flatten the target to match the shape of output[:, 0]
    target = target.view(-1)

    # MSE Loss
    mse = nn.MSELoss(reduction='none')(output[:, 0], target)  # No reduction

    # NLL Loss
    nll = nll_loss(output, target)  # Already scalar per sample

    # Combine losses based on gate_decision
    loss = gate_decision * nll + (1 - gate_decision) * mse
    return loss.mean()  # Reduce to scalar



In [89]:
def visualize_solver_outputs(x, y_true, y_pred, gate_decisions):
    """
    Visualize the solver outputs for deterministic and stochastic data.
    Args:
        x: Input data points.
        y_true: True outputs.
        y_pred: Predicted outputs.
        gate_decisions: Gate decisions (binary: 0 for ODE, 1 for SDE).
    """
    plt.figure(figsize=(12, 6))

    # Separate deterministic and stochastic regions
    deterministic_mask = (gate_decisions == 0).squeeze()
    stochastic_mask = (gate_decisions == 1).squeeze()

    # Plot deterministic outputs
    plt.scatter(x[deterministic_mask].squeeze(), y_true[deterministic_mask], c='blue', label='True (Deterministic)')
    plt.scatter(x[deterministic_mask].squeeze(), y_pred[deterministic_mask, 0], c='cyan', label='Predicted (ODE)')

    # Plot stochastic outputs
    plt.scatter(x[stochastic_mask].squeeze(), y_true[stochastic_mask], c='red', label='True (Stochastic)')
    plt.scatter(x[stochastic_mask].squeeze(), y_pred[stochastic_mask, 0], c='orange', label='Predicted (SDE Mean)')

    plt.xlabel('Input')
    plt.ylabel('Output')
    plt.title('Solver Outputs (ODE vs SDE)')
    plt.legend()
    plt.show()


In [90]:
def visualize_loss_landscape(model, x, y_true, t_span):
    """
    Visualize the loss landscape by varying gate probabilities.
    Args:
        model: The trained model.
        x: Input data points.
        y_true: Ground truth outputs.
        t_span: Time span for solvers.
    """
    gate_probs = torch.linspace(0, 1, 100)  # Gate probabilities from 0 to 1
    losses = []

    for gate_prob in gate_probs:
        gate_decision = (gate_prob > 0.5).float()
        with torch.no_grad():
            output, _, _ = model(x, t_span)
            loss = combined_loss(output, y_true, gate_decision)
            losses.append(loss.item())

    # Plot the loss landscape
    plt.figure(figsize=(10, 6))
    plt.plot(gate_probs.numpy(), losses, label='Loss')
    plt.xlabel('Gate Probability')
    plt.ylabel('Loss')
    plt.title('Loss Landscape Across Gate Decisions')
    plt.axvline(0.5, color='gray', linestyle='--', label='Decision Threshold')
    plt.legend()
    plt.show()


In [None]:
# Training parameters
input_dim = 1
hidden_dim = 64
output_dim = 1
n_samples = 100
batch_size = 16
epochs = 50

# Generate data
x_det, y_det = generate_deterministic_data(n_samples, input_dim)
x_sto, y_sto = generate_stochastic_data(n_samples, input_dim)
x = torch.cat([x_det, x_sto], dim=0)
y = torch.cat([y_det, y_sto], dim=0)

# Instantiate model
model = AdvancedACMDN(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
# Training loop with visualizations
for epoch in range(epochs):
    permutation = torch.randperm(x.size(0))  # Shuffle data
    x_shuffled = x[permutation]
    y_shuffled = y[permutation]

    for i in range(0, x.size(0), batch_size):
        x_batch = x_shuffled[i:i + batch_size]
        y_batch = y_shuffled[i:i + batch_size]

        # Time span for solver
        t_span = torch.linspace(0, 1, 100)

        # Forward pass
        output, gate_prob, gate_decision = model(x_batch, t_span)

        # Compute combined loss
        gate_decision = (gate_prob > 0.5).float()  # Binary mask
        loss = combined_loss(output, y_batch, gate_decision)  # Scalar

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Log progress
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")


    # Visualizations
    if (epoch + 1) % 10 == 0:  # Visualize every 10 epochs
        print("Visualizing Gate Decisions...")
        visualize_gate_decisions(x, gate_prob, true_labels=None)

        print("Visualizing Solver Outputs...")
        visualize_solver_outputs(x, y, output, gate_decision)

        print("Visualizing Loss Landscape...")
        visualize_loss_landscape(model, x, y, t_span)


Epoch 1/50, Loss: 6.0069
Epoch 2/50, Loss: 2.8892
Epoch 3/50, Loss: 2.8304
Epoch 4/50, Loss: 2.7476
Epoch 5/50, Loss: 2.5074
Epoch 6/50, Loss: 2.6108
