In [6]:
import torch
import time
import numpy as np

##############################################################################
# A) UNIVERSAL CONSTANTS
##############################################################################
PERIOD      = 100.0
A2          = -3.2
SIN_SCALE   = 3.0
A4          = 0.1
A6          = 1.0
ALL_SCALE   = 1.0 

def a3_func(t):
    return SIN_SCALE * torch.sin(2.0 * torch.pi * t / PERIOD)

def a5_func(t):
    return -(3.0/5.0) * a3_func(t)

##############################################################################
# B) DERIVATIVES OF THE POTENTIAL
##############################################################################

def d_poly__d_x_torch(x, t):
    a3_val = a3_func(t)
    a5_val = a5_func(t)

    return ALL_SCALE * (
         A6*6.0*x**5
       + a5_val*5.0*x**4
       + A4*4.0*x**3
       + a3_val*3.0*x**2
       + A2*2.0*x
    )

def d2_poly__d_x2_torch(x, t):
    a3_val = a3_func(t)
    a5_val = a5_func(t)

    return ALL_SCALE * (
         A6*30.0*x**4
       + a5_val*20.0*x**3
       + A4*12.0*x**2
       + a3_val*6.0*x
       + A2*2.0
    )

##############################################################################
# C) DRIFT-IMPLICIT UPDATE
##############################################################################

def drift_implicit_update_torch(x_n, t_n, dt, sqrt_epsilon, dW_n,
                                max_iter=20, tol=1e-10):
    """
    Implicit eqn: X_{n+1} + dt*U'(X_{n+1}, t_{n+1}) = x_n + sqrt_epsilon*dW_n.
    We solve via Newton's method with derivative 1 + dt*U''(x,t_{n+1}).
    """
    t_next = t_n + dt
    A_n = x_n + sqrt_epsilon*dW_n  

    # Initial guess: explicit Euler step
    x_new = x_n - dt*d_poly__d_x_torch(x_n, t_n) + sqrt_epsilon*dW_n

    for _ in range(max_iter):
        f_val = x_new + dt*d_poly__d_x_torch(x_new, t_next) - A_n
        df_val = 1.0 + dt*d2_poly__d_x2_torch(x_new, t_next)

        # Clamp derivative to prevent blow-ups
        df_val = torch.clamp(df_val, min=1e-14, max=1e14)
        step = f_val / df_val
        x_new_next = x_new - step

        if torch.max(torch.abs(step)) < tol:
            x_new = x_new_next
            break

        x_new = x_new_next

    return x_new, t_next

##############################################################################
# D) EXPLICIT EULER-MARUYAMA UPDATE
##############################################################################

def euler_maruyama_update_torch(x_n, t_n, dt, sqrt_epsilon, dW_n):
    """
    Explicit Euler–Maruyama step:
        dX_t = - U'(X_t, t) dt + sqrt(eps)*dW_t
    =>  X_{n+1} = X_n - dt*U'(X_n, t_n) + sqrt(eps)*dW_n
    """
    drift_val = - d_poly__d_x_torch(x_n, t_n)
    x_next = x_n + drift_val * dt + sqrt_epsilon * dW_n
    t_next = t_n + dt
    return x_next, t_next

##############################################################################
# E) EXAMPLE FUNCTION: SIMULATE SDE (DRIFT-IMPLICIT) WITH SHARED NOISE
##############################################################################

def simulate_sde_same_noise(X0, t0, dt, N, eps,
                            device='cpu', seed=1234):
    """
    Each time step uses one scalar noise for *all* trajectories.
    CPU vs GPU will share the same increments if we fix the seed.
    """
    X0 = X0.to(device)
    batch_size = X0.shape[0]

    X_all = torch.zeros(N+1, batch_size, device=device)
    t_all = torch.zeros(N+1, device=device)
    X_all[0] = X0
    t_all[0] = t0

    sqrt_epsilon = torch.sqrt(torch.tensor(eps, device=device))

    # Pre-generate random increments on CPU for cross-device consistency
    gen = torch.Generator(device='cpu').manual_seed(seed)
    dWs_cpu = torch.randn(N, generator=gen)*np.sqrt(dt)

    # Copy to device if needed
    dWs = dWs_cpu if device == 'cpu' else dWs_cpu.to(device)

    for n in range(N):
        x_n = X_all[n]
        t_n = t_all[n].expand_as(x_n)

        # Single scalar => expand to entire batch
        dW_n = dWs[n].expand(batch_size)

        # >>> DRIFT-IMPLICIT STEP:
        x_new, t_new = drift_implicit_update_torch(
            x_n, t_n, dt, sqrt_epsilon, dW_n
        )

        X_all[n+1] = x_new
        t_all[n+1] = t_all[n] + dt

    return X_all, t_all

def simulate_sde_euler_explicit_same_noise(
    X0, t0, dt, N, eps, device='cpu', seed=1234
):
    X0 = X0.to(device)
    batch_size = X0.shape[0]
    X_all = torch.zeros(N+1, batch_size, device=device)
    t_all = torch.zeros(N+1, device=device)

    X_all[0] = X0
    t_all[0] = t0

    sqrt_epsilon = torch.sqrt(torch.tensor(eps, device=device))

    # Pre-generate random increments on CPU for cross-device consistency
    gen = torch.Generator(device='cpu').manual_seed(seed)
    dWs_cpu = torch.randn(N, generator=gen)*np.sqrt(dt)

    # Copy to device if needed
    dWs = dWs_cpu if device=='cpu' else dWs_cpu.to(device)

    for n in range(N):
        x_n = X_all[n]
        t_n = t_all[n].expand_as(x_n)

        dW_n = dWs[n].expand(batch_size)  # share same increment
        x_next, t_next = euler_maruyama_update_torch(x_n, t_n, dt, sqrt_epsilon, dW_n)

        X_all[n+1] = x_next
        t_all[n+1] = t_all[n] + dt

    return X_all, t_all

