In [60]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from tqdm import tqdm
from dynamics import linear_dynamics,linear_1d
import seaborn as sns
from scipy.stats import special_ortho_group


class Args:
    batch_size = 64
    learning_rate = 0.001
    optimizer = 'Adam'
    layer_dims = (16,8,1)
    dt = 0.1
    num_epochs = 20000
    lr_scheduler_step_size = 50
    lr_scheduler_gamma = 1.0
    dynamics_function = linear_1d
    state_dim = 1  # Define the dimension of the state here
    margin = 5
    init_range = (-1,1)
    seed = 4 #TODO: add seed in random
    num_steps_test = 500
    weight_n = 0.8

class EnergyPredictor(nn.Module):
    def __init__(self):
        super(EnergyPredictor, self).__init__()
        layers = [
            nn.Linear(Args.state_dim, Args.layer_dims[0]),
            nn.ReLU()
        ]
        for i in range(len(Args.layer_dims) - 1):
            layers.append(nn.Linear(Args.layer_dims[i], Args.layer_dims[i + 1]))
            layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return (self.layers(x))**2

def contrastive_loss(model, x_t, x_tp1, x_neg):
    """Computes the contrastive loss.

    Args:
        model (nn.Module): The neural network model.
        x_t (torch.Tensor): The current state tensor.
        x_tp1 (torch.Tensor): The next state tensor.
        x_neg (torch.Tensor): The negative samples tensor.

    Returns:
        torch.Tensor: The computed contrastive loss.
    """
    E_xt = model(x_t)
    
    E_xtp1 = model(x_tp1)
    E_xneg = model(x_neg)
    lp1 =  torch.relu(Args.margin - torch.mean((E_xt - E_xtp1)**2 / Args.dt))  #
#     lp2 = torch.mean(E_xtp1)
    np1 = torch.mean(((-E_xt + E_xneg)**2) / Args.dt)
#     np2 = -torch.mean(((E_xneg)))
    #lp1 =  torch.mean(E_xtp1)
    
    #np1 = torch.relu(Args.margin - torch.mean(E_xneg))
    print(lp1,np1)
    loss = lp1 + (np1) * Args.weight_n

    return loss

def train(model, loss_fn):
    """Trains the model on the given dataset.

    Args:
        model (nn.Module): The neural network model.
        dynamics_fn (function): The dynamics function.
        loss_fn (function): The loss function.
    """
    if Args.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=Args.learning_rate)
    else:
        raise ValueError("Invalid optimizer specified in Args.")

    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=Args.lr_scheduler_step_size, gamma=Args.lr_scheduler_gamma)

    for epoch in tqdm(range(Args.num_epochs)):
    
        x_t = torch.tensor(np.random.uniform(Args.init_range[0],Args.init_range[1],(Args.batch_size, 1,Args.state_dim)), dtype=torch.float32)
        x_tp1 = Args.dynamics_function(x_t, Args.dt)

        #noise = np.random.normal(0, Args.negative_sample_variance, (Args.batch_size, Args.state_dim))
        if Args.state_dim > 1:
            randm = torch.tensor([special_ortho_group.rvs(dim=Args.state_dim) for _ in range(Args.batch_size)],dtype=torch.float32)
            x_neg = torch.bmm(x_t,randm)
        else:
            randm = -torch.ones(x_t.shape)
            x_neg = torch.multiply(x_t,randm)

        optimizer.zero_grad()
        loss = loss_fn(model, x_t, x_tp1, x_neg)
        loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

def evaluate(model, x_t):
    """Evaluates the model on given dataset.

    Args:
        model (nn.Module): The neural network model.
        x_t (torch.Tensor): The current state tensor.

    Returns:
        np.ndarray: Array of prediction errors.
    """
    errors = []
    preds, actuals = [],[]
    grads = []
    for _ in range(1,Args.num_steps_test):
        x_t.requires_grad_(True)  # Ensure that x_t has requires_grad=True
        E_xt = model(x_t)
        E_xt.backward(torch.ones_like(E_xt))
        grad_E_xt = x_t.grad
        grads.append(grad_E_xt.detach().numpy())
        if _ == 1:
            x_t_inferred = x_t - Args.dt * grad_E_xt
        else:
            x_t_inferred = x_t_inferred - Args.dt * grad_E_xt
        x_t_actual = Args.dynamics_function(x_t.detach(), Args.dt)
        error = torch.norm(x_t_inferred.detach() - x_t_actual, dim=-1).numpy()
        errors.append(error)
        x_t = x_t_actual
        preds.append(x_t_inferred.detach().numpy())
        actuals.append(x_t_actual.detach().numpy())

    return np.stack(errors),np.array(preds).squeeze(),np.array(actuals).squeeze(),np.array(grads).squeeze()

