In [None]:
# Simulated Object class

from dataclasses import dataclass
from typing import Optional

from src.dynamics.solver import DynamicsSolver

import torch

@dataclass
class SimulatedObject():
    mass: float
    coordinates: tuple
    velocity: tuple
    acceleration: tuple
    solver: DynamicsSolver
    save_past: int = 0
    past: list = None

    def update(self, dt: float, force: Optional[tuple] = None):
        if self.past is None and self.save_past > 0:
            self.past = []

        if self.past is not None:
            self.past.append((self.coordinates, self.velocity, self.acceleration))
            if len(self.past) > self.save_past:
                self.past.pop(0)

        self.acceleration = tuple([f / self.mass for f in force])
        x, v = self.solver.apply_force(x=torch.tensor(self.coordinates).view(1,1,-1), v=torch.tensor(self.velocity).view(1,1,-1), force=torch.tensor(force).view(1,1,-1), dt=dt)
        self.coordinates = tuple(x[0,-1,:].tolist())
        self.velocity = tuple(v[0,-1,:].tolist())

    def get_past_coordinates(self, dt: Optional[float] = None) -> tuple:
        if self.past is None:
            return [self.coordinates]
        
        if dt is None:
            dt = len(self.past)
        
        return [x for x, v, a in self.past[dt:]] + [self.coordinates]
    
    def get_past_velocity(self, dt: Optional[float] = None) -> tuple:
        if dt is None:
            dt = len(self.past)

        if self.past is None:
            return [self.velocity]
        return [v for x, v, a in self.past[dt:]] + [self.velocity]

    

In [None]:
# Agent class

from src.model.dynamics_model import DYNAMIC_MODELS

model_type = "dynamical_lstm"
model_save = "lightning_logs/version_0/checkpoints/epoch=9-step=8790.ckpt"

dynamical_predictor = DYNAMIC_MODELS[model_type].load_from_checkpoint(checkpoint_path=model_save)

def dynamical_agent(state: SimulatedObject) -> tuple:
    x = state.get_past_coordinates()
    x = torch.tensor(x).unsqueeze(0).to(dynamical_predictor.device)
    if x.size(1) < 2:
        x = x.repeat(1, 2, 1)
    
    force = dynamical_predictor(x)
    force = force * 0.001
    force = tuple(force[0,-1,:].tolist())

    return force
    


In [None]:
# Simulation parameters

minX, maxX = -10, 10
minY, maxY = -10, 10
minV, maxV = -0.2, 0.2
maxT = 400
nbObjects = 0
nbGravObjects = 0
nbAgents = 2
mass = 1
gravity_field = -9.81 * 1e-3
tau = 5

In [None]:
# Set forces
import math

def bounceWall(object: SimulatedObject, dimension: int, is_max: bool, wall_coordinate: int):
    force = [0,0]
    if is_max and object.coordinates[dimension] > wall_coordinate:
        force[dimension] = -2 * object.mass * abs(object.velocity[dimension])
    elif not is_max and object.coordinates[dimension] < wall_coordinate:
        force[dimension] = 2 * object.mass * abs(object.velocity[dimension])
    return force + [0]

def gravity(object: SimulatedObject, gravity_field: float):
    return (0, gravity_field * object.mass, 0)

def radialGravity(object: SimulatedObject, gravity_field: float):
    center = (maxX + minX) / 2, (maxY + minY) / 2
    angle = math.atan2(object.coordinates[1] - center[1], object.coordinates[0] - center[0])
    return (math.cos(angle) * gravity_field * object.mass, math.sin(angle) * gravity_field * object.mass, 0)


object_forces = [
    lambda object, dt: bounceWall(object, 0, True, maxX),
    lambda object, dt: bounceWall(object, 0, False, minX),
    lambda object, dt: bounceWall(object, 1, True, maxY),
    lambda object, dt: bounceWall(object, 1, False, minY)
]
gravobject_forces = object_forces + [lambda object, dt: radialGravity(object, gravity_field)]
agent_forces = object_forces + [lambda object, dt: dynamical_agent(object)]

In [None]:
%matplotlib notebook

# Run and display simulation

import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np


individuals = [SimulatedObject(
    mass=mass,
    coordinates=(np.random.uniform(minX, maxX), np.random.uniform(minY, maxY), 0),
    velocity=(np.random.uniform(minV, maxV), np.random.uniform(minV, maxV), 0),
    acceleration=(0, 0, 0),
    solver=DynamicsSolver(mass=mass,dimensions=3),
    save_past=tau,
) for _ in range(nbObjects + nbGravObjects + nbAgents)]

fig, ax = plt.subplots()
ax.axis([minX, maxX, minY, maxY])
lo, = ax.plot([], [], "ro")
lg, = ax.plot([], [], "bo")
la, = ax.plot([], [], "go")

def animate(i):
    for i, individual in enumerate(individuals):
        force = [0, 0, 0]
        if i < nbObjects:
            forces = object_forces
        elif i < nbObjects + nbGravObjects:
            forces = gravobject_forces
        elif i < nbObjects + nbGravObjects + nbAgents:
            forces = agent_forces
        else:
            raise ValueError("Too many individuals")

        for f in forces:
            dx, dy, dz = f(individual, 1)
            force[0] += dx
            force[1] += dy
        individual.update(1, tuple(force))
    
    lo.set_data([object.coordinates[0] for object in individuals[:nbObjects]], [object.coordinates[1] for object in individuals[:nbObjects]])
    lg.set_data([gravobject.coordinates[0] for gravobject in individuals[nbObjects:nbObjects+nbGravObjects]], [gravobject.coordinates[1] for gravobject in individuals[nbObjects:nbObjects+nbGravObjects]])
    la.set_data([agent.coordinates[0] for agent in individuals[nbObjects+nbGravObjects:]], [agent.coordinates[1] for agent in individuals[nbObjects+nbGravObjects:]])
    
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=maxT, interval=10)

from IPython.display import HTML
HTML(ani.to_jshtml())