In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit, value_and_grad
from functools import partial
import matplotlib.pyplot as plt
from tqdm import trange

# Import Flax and Optax
from flax import linen as nn
from flax.training import train_state
import optax

# Import attention utilities
from flax.linen.attention import make_causal_mask

# Define the exact solution of the underdamped harmonic oscillator
def oscillator(d, w0, t):
    w = jnp.sqrt(w0 ** 2 - d ** 2)
    phi = jnp.arctan(-d / w)
    A = 1.0 / (2.0 * jnp.cos(phi))
    cos_term = jnp.cos(phi + w * t)
    exp_term = jnp.exp(-d * t)
    x = exp_term * 2 * A * cos_term
    return x

# Define the Transformer model
class TransformerLayer(nn.Module):
    d_model: int
    num_heads: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, *, training: bool = True):
        # Create causal mask
        seq_len = x.shape[1]
        causal_mask = make_causal_mask(
            jnp.ones((x.shape[0], seq_len)), dtype=x.dtype
        )

        # Multi-head self-attention with causal mask
        attn = nn.SelfAttention(
            num_heads=self.num_heads,
            use_bias=True,
            broadcast_dropout=False,
            dropout_rate=self.dropout_rate,
            deterministic=not training,
            dtype=x.dtype,
        )(x, mask=causal_mask)

        # Add & Norm
        x = x + attn
        x = nn.LayerNorm()(x)

        # Feed-forward network
        ff = nn.Dense(self.d_model * 4)(x)
        ff = nn.tanh(ff)
        ff = nn.Dropout(rate=self.dropout_rate)(ff, deterministic=not training)
        ff = nn.Dense(self.d_model)(ff)

        # Add & Norm
        x = x + ff
        x = nn.LayerNorm()(x)
        return x

class TransformerPINN(nn.Module):
    d_model: int
    num_heads: int
    num_layers: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, t, *, training: bool = True):
        # Input embedding
        x = t.reshape(t.shape[0], -1, 1)
        x = nn.Dense(self.d_model)(x)
        x = nn.tanh(x)

        # Transformer layers
        for _ in range(self.num_layers):
            x = TransformerLayer(
                self.d_model, self.num_heads, self.dropout_rate
            )(x, training=training)

        # Output layer
        x = nn.Dense(1)(x)
        x = x.squeeze(-1)
        return x

def create_train_state(rng, learning_rate, model, params):
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx
    )

def compute_residual_per_t(t_single, x_single, params):
    t_scalar = t_single[0]
    x_scalar = x_single[0]

    def net_single(t_scalar, params):
        t_input = t_scalar.reshape(1, 1)
        x = model.apply({'params': params}, t_input, training=False)
        x_value = x[0, 0]
        return x_value

    # Compute x_t and x_tt
    x_t = grad(lambda t: net_single(t, params))(t_scalar)
    x_tt = grad(lambda t: grad(lambda t_inner: net_single(t_inner, params))(t))(t_scalar)

    # Compute residual
    residual = x_tt + mu * x_t + k * x_scalar

    # Compute gradient of x_tt w.r.t params
    grad_x_tt = grad(lambda params: x_tt)(params)

    # Compute the norm of the gradient
    grad_x_tt_norm = jax.tree_util.tree_reduce(
        lambda x, y: x + jnp.sum(y ** 2), grad_x_tt, initializer=0.0
    )
    grad_x_tt_norm = jnp.sqrt(grad_x_tt_norm)

    return residual, grad_x_tt_norm

def compute_residual(t, x_pred, params):
    # residuals and grad norms are arrays of shape (sequence_length,)
    residuals, grad_norms = vmap(
        compute_residual_per_t, in_axes=(1, 1, None)
    )(t, x_pred, params)
    return residuals, grad_norms

def loss_fn(model, params, t_r, t_data, x_data, rng1, rng2):
    # Forward pass for residual points
    x_pred_r = model.apply({'params': params}, t_r, training=True, rngs={'dropout': rng1})
    (residual, grad_norms) = compute_residual(t_r, x_pred_r, params)
    loss_res = jnp.mean(residual ** 2)
    grad_norm_mean = jnp.mean(grad_norms)

    # Forward pass for data points
    x_pred_data = model.apply({'params': params}, t_data, training=True, rngs={'dropout': rng2})
    loss_data = jnp.mean((x_pred_data - x_data) ** 2)

    # Initial condition loss
    x0_pred = x_pred_data[0, 0]
    def net_single(t0):
        t_input = t0.reshape(1, 1)
        x = model.apply({'params': params}, t_input, training=False)
        x_value = x[0, 0]
        return x_value
    x0_t_pred = grad(net_single)(jnp.array(0.0))
    loss_ic = (x0_pred - 1.0) ** 2 + (x0_t_pred - 0.0) ** 2

    # Total loss
    total_loss = loss_res + 0.0*loss_data + loss_ic

    return total_loss, grad_norm_mean

