<a href="https://colab.research.google.com/github/Nishant-Ramakuru/Inference-based-GNNS/blob/main/Predicting_Collective_Dynamics_w_GNN_Simulator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Simulator used to generate collective dynamics

This simulator makes use of a JAX-MD library to create simple simulations of collective dynamics of active agents using GPU hardware and is fully differentiable. The first model uses a simple interaction potential such that agents maintain a minimal distance from each other and have an individual self-propulsion force. Additionally agents move stochastically such that their orientation is randomly determined. 

## TODO: 

- Output neighbour lists (as adjacency lists)
- Extend to include alignment
- Extend to include populations of agents
- Extend to include novel interaction dynamics 

In [1]:
!pip install jax
!pip install jax-md
!pip install tqdm
!pip install iteration-utilities

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jax-md
  Downloading jax_md-0.2.5-py2.py3-none-any.whl (144 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.5/144.5 KB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting e3nn-jax
  Downloading e3nn_jax-0.15.0-py3-none-any.whl (126 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.1/126.1 KB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting flax
  Downloading flax-0.6.4-py3-none-any.whl (204 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.3/204.3 KB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dm-haiku
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 KB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
[?25hC

In [48]:
import pickle

import numpy as onp

from jax.config import config ; config.update('jax_enable_x64', True) 
import jax.numpy as np 
from jax import random 
from jax import jit  
from jax import lax  
from jax import vmap 
import random as rn

import time 
from tqdm import tqdm
from jax_md import space, smap, energy, minimize, quantity, simulate, partition,dataclasses
from jax_md import util
from collections import namedtuple 
from functools import partial 
from typing import Any, Callable, TypeVar, Union, Tuple, Dict, Optional

In [3]:
#@title Imports & Utils

# Imports

!pip install -q git+https://www.github.com/google/jax-md

import numpy as onp

from jax.config import config ; config.update('jax_enable_x64', True)
import jax.numpy as np
from jax import random
from jax import jit
from jax import vmap
from jax import lax
vectorize = np.vectorize

from functools import partial

from collections import namedtuple
import base64

import IPython
from google.colab import output

import os

from jax_md import space, smap, energy, minimize, quantity, simulate, partition, util
from jax_md.util import f32

# Plotting

import matplotlib.pyplot as plt
import seaborn as sns
  
sns.set_style(style='white')

dark_color = [56 / 256] * 3
light_color = [213 / 256] * 3
axis_color = 'white'

def format_plot(x='', y='', grid=True):  
  ax = plt.gca()
  
  ax.spines['bottom'].set_color(axis_color)
  ax.spines['top'].set_color(axis_color) 
  ax.spines['right'].set_color(axis_color)
  ax.spines['left'].set_color(axis_color)
  
  ax.tick_params(axis='x', colors=axis_color)
  ax.tick_params(axis='y', colors=axis_color)
  ax.yaxis.label.set_color(axis_color)
  ax.xaxis.label.set_color(axis_color)
  ax.set_facecolor(dark_color)
  
  plt.grid(grid)
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 1)):
  plt.gcf().patch.set_facecolor(dark_color)
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

# Progress Bars

from IPython.display import HTML, display
import time


def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 45%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

normalize = lambda v: v / np.linalg.norm(v, axis=1, keepdims=True)

# Rendering

renderer_code = IPython.display.HTML('''
<canvas id="canvas"></canvas>
<script>
  Rg = null;
  Ng = null;

  var current_scene = {
      R: null,
      N: null,
      is_loaded: false,
      frame: 0,
      frame_count: 0,
      boid_vertex_count: 0,
      boid_buffer: [],
      predator_vertex_count: 0,
      predator_buffer: [],
      disk_vertex_count: 0,
      disk_buffer: null,
      box_size: 0
  };

  google.colab.output.setIframeHeight(0, true, {maxHeight: 5000});

  async function load_simulation() {
    buffer_size = 400;
    max_frame = 800;

    result = await google.colab.kernel.invokeFunction(
        'notebook.GetObstacles', [], {});
    data = result.data['application/json'];

    if(data.hasOwnProperty('Disk')) {
      current_scene = put_obstacle_disk(current_scene, data.Disk);
    }

    for (var i = 0 ; i < max_frame ; i += buffer_size) {
      console.log(i);
      result = await google.colab.kernel.invokeFunction(
          'notebook.GetBoidStates', [i, i + buffer_size], {}); 
      
      data = result.data['application/json'];
      current_scene = put_boids(current_scene, data);
    }
    current_scene.is_loaded = true;

    result = await google.colab.kernel.invokeFunction(
        'notebook.GetPredators', [], {}); 
    data = result.data['application/json'];
    if (data.hasOwnProperty('R'))
      current_scene = put_predators(current_scene, data);

    result = await google.colab.kernel.invokeFunction(
          'notebook.GetSimulationInfo', [], {});
    current_scene.box_size = result.data['application/json'].box_size;
  }

  function initialize_gl() {
    const canvas = document.getElementById("canvas");
    canvas.width = 640;
    canvas.height = 640;

    const gl = canvas.getContext("webgl2");

    if (!gl) {
        alert('Unable to initialize WebGL.');
        return;
    }

    gl.viewport(0, 0, gl.drawingBufferWidth, gl.drawingBufferHeight);
    gl.clearColor(0.2, 0.2, 0.2, 1.0);
    gl.enable(gl.DEPTH_TEST);

    const shader_program = initialize_shader(
        gl, VERTEX_SHADER_SOURCE_2D, FRAGMENT_SHADER_SOURCE_2D);
    const shader = {
      program: shader_program,
      attribute: {
          vertex_position: gl.getAttribLocation(shader_program, 'vertex_position'),
      },
      uniform: {
          screen_position: gl.getUniformLocation(shader_program, 'screen_position'),
          screen_size: gl.getUniformLocation(shader_program, 'screen_size'),
          color: gl.getUniformLocation(shader_program, 'color'),
      },
    };
    gl.useProgram(shader_program);

    const half_width = 200.0;

    gl.uniform2f(shader.uniform.screen_position, half_width, half_width);
    gl.uniform2f(shader.uniform.screen_size, half_width, half_width);
    gl.uniform4f(shader.uniform.color, 0.9, 0.9, 1.0, 1.0);

    return {gl: gl, shader: shader};
  }

  var loops = 0;

  function update_frame() {
    gl.clear(gl.COLOR_BUFFER_BIT | gl.DEPTH_BUFFER_BIT);

    if (!current_scene.is_loaded) {
      window.requestAnimationFrame(update_frame);
      return;
    }

    var half_width = current_scene.box_size / 2.;
    gl.uniform2f(shader.uniform.screen_position, half_width, half_width);
    gl.uniform2f(shader.uniform.screen_size, half_width, half_width);

    if (current_scene.frame >= current_scene.frame_count) {
      if (!current_scene.is_loaded) {
        window.requestAnimationFrame(update_frame);
        return;
      }
      loops++;
      current_scene.frame = 0;
    }

    gl.enableVertexAttribArray(shader.attribute.vertex_position);

    gl.bindBuffer(gl.ARRAY_BUFFER, current_scene.boid_buffer[current_scene.frame]);
    gl.uniform4f(shader.uniform.color, 0.0, 0.35, 1.0, 1.0);
    gl.vertexAttribPointer(
      shader.attribute.vertex_position,
      2,
      gl.FLOAT,
      false,
      0,
      0
    );
    gl.drawArrays(gl.TRIANGLES, 0, current_scene.boid_vertex_count);

    if(current_scene.predator_buffer.length > 0)  {
      gl.bindBuffer(gl.ARRAY_BUFFER, current_scene.predator_buffer[current_scene.frame]);
      gl.uniform4f(shader.uniform.color, 1.0, 0.35, 0.35, 1.0);
      gl.vertexAttribPointer(
        shader.attribute.vertex_position,
        2,
        gl.FLOAT,
        false,
        0,
        0
      );
      gl.drawArrays(gl.TRIANGLES, 0, current_scene.predator_vertex_count);
    }
    
    if(current_scene.disk_buffer) {
      gl.bindBuffer(gl.ARRAY_BUFFER, current_scene.disk_buffer);
      gl.uniform4f(shader.uniform.color, 0.9, 0.9, 1.0, 1.0);
      gl.vertexAttribPointer(
        shader.attribute.vertex_position,
        2,
        gl.FLOAT,
        false,
        0,
        0
      );
      gl.drawArrays(gl.TRIANGLES, 0, current_scene.disk_vertex_count);
    }

    current_scene.frame++;
    if ((current_scene.frame_count > 1 && loops < 5) || 
        (current_scene.frame_count == 1 && loops < 240))
      window.requestAnimationFrame(update_frame);
    
    if (current_scene.frame_count > 1 && loops == 5 && current_scene.frame < current_scene.frame_count - 1)
      window.requestAnimationFrame(update_frame);
  }

  function put_boids(scene, boids) {
    const R = decode(boids['R']);
    const R_shape = boids['R_shape'];
    const theta = decode(boids['theta']);
    const theta_shape = boids['theta_shape'];

    function index(i, b, xy) {
      return i * R_shape[1] * R_shape[2] + b * R_shape[2] + xy; 
    }

    var steps = R_shape[0];
    var boids = R_shape[1];
    var dimensions = R_shape[2];

    if(dimensions != 2) {
      alert('Can only deal with two-dimensional data.')
    }

    // First flatten the data.
    var buffer_data = new Float32Array(boids * 6);
    var size = 8.0;
    for (var i = 0 ; i < steps ; i++) {
      var buffer = gl.createBuffer();
      for (var b = 0 ; b < boids ; b++) {
        var xi = index(i, b, 0);
        var yi = index(i, b, 1);
        var ti = i * boids + b;
        var Nx = size * Math.cos(theta[ti]); //N[xi];
        var Ny = size * Math.sin(theta[ti]); //N[yi];
        buffer_data.set([
          R[xi] + Nx, R[yi] + Ny,
          R[xi] - Nx - 0.5 * Ny, R[yi] - Ny + 0.5 * Nx,
          R[xi] - Nx + 0.5 * Ny, R[yi] - Ny - 0.5 * Nx,             
        ], b * 6);
      }
      gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
      gl.bufferData(gl.ARRAY_BUFFER, buffer_data, gl.STATIC_DRAW);

      scene.boid_buffer.push(buffer);
    }
    scene.boid_vertex_count = boids * 3;
    scene.frame_count += steps;
    return scene;
  }

  function put_predators(scene, boids) {
    // TODO: Unify this with the put_boids function.
    const R = decode(boids['R']);
    const R_shape = boids['R_shape'];
    const theta = decode(boids['theta']);
    const theta_shape = boids['theta_shape'];

    function index(i, b, xy) {
      return i * R_shape[1] * R_shape[2] + b * R_shape[2] + xy; 
    }

    var steps = R_shape[0];
    var boids = R_shape[1];
    var dimensions = R_shape[2];

    if(dimensions != 2) {
      alert('Can only deal with two-dimensional data.')
    }

    // First flatten the data.
    var buffer_data = new Float32Array(boids * 6);
    var size = 18.0;
    for (var i = 0 ; i < steps ; i++) {
      var buffer = gl.createBuffer();
      for (var b = 0 ; b < boids ; b++) {
        var xi = index(i, b, 0);
        var yi = index(i, b, 1);
        var ti = theta_shape[1] * i + b;
        var Nx = size * Math.cos(theta[ti]);
        var Ny = size * Math.sin(theta[ti]);
        buffer_data.set([
          R[xi] + Nx, R[yi] + Ny,
          R[xi] - Nx - 0.5 * Ny, R[yi] - Ny + 0.5 * Nx,
          R[xi] - Nx + 0.5 * Ny, R[yi] - Ny - 0.5 * Nx,             
        ], b * 6);
      }
      gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
      gl.bufferData(gl.ARRAY_BUFFER, buffer_data, gl.STATIC_DRAW);

      scene.predator_buffer.push(buffer);
    }
    scene.predator_vertex_count = boids * 3;
    return scene;
  }

  function put_obstacle_disk(scene, disk) {
    const R = decode(disk.R);
    const R_shape = disk.R_shape;
    const radius = decode(disk.D);
    const radius_shape = disk.D_shape;

    const disk_count = R_shape[0];
    const dimensions = R_shape[1];
    if (dimensions != 2) {
        alert('Can only handle two-dimensional data.');
    }
    if (radius_shape[0] != disk_count) {
        alert('Inconsistent disk radius count found.');
    }
    const segments = 32;

    function index(o, xy) {
        return o * R_shape[1] + xy;
    }

    // TODO(schsam): Use index buffers here.
    var buffer_data = new Float32Array(disk_count * segments * 6);
    for (var i = 0 ; i < disk_count ; i++) {
      var xi = index(i, 0);
      var yi = index(i, 1);
      for (var s = 0 ; s < segments ; s++) {
        const th = 2 * s / segments * Math.PI;
        const th_p = 2 * (s + 1) / segments * Math.PI;
        const rad = radius[i] * 0.8;
        buffer_data.set([
          R[xi], R[yi],
          R[xi] + rad * Math.cos(th), R[yi] + rad * Math.sin(th),
          R[xi] + rad * Math.cos(th_p), R[yi] + rad * Math.sin(th_p),
        ], i * segments * 6 + s * 6);
      }
    }
    var buffer = gl.createBuffer();
    gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
    gl.bufferData(gl.ARRAY_BUFFER, buffer_data, gl.STATIC_DRAW);
    scene.disk_vertex_count = disk_count * segments * 3;
    scene.disk_buffer = buffer;
    return scene;
  }

  // SHADER CODE

  const VERTEX_SHADER_SOURCE_2D = `
    // Vertex Shader Program.
    attribute vec2 vertex_position;
    
    uniform vec2 screen_position;
    uniform vec2 screen_size;

    void main() {
      vec2 v = (vertex_position - screen_position) / screen_size;
      gl_Position = vec4(v, 0.0, 1.0);
    }
  `;

  const FRAGMENT_SHADER_SOURCE_2D = `
    precision mediump float;

    uniform vec4 color;

    void main() {
      gl_FragColor = color;
    }
  `;

  function initialize_shader(
    gl, vertex_shader_source, fragment_shader_source) {

    const vertex_shader = compile_shader(
      gl, gl.VERTEX_SHADER, vertex_shader_source);
    const fragment_shader = compile_shader(
      gl, gl.FRAGMENT_SHADER, fragment_shader_source);

    const shader_program = gl.createProgram();
    gl.attachShader(shader_program, vertex_shader);
    gl.attachShader(shader_program, fragment_shader);
    gl.linkProgram(shader_program);

    if (!gl.getProgramParameter(shader_program, gl.LINK_STATUS)) {
      alert(
        'Unable to initialize shader program: ' + 
        gl.getProgramInfoLog(shader_program)
        );
        return null;
    }
    return shader_program;
  }

  function compile_shader(gl, type, source) {
    const shader = gl.createShader(type);
    gl.shaderSource(shader, source);
    gl.compileShader(shader);

    if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {
      alert('An error occured compiling shader: ' + gl.getShaderInfoLog(shader));
      gl.deleteShader(shader);
      return null;
    }

    return shader;
  }

  // SERIALIZATION UTILITIES
  function decode(sBase64, nBlocksSize) {
    var chrs = atob(atob(sBase64));
    var array = new Uint8Array(new ArrayBuffer(chrs.length));

    for(var i = 0 ; i < chrs.length ; i++) {
      array[i] = chrs.charCodeAt(i);
    }

    return new Float32Array(array.buffer);
  }

  // RUN CELL

  load_simulation();
  gl_and_shader = initialize_gl();
  var gl = gl_and_shader.gl;
  var shader = gl_and_shader.shader;
  update_frame();
</script>
''')

def encode(R):
  return base64.b64encode(onp.array(R, onp.float32).tobytes())

def render(box_size, states, obstacles=None, predators=None):
  R, theta = zip(*states)
  R = onp.stack(R)
  theta = onp.stack(theta)    
  
  if isinstance(predators, list):
    R_predators, theta_predators, *_ = zip(*predators)
    R_predators = onp.stack(R_predators)
    theta_predators = onp.stack(theta_predators)

  def get_boid_states(start, end):
    R_, theta_ = R[start:end], theta[start:end]
    return IPython.display.JSON(data={
        "R_shape": R_.shape,
        "R": encode(R_), 
        "theta_shape": theta_.shape,
        "theta": encode(theta_)
        })
  output.register_callback('notebook.GetBoidStates', get_boid_states)

  def get_obstacles():
    if obstacles is None:
      return IPython.display.JSON(data={})
    else:
      return IPython.display.JSON(data={
          'Disk': {
              'R': encode(obstacles.R),
              'R_shape': obstacles.R.shape,
              'D': encode(obstacles.D),
              'D_shape': obstacles.D.shape
          }
      })
  output.register_callback('notebook.GetObstacles', get_obstacles)

  def get_predators():
    if predators is None:
      return IPython.display.JSON(data={})
    else:
      return IPython.display.JSON(data={
          'R': encode(R_predators),
          'R_shape': R_predators.shape,
          'theta': encode(theta_predators),
          'theta_shape': theta_predators.shape
      })
  output.register_callback('notebook.GetPredators', get_predators)

  def get_simulation_info():
    return IPython.display.JSON(data={
        'frames': R.shape[0],
        'box_size': box_size
        })
  output.register_callback('notebook.GetSimulationInfo', get_simulation_info)

  return renderer_code

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [4]:
def interaction_potential(dR, J_avoid, D_avoid, alpha):
    dr = space.distance(dR) / D_avoid
    return np.where(dr < 1., 
                    J_avoid / alpha * (1 - dr) ** alpha, 
                    0.)

def energy_fn(state):
    E_interact = partial(interaction_potential, J_avoid=25., D_avoid=30., alpha=3.)
    E_interact = vmap(vmap(E_interact))

    dR = space.map_product(displacement_fn)(state,state)

    return 0.5 * np.sum(E_interact(dR))

In [5]:
@dataclasses.dataclass
class ActiveBrownianState:
    """A tuple containing state information for Brownian dynamics.

    Attributes:
    position: The current position of the particles. An ndarray of floats with
      shape `[n, spatial_dimension]`.
    mass: The mass of particles. Will either be a float or an ndarray of floats
      with shape `[n]`.
    rng: The current state of the random number generator.
    """
    position: util.Array
    theta: util.Array
    rng: util.Array


In [6]:
T = TypeVar('T')
InitFn = Callable[..., T]
ApplyFn = Callable[[T], T]

def activeBrownian(energy_or_force: Callable[..., util.Array],
    shift: space.ShiftFn,
    dt: float,
    tau: float,
    v0: float=0.1) -> Tuple[InitFn, ApplyFn]:
    """Simulation of active Brownian dynamics.

    Simulates active Brownian dynamics which are synonymous with the overdamped
    regime of Langevin dynamics with self-propulsion force. 
    Args:
    energy_or_force: A function that produces either an energy or a force from
    a set of particle positions specified as an ndarray of shape
    `[n, spatial_dimension]`.
    shift_fn: A function that displaces positions, `R`, by an amount `dR`.
    Both `R` and `dR` should be ndarrays of shape `[n, spatial_dimension]`.
    dt: Floating point number specifying the timescale (step size) of the
    simulation.
    tau: Floating point number specifying persistence timescale.
    v0: Floating point number specifying active force.

    Returns:
        See above.
    """
    # convert energy functions to forces
    dt, tau, v0 = util.static_cast(dt, tau, v0)
    force_fn = quantity.canonicalize_force(energy_or_force)
    
    def init_fn(R, theta, key):
        return ActiveBrownianState(R, theta, key)  # pytype: disable=wrong-arg-count

    @vmap
    def normal(theta):
        return np.array([np.cos(theta), np.sin(theta)])

    def apply_fn(_, state, **kwargs):
    
        # Combine movement functionality into a `move` function.
        R, theta, key = dataclasses.astuple(state)
        
        key, split = random.split(key)
        eta = random.normal(split, theta.shape, theta.dtype)

        F_int = force_fn(R)
        
        dR = dt * (v0*normal(theta)+ F_int)
        R = shift(R, dR, **kwargs)
        
        theta = theta + (dt*util.f32(2)/tau)**(1/2) * eta

        return ActiveBrownianState(R,theta,key)

    return init_fn, apply_fn

In [7]:
# Create RNG state to draw random numbers
key = random.PRNGKey(0)

# Simulation Parameters:
poly = 0.3
box_size = 2000.0
Nparticles = 10
dt = 1e-2
tau = 20
v0 = 10
dim = 2


In [49]:
rn.randint(0,100000)

84070

In [51]:
state_buffer = []
for i in tqdm(range(100)):

  # Define periodic boundary conditions.
  displacement_fn, shift_fn = space.periodic(box_size)

  # Define simulation function
  init_fn, apply_fn = activeBrownian(energy_fn, shift_fn, dt, tau, v0)
 

  key = random.PRNGKey(rn.randint(0,10000))
  # Initialize the particle positions, theta
  rng, R_rng, theta_rng = random.split(key, 3)

  R = box_size* 0.3 * random.uniform(R_rng, (Nparticles, dim))+ box_size *0.3
  theta = random.uniform(theta_rng, (Nparticles,), maxval= 2.* np.pi)
  #print(R[0])

  state = init_fn(R, theta, key)
  t0 = time.time() 
  for i in range(50):
      state = lax.fori_loop(0, 50, apply_fn, (state))
      state_buffer += [(state.position, state.theta)]


  tend =time.time()

100%|██████████| 100/100 [00:47<00:00,  2.09it/s]


In [52]:
state_buffer[0][0][0]

DeviceArray([612.46150912, 671.22848774], dtype=float64)

In [53]:
state_buffer[50][0][0]

DeviceArray([ 879.30735117, 1128.70646243], dtype=float64)

In [None]:
display(render(box_size,state_buffer))

In [55]:
import pandas as pd

df = pd.DataFrame(state_buffer, columns=['position','theta'])

In [56]:
import itertools
from iteration_utilities import deepflatten

df['trajectory']  = df.position
for i in range(len(df)):
  df['trajectory'][i] = [x for xs in zip(df.position[i], df.theta[i]) for x in xs]  
  df['trajectory'][i] = list(deepflatten(df['trajectory'][i]))

In [57]:
df.head()

Unnamed: 0,position,theta,trajectory
0,"[[612.4615091166613, 671.2284877436346], [835....","[3.8789045101581965, 5.094119628502906, 2.4057...","[612.4615091166613, 671.2284877436346, 3.87890..."
1,"[[608.2784306963905, 668.5606044022409], [837....","[3.641553513473107, 4.900716582029103, 2.36850...","[608.2784306963905, 668.5606044022409, 3.64155..."
2,"[[603.6809883716359, 666.6126129587349], [837....","[3.4823290865418266, 4.838578481566317, 2.1151...","[603.6809883716359, 666.6126129587349, 3.48232..."
3,"[[598.8747251914682, 665.2576338706056], [837....","[3.2916996862129553, 4.804415418578076, 2.3005...","[598.8747251914682, 665.2576338706056, 3.29169..."
4,"[[594.0605258096285, 663.9555719254632], [838....","[3.4664108920571457, 4.786894881710787, 2.3705...","[594.0605258096285, 663.9555719254632, 3.46641..."


In [58]:
df.shape

(5000, 3)

In [59]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [60]:
df.to_csv('/content/drive/MyDrive/GNNs/boids_buffer.csv')