In [163]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the Lorenz attractor nonlinear flow F(x)
def F(state):
    sigma = 10.0
    rho = 28.0
    beta = 8.0 / 3.0
    
    x = state[:, 0]
    y = state[:, 1]
    z = state[:, 2]
    
    dxdt = sigma * (y - x)
    dydt = x * (rho - z) - y
    dzdt = x * y - beta * z
    
    return torch.stack([dxdt, dydt, dzdt], dim=1)

# Define the neural network for a(x) with high-dimensional output
class MappingNet(nn.Module):
    def __init__(self, input_dim, output_dim=32):  # 32-dimensional a(x)
        super(MappingNet, self).__init__()
        self.normalizer = nn.BatchNorm1d(input_dim)  # Input normalizer
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fco = nn.Linear(64, output_dim)
        
    def forward(self, x):
        x = self.normalizer(x)  # Normalize input
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        a = self.fco(x)
        return a

# Define the learned linear operator A and offset b
class LinearParams(nn.Module):
    def __init__(self, input_dim):  # Map from 32-dimensional a(x) to 3-dimensional x
        super(LinearParams, self).__init__()
        self.A = nn.Parameter(torch.randn(input_dim, input_dim))
        self.b = nn.Parameter(torch.randn(input_dim))
        
    def forward(self, a):
        return torch.matmul(self.A, a.T).T + self.b

def compute_jacobian(a, x):
    jacobian = []
    for i in range(a.shape[1]):
        grad_output = torch.zeros_like(a)
        grad_output[:, i] = 1.0
        jac_i = torch.autograd.grad(a, x, grad_outputs=grad_output, create_graph=True)[0]
        jacobian.append(jac_i)
    return torch.stack(jacobian, dim=1)  # Shape: (batch_size, a_dim, input_dim)


def compute_jacobian(network, x):
    # Make sure x requires gradients
    x = x.clone().detach().requires_grad_(True)
    
    # Forward pass to get the output a(x)
    a = network(x)
    
    # Initialize an empty list to store the Jacobian
    jacobian = []
    
    # Compute the Jacobian for each output with respect to the input
    for i in range(a.shape[1]):
        # Create a tensor of the same shape as a, filled with zeros except for the i-th column
        grad_output = torch.zeros_like(a)
        grad_output[:, i] = 1.0
        
        # Compute the gradient of the i-th output w.r.t. x
        jac_i = torch.autograd.grad(outputs=a, inputs=x,
                                    grad_outputs=grad_output, create_graph=True)[0]
        
        jacobian.append(jac_i)
    
    # Stack the computed gradients along a new dimension to form the Jacobian
    return torch.stack(jacobian, dim=1)  # Shape: (batch_size, a_dim, input_dim)


# Initialize dimensions
input_dim = 3  # Lorenz system state: (x, y, z)
a_dim = 64    # High-dimensional a(x)
output_dim = 3 # Output dimension should match the state dimension

# Initialize models
mapping_net = MappingNet(input_dim, output_dim=a_dim)
linear_params = LinearParams(input_dim=a_dim)

# Optimizers for both parts
optimizer_a = optim.Adam(mapping_net.parameters(), lr=1e-3)
optimizer_linear = optim.Adam(linear_params.parameters(), lr=1e-3)

# Loss function
mse_loss = nn.MSELoss()

# Training loop
num_epochs = 3000
linear_steps = 5  # Number of linear optimizer steps per epoch