@jit
def train_step(state, t_r, t_data, x_data, rng):
    params = state.params

    rng1, rng2 = random.split(rng)

    def loss_and_grad(params):
        loss_value, grad_norm_mean = loss_fn(model, params, t_r, t_data, x_data, rng1, rng2)
        return loss_value, grad_norm_mean

    (loss_value, grad_norm_mean), grads = value_and_grad(loss_and_grad, has_aux=True)(params)
    state = state.apply_gradients(grads=grads)
    return state, loss_value, grad_norm_mean

# Constants for the underdamped harmonic oscillator
d = 2.0
w0 = 20.0
mu = 2 * d
k = w0 ** 2

# Generate data
t_r = jnp.linspace(0.0, 1.0, 100).reshape(1, -1)
t_data = jnp.linspace(0.0, 0.5, 25).reshape(1, -1)
x_data = oscillator(d, w0, t_data).reshape(1, -1)

# Initialize model and training state
rng = random.PRNGKey(0)
model = TransformerPINN(d_model=128, num_heads=4, num_layers=2, dropout_rate=0.1)
params = model.init(rng, t_r)['params']
state = create_train_state(rng, learning_rate=1e-3, model=model, params=params)

# Training loop
loss_log = []
nIter = 10000
pbar = trange(nIter)
for it in pbar:
    rng, step_rng = random.split(rng)
    state, loss_value, grad_norm_mean = train_step(state, t_r, t_data, x_data, step_rng)
    if it % 100 == 0:
        loss_log.append(loss_value)
        pbar.set_postfix({'Loss': loss_value, 'GradNorm_x_tt': grad_norm_mean})

# Evaluation
params = state.params
t_test = jnp.linspace(0.0, 2.0, 200).reshape(1, -1)
x_pred = model.apply({'params': params}, t_test, training=False)
x_pred = x_pred.flatten()
t_test_flat = t_test.flatten()
x_exact = oscillator(d, w0, t_test_flat)

# Plot results
plt.figure()
plt.plot(t_test_flat, x_exact, label='Exact Solution')
plt.plot(t_test_flat, x_pred, '--', label='Transformer PINN Prediction')
plt.scatter(t_data.flatten(), x_data.flatten(), color='red', label='Training Data')
plt.legend()
plt.xlabel('Time $t$')
plt.ylabel('Displacement $x(t)$')
plt.title('Under-damped Harmonic Oscillator with Causal Attention')
plt.show()

# Plot training loss
plt.figure()
plt.plot(np.arange(len(loss_log))*100, loss_log)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.title('Training Loss')
plt.show()


  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
# Before training
t0 = jnp.array([[0.0]])  # Shape: (1, 1)
x0_pred = model.apply({'params': params}, t0, training=False)
print("Initial prediction t(0):", x0_pred)


Initial prediction t(0): [[0.9999174]]


In [None]:
print("t_r shape:", t_r.shape)
print("t_data shape:", t_data.shape)
print("x_data shape:", x_data.shape)


t_r shape: (1, 100)
t_data shape: (1, 25)
x_data shape: (1, 25)


In [5]:
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt

# Seed for reproducibility
key = random.PRNGKey(0)

# 1. Generate a linearly separable dataset
def generate_linearly_separable_data(key, num_samples=100, dim=2):
    key, subkey = random.split(key)
    # Random hyperplane parameters (optimal weights and bias)
    w_star = random.normal(subkey, (dim,))
    b_star = random.normal(subkey)
    
    # Scale w_star and b_star so that the margin gamma = 1
    gamma = jnp.min(jnp.abs(jnp.dot(w_star, w_star) + b_star))
    w_star = w_star / gamma
    b_star = b_star / gamma

    key, subkey = random.split(key)
    # Generate random points
    X = random.normal(subkey, (num_samples, dim))
    # Labels determined by the hyperplane
    y = jnp.sign(jnp.dot(X, w_star) + b_star)
    return X, y, w_star, b_star

X, y, w_star, b_star = generate_linearly_separable_data(key, num_samples=1000, dim=20)

# 2. Initialize the Perceptron parameters
w = jnp.zeros_like(w_star)
b = 1.0

# For tracking purposes
t = 0  # Number of updates
max_iterations = 1000  # To prevent infinite loops in case of errors
R = jnp.max(jnp.linalg.norm(X, axis=1))
norm_w_list = []
wt_wstar_list = []

# 3. Perceptron Learning Algorithm
while t < max_iterations:
    errors = 0
    for i in range(len(X)):
        if y[i] * (jnp.dot(w, X[i]) + b) <= 0:
            # Update rule
            w = w + y[i] * X[i]
            b = b + y[i]
            t += 1
            errors += 1
            
            # Track the norm of w and the inner product with w_star
            norm_w = jnp.linalg.norm(w)
            norm_w_list.append(norm_w)
            wt_wstar = jnp.dot(w, w_star)
            wt_wstar_list.append(wt_wstar)
            
            # Print the key steps corresponding to the proof
            print(f"Update {t}:")
            print(f"  w^T * w_star = {wt_wstar}")
            print(f"  ||w||^2 = {norm_w ** 2}")
            print(f"  Bound on ||w||^2: {t * R ** 2}")
            print(f"  Minimum required w^T * w_star: {t}")
            break  # Go back to the beginning after an update
    if errors == 0:
        print("Perceptron has converged.")
        break