def plot_energy(model, x_ranges, axes=None):
    """
    Plots the energy landscape of the model. If the system has one dimension, a 2D plot is
    generated. For systems with more than one dimension, a 3D plot is created, with the option to
    specify which axes to plot.

    Args:
        model (nn.Module): The neural network model.
        x_ranges (list of np.ndarray): A list of 1D NumPy arrays representing the ranges of each
            dimension in the input space. The length of the list should match the number of
            dimensions in the input space.
        axes (tuple of int, optional): A tuple containing the indices of the dimensions to plot on
            the x and y axes. Only applicable when the system has more than one dimension. If not
            provided, the first two dimensions (0, 1) will be plotted by default.
    """
    if len(x_ranges) == 1:
        x = torch.tensor(np.array(x_ranges), dtype=torch.float32)
        energy = model(x.T).detach().numpy().squeeze()
        fig,ax = plt.subplots(1,1)#,figsize=(10,5))
        plt.plot(x_ranges[0], energy)
        plt.xlabel("x")
        plt.ylabel("Energy")
        plt.tight_layout()
    else:
        if axes is None:
            axes = (0, 1)

        x_meshgrid = np.meshgrid(*[x_range for i, x_range in enumerate(x_ranges) if i in axes])
        x = np.stack(x_meshgrid, axis=-1)
        x_full = np.zeros((x.shape[:-1] + (len(x_ranges),)), dtype=np.float32)
        x_full[..., axes] = x
        x_full = torch.tensor(x_full, dtype=torch.float32)

        energy = model(x_full).detach().numpy().squeeze()

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(*x_meshgrid, energy, cmap=cm.viridis)
        plt.tight_layout()

def plot_errors(errors, dt_values):
    colors = cm.viridis(np.linspace(0, 1, len(dt_values)))

    for dt, error, color in zip(dt_values, errors, colors):
        mean = np.mean(error, axis=1)
        std_error = np.std(error, axis=1) / np.sqrt(error.shape[1])
        plt.plot(range(1,len(mean)+1), mean, color=color, label=f'dt = {dt}')
        plt.fill_between(range(1,len(mean)+1), mean - std_error, mean + std_error, color=color, alpha=0.3)

    plt.xlabel("Steps")
    plt.ylabel("Prediction Error")
    plt.legend()

def plot_dynamics():
    '''
    plot the dynamics to make sure we're not insane
    '''
    x_t = np.random.uniform(Args.init_range[0],Args.init_range[1],(Args.state_dim))
    xs = np.zeros((Args.num_steps_test,x_t.shape[-1]))
    xs[0] = x_t
    for i in range(1,Args.num_steps_test):
        xs[i] = Args.dynamics_function(xs[i-1],Args.dt)
    for i in range(Args.state_dim):
        plt.plot(xs[:,i],label=f"dimension {i}")
    plt.xlabel("time")
    plt.ylabel("value")
    plt.legend()



In [None]:
plot_dynamics()
model = EnergyPredictor()
train(model, contrastive_loss)


  0%|                                       | 49/20000 [00:00<00:41, 483.11it/s]

tensor(5., grad_fn=<ReluBackward0>) tensor(1.1264e-06, grad_fn=<MeanBackward0>)
Epoch 0, Loss: 5.000000953674316
tensor(5., grad_fn=<ReluBackward0>) tensor(6.6589e-07, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(3.6096e-07, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(2.1413e-07, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(1.0246e-07, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(4.3931e-08, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(2.1921e-08, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(1.3202e-08, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(4.5637e-09, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(2.5088e-09, grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(6.2604e-10, grad_fn=<MeanBackward0>)
Epoch 10, Loss: 5.0
tensor(5., grad_fn=<ReluBackward0>) tensor(2.1434e-10, grad_fn=<Mea

  1%|▎                                     | 153/20000 [00:00<00:39, 508.31it/s]

tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
Epoch 110, Loss: 5.0
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(

  1%|▍                                     | 259/20000 [00:00<00:38, 517.64it/s]

tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
Epoch 210, Loss: 5.0
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(

  2%|▋                                     | 364/20000 [00:00<00:38, 516.40it/s]

Epoch 310, Loss: 5.0
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
Epoch 320, Loss: 5.0
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<Re

  2%|▉                                     | 470/20000 [00:00<00:37, 520.15it/s]

tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
Epoch 420, Loss: 5.0
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(

  3%|█▏                                    | 630/20000 [00:01<00:36, 525.00it/s]

tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
Epoch 530, Loss: 5.0
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(0., grad_fn=<MeanBackward0>)
tensor(5., grad_fn=<ReluBackward0>) tensor(

In [None]:
x_ranges = [np.linspace(Args.init_range[0], Args.init_range[1], 1000) for _ in range(Args.state_dim)]

plot_energy(model, x_ranges)


In [None]:
num_samples = 1000
x_t = torch.tensor(np.random.uniform(Args.init_range[0],Args.init_range[1],(num_samples, Args.state_dim)), dtype=torch.float32, requires_grad=True)

errors,preds,actuals,grads = evaluate(model, x_t)
plt.plot(errors.mean(axis=1))
#plot_errors(errors, [Args.dt])

In [None]:
ind = 500
plt.plot(preds[:,ind],label="inferred")
plt.plot(actuals[:,ind],label="real dynamics")
plt.legend()

In [None]:
plt.scatter(x=actuals.reshape(-1),y=grads.reshape(-1),s=1)