##############################################################################
# F) TIMING
##############################################################################

def time_cpu_vs_gpu():
    # Example usage for timing
    batch_size = 500000
    N = 1000
    dt = 0.01
    eps = 0.1
    t0 = 0.0

    # CPU
    X0_cpu = torch.zeros(batch_size, dtype=torch.float32)
    print("==> Running drift-implicit on CPU...")
    start_cpu = time.time()
    X_cpu, t_cpu = simulate_sde_same_noise(
        X0_cpu, t0, dt, N, eps, device='cpu', seed=1234
    )
    cpu_time = time.time() - start_cpu
    print(f"CPU time: {cpu_time:.4f} s")

    # GPU (if available)
    if torch.cuda.is_available():
        print("==> Running drift-implicit on GPU...")
        X0_gpu = X0_cpu.clone()  # same initial state
        start_gpu = time.time()
        X_gpu, t_gpu = simulate_sde_same_noise(
            X0_gpu, t0, dt, N, eps, device='cuda', seed=1234
        )
        torch.cuda.synchronize()
        gpu_time = time.time() - start_gpu
        print(f"GPU time: {gpu_time:.4f} s")
    else:
        print("No CUDA device found; skipping GPU test.")

def time_cpu_vs_gpu_euler_explicit():
    # Problem size
    batch_size = 1000000
    N = 1000
    dt = 0.01
    eps = 0.1
    t0 = 0.0

    # All initial states zero
    X0_cpu = torch.zeros(batch_size, dtype=torch.float32)

    print("==> Running Explicit Euler–Maruyama on CPU...")
    start_cpu = time.time()
    X_cpu, t_cpu = simulate_sde_euler_explicit_same_noise(
        X0=X0_cpu, t0=t0, dt=dt, N=N, eps=eps, device='cpu', seed=1234
    )
    end_cpu = time.time()
    cpu_time = end_cpu - start_cpu
    print(f"CPU time: {cpu_time:.4f} s")

    # GPU if available
    if torch.cuda.is_available():
        print("==> Running Explicit Euler–Maruyama on GPU...")
        X0_gpu = X0_cpu.clone()
        start_gpu = time.time()
        X_gpu, t_gpu = simulate_sde_euler_explicit_same_noise(
            X0=X0_gpu, t0=t0, dt=dt, N=N, eps=eps, device='cuda', seed=1234
        )
        torch.cuda.synchronize()  # ensure completion
        end_gpu = time.time()
        gpu_time = end_gpu - start_gpu
        print(f"GPU time: {gpu_time:.4f} s")
    else:
        print("No CUDA device found. Skipping GPU test.")

##############################################################################
# G) Check CPU == GPU
##############################################################################

def check_cpu_vs_gpu_match():
    # Parameters
    batch_size = 10    
    N = 1000            
    dt = 0.01
    eps = 0.5
    t0 = 0.0

    # CPU version
    X0_cpu = torch.zeros(batch_size, dtype=torch.float32)
    X_cpu, t_cpu = simulate_sde_same_noise(
        X0_cpu, t0, dt, N, eps, device='cpu', seed=1234
    )

    if not torch.cuda.is_available():
        print("No CUDA device found; skipping GPU check.")
        return

    # GPU version
    X0_gpu = X0_cpu.clone()  # same initial condition
    X_gpu, t_gpu = simulate_sde_same_noise(
        X0_gpu, t0, dt, N, eps, device='cuda', seed=1234
    )

    # Bring GPU results to CPU for comparison
    X_gpu_cpu = X_gpu.cpu()
    t_gpu_cpu = t_gpu.cpu()

    # Compare shapes
    if X_cpu.shape != X_gpu_cpu.shape or t_cpu.shape != t_gpu_cpu.shape:
        print("ERROR: CPU and GPU shapes differ!")
        return

    # Compare numerical values
    diff_X = torch.abs(X_cpu - X_gpu_cpu).max().item()
    diff_t = torch.abs(t_cpu - t_gpu_cpu).max().item()

    tol = 1e-6  # example tolerance
    if diff_X < tol and diff_t < tol:
        print(f"SUCCESS: CPU and GPU match within tolerance {tol}.")
        print(f"Max difference in X: {diff_X}, in t: {diff_t}")
    else:
        print("WARNING: CPU and GPU differ.")
        print(f"Max diff in X: {diff_X}, in t: {diff_t}")
        print(f"(You can raise tolerance if small floating errors are okay.)")


if __name__ == "__main__":
    check_cpu_vs_gpu_match()
    time_cpu_vs_gpu()
    time_cpu_vs_gpu_euler_explicit()

SUCCESS: CPU and GPU match within tolerance 1e-06.
Max difference in X: 1.1920928955078125e-07, in t: 0.0
==> Running drift-implicit on CPU...
CPU time: 123.8162 s
==> Running drift-implicit on GPU...
GPU time: 19.6495 s
==> Running Explicit Euler–Maruyama on CPU...
CPU time: 30.2360 s
==> Running Explicit Euler–Maruyama on GPU...
GPU time: 2.9217 s