if t == max_iterations:
    print("Reached maximum iterations without convergence.")

# 4. Plotting the increase in alignment and the growth of the weight vector norm
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(wt_wstar_list, label=r'$w_t^\top w^*$')
plt.plot([i for i in range(1, t+1)], label='Number of updates t')
plt.xlabel('Update step')
plt.ylabel(r'$w_t^\top w^*$')
plt.title('Increase in Alignment with $w^*$')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot([nw ** 2 for nw in norm_w_list], label=r'$||w_t||^2$')
plt.plot([i * R ** 2 for i in range(1, t+1)], label=r'Upper bound $t R^2$')
plt.xlabel('Update step')
plt.ylabel(r'$||w_t||^2$')
plt.title('Growth of the Weight Vector Norm')
plt.legend()

plt.tight_layout()
plt.show()


Update 1:
  w^T * w_star = -0.028189411386847496
  ||w||^2 = 16.195280075073242
  Bound on ||w||^2: 44.102630615234375
  Minimum required w^T * w_star: 1
Update 2:
  w^T * w_star = 0.2931363582611084
  ||w||^2 = 28.372013092041016
  Bound on ||w||^2: 88.20526123046875
  Minimum required w^T * w_star: 2
Update 3:
  w^T * w_star = 0.3195192813873291
  ||w||^2 = 45.266761779785156
  Bound on ||w||^2: 132.30789184570312
  Minimum required w^T * w_star: 3
Update 4:
  w^T * w_star = 0.41354697942733765
  ||w||^2 = 49.0037841796875
  Bound on ||w||^2: 176.4105224609375
  Minimum required w^T * w_star: 4
Update 5:
  w^T * w_star = 0.9777462482452393
  ||w||^2 = 78.6429443359375
  Bound on ||w||^2: 220.51315307617188
  Minimum required w^T * w_star: 5
Update 6:
  w^T * w_star = 1.2408562898635864
  ||w||^2 = 96.3493423461914
  Bound on ||w||^2: 264.61578369140625
  Minimum required w^T * w_star: 6
Update 7:
  w^T * w_star = 1.3473656177520752
  ||w||^2 = 112.26011657714844
  Bound on ||w||^2: 3

KeyboardInterrupt: 

In [15]:
import jax
import jax.numpy as jnp
from jax import random, jit
import matplotlib.pyplot as plt

# Seed for reproducibility
key = random.PRNGKey(0)

# 1. Define the target function
def f(x):
    return jnp.sin(2 * jnp.pi * x)

# 2. Define the activation function (sigmoid)
def sigmoid(z):
    return 1 / (1 + jnp.exp(-z))

# 3. Neural network approximation
def neural_network(params, x):
    W, b, alpha = params['W'], params['b'], params['alpha']
    hidden_layer = sigmoid(jnp.dot(x, W.T) + b)
    output = jnp.dot(hidden_layer, alpha)
    return output

# 4. Training the neural network
def train_network(key, num_neurons=10, num_epochs=10000, learning_rate=0.01):
    # Initialize parameters
    key, subkey = random.split(key)
    W = random.normal(subkey, (num_neurons, 1))  # Shape: (num_neurons, input_dim)
    b = random.normal(subkey, (num_neurons,))    # Shape: (num_neurons,)
    alpha = random.normal(subkey, (num_neurons,))  # Output weights

    params = {'W': W, 'b': b, 'alpha': alpha}

    # Generate training data
    x = jnp.linspace(0, 1, 100).reshape(-1, 1)
    y_true = f(x).reshape(-1)

    # Training loop
    for epoch in range(num_epochs):
        # Forward pass
        y_pred = neural_network(params, x).reshape(-1)
        loss = jnp.mean((y_pred - y_true) ** 2)

        # Compute gradients
        def loss_fn(params):
            y_pred = neural_network(params, x).reshape(-1)
            return jnp.mean((y_pred - y_true) ** 2)

        grads = jax.grad(loss_fn)(params)

        # Update parameters
        params['W'] = params['W'] - learning_rate * grads['W']
        params['b'] = params['b'] - learning_rate * grads['b']
        params['alpha'] = params['alpha'] - learning_rate * grads['alpha']

        # Optional: Print loss every 1000 epochs
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")

    return params

# 5. Train the network
trained_params = train_network(key, num_neurons=10000, num_epochs=10000, learning_rate=0.01)

# 6. Plot the results
x_test = jnp.linspace(0, 1, 200).reshape(-1, 1)
y_test = f(x_test)
y_pred = neural_network(trained_params, x_test)

plt.figure(figsize=(8, 6))
plt.plot(x_test, y_test, label='Target Function $f(x) = \sin(2\pi x)$')
plt.plot(x_test, y_pred, label='Neural Network Approximation')
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.title('Universal Approximation Theorem Demonstration')
plt.legend()
plt.show()


Epoch 0, Loss: 6802215.5
Epoch 1000, Loss: nan


KeyboardInterrupt: 