for epoch in range(num_epochs):

    # Generate more random input state data
    x = torch.randn((256, input_dim), requires_grad=True)  # More states (x, y, z) samples
    
    for _ in range(linear_steps):
        # Step 1: Optimize A and b
        optimizer_a.zero_grad()
        
        # Forward pass through the network
        a = mapping_net(x)
        
        # Recalculate the Jacobian of a(x)
        #Ja = compute_jacobian(a, x)
        Ja = compute_jacobian(mapping_net, x)
        
        # Recalculate the left and right sides
        left_side = torch.einsum('bij,bj->bi', Ja, F(x))
        right_side = linear_params(a)
        #left_side = F(x)
        #right_side =  torch.einsum('bij,bj->bi', pinv(Ja), linear_params(mapping_net(x)))
        
        # Compute the loss for a(x)
        loss = mse_loss(left_side, right_side)
        loss.backward(retain_graph=True)
        optimizer_a.step()
    
    # Step 2: Optimize a(x)
    optimizer_linear.zero_grad()
    
    # Forward pass through the network
    a = mapping_net(x)
    
    # Recalculate the Jacobian of a(x)
    #Ja = compute_jacobian(a, x)
    Ja = compute_jacobian(mapping_net, x)
    
    # Recalculate the left and right sides
    left_side = torch.einsum('bij,bj->bi', Ja, F(x))
    right_side = linear_params(a)
    #left_side = F(x)
    #right_side =  torch.einsum('bij,bj->bi', pinv(Ja), linear_params(mapping_net(x)))
    
    # Compute the loss for a(x)
    loss_a = mse_loss(left_side, right_side)
    loss_a.backward(retain_graph=True)
    optimizer_linear.step()

    #Javi = pinv(compute_jacobian(mapping_net(xv), xv))
    #left_side = torch.einsum('bij,bj->bi', Javi, mapping_net(xv))

    #print(left_side, F(xv))
    
    # Print loss every 100 epochs
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss (a): {loss_a.item()}')


Epoch 0, Loss (a): 0.737454891204834
Epoch 100, Loss (a): 0.0012763711856678128
Epoch 200, Loss (a): 0.00027541464078240097
Epoch 300, Loss (a): 9.676031913841143e-05
Epoch 400, Loss (a): 4.28910534537863e-05
Epoch 500, Loss (a): 3.9558606658829376e-05
Epoch 600, Loss (a): 6.744704296579584e-05
Epoch 700, Loss (a): 1.4464032574323937e-05
Epoch 800, Loss (a): 1.6896747183636762e-05
Epoch 900, Loss (a): 8.137214172165841e-05
Epoch 1000, Loss (a): 3.5528380976757035e-05
Epoch 1100, Loss (a): 5.8680379879660904e-06
Epoch 1200, Loss (a): 6.668429705314338e-06
Epoch 1300, Loss (a): 5.373057138058357e-05
Epoch 1400, Loss (a): 9.832109753915574e-06
Epoch 1500, Loss (a): 2.1224175725365058e-05
Epoch 1600, Loss (a): 7.894050213508308e-06
Epoch 1700, Loss (a): 0.0005194611730985343
Epoch 1800, Loss (a): 0.0001251824724022299
Epoch 1900, Loss (a): 2.4196753656724468e-05
Epoch 2000, Loss (a): 5.1929091569036245e-06
Epoch 2100, Loss (a): 1.0573543477221392e-05
Epoch 2200, Loss (a): 2.129108543158509

In [172]:
res =  torch.einsum('bij,bj->bi', (compute_jacobian(mapping_net, x)), F(x)) - linear_params(mapping_net(x))

In [173]:
res

tensor([[ 5.5608e-04,  4.8520e-05,  5.4945e-05,  ...,  2.3366e-04,
          7.4381e-05,  8.4932e-06],
        [ 8.9772e-04, -3.0435e-04, -2.4129e-04,  ..., -7.5167e-04,
         -5.2047e-04,  2.0322e-04],
        [ 6.4170e-04,  3.9267e-05, -2.0168e-05,  ...,  2.3327e-06,
          2.2807e-05, -2.5406e-04],
        ...,
        [ 6.5218e-04, -1.9912e-05, -3.8731e-05,  ..., -5.1633e-05,
         -5.6250e-05, -1.2657e-04],
        [ 6.9821e-04,  3.8606e-05, -6.2327e-05,  ..., -8.8867e-05,
         -7.5784e-05, -1.5930e-04],
        [ 6.5161e-04,  2.3288e-05, -3.0990e-05,  ..., -3.4820e-06,
         -4.2515e-05, -1.0845e-04]], grad_fn=<SubBackward0>)