## Setup

#### Imports

In [None]:
!pip install Divergence-Free-Interpolant
!pip install scipy

In [None]:
import torch
import numpy as np
import math
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import Divergence_Free_Interpolant as dfi
import sys
import os

sys.path.insert(0, os.path.abspath('../pytorch-physics/'))
from boid import Flock

#### Helper-Functions

In [None]:
def histogram(xs, bins):
    min, max = xs.min(), xs.max()
    counts = torch.histc(xs, bins, min=min, max=max)
    boundaries = torch.linspace(min, max, bins + 1)
    return counts, boundaries

def positions_to_grid(observations, grid, box_top):
    observations = torch.clamp(observations, 0, box_top)
    bin_size = box_top / grid
    
    x = observations[:, 0]
    y = observations[:, 1]
    x_bins = (x / bin_size).long()
    y_bins = (y / bin_size).long()
    x_bins = torch.clamp(x_bins, 0, grid - 1)
    y_bins = torch.clamp(y_bins, 0, grid - 1)
    
    indices = y_bins * grid + x_bins
    hist = torch.bincount(indices, minlength=grid*grid).reshape(grid, grid)
    
    return hist

#### Visualizations

In [None]:
def visualize_boids(list_of_birds_pos, list_of_birds_vel, box_top):
    plt.close('all')
    
    list_of_birds_pos = [list_of_birds_pos[i].cpu().numpy() for i in range(len(list_of_birds_pos))]
    list_of_birds_vel = [list_of_birds_vel[i].cpu().numpy() for i in range(len(list_of_birds_vel))]
    
    # offset = math.floor(box_top / 50)
    offset = 0
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(offset, box_top - offset)
    ax.set_ylim(offset, box_top - offset)
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    
    initial_positions = list_of_birds_pos[0]
    initial_velocities = list_of_birds_vel[0]
    
    scatter = ax.scatter([], [], c='orange', s=10)
    
    quiver = ax.quiver(initial_positions[:, 0], initial_positions[:, 1], 
                       initial_velocities[:, 0], initial_velocities[:, 1], 
                       angles='xy', scale_units='xy', scale=1, color='blue')
    
    def update(frame):
        positions = list_of_birds_pos[frame]
        velocities = list_of_birds_vel[frame]
    
        scatter.set_offsets(positions)
    
        quiver.set_offsets(positions)
        quiver.set_UVC(velocities[:, 0], velocities[:, 1])
    
        return scatter, quiver
    
    ani = FuncAnimation(fig, update, frames=len(list_of_birds_pos), interval=50, blit=True)
    plt.close(fig)
    
    return HTML(ani.to_html5_video())

In [None]:
def plot_observations_and_velocities_to_grid(observation, velocities, box_top):
    plt.close('all')
    
    offset = 0
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(offset, box_top - offset)
    ax.set_ylim(offset, box_top - offset)
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')

    scatter = ax.scatter([], [], c='orange', s=10)
    
    quiver = ax.quiver(velocities[:, 0], velocities[:, 1],
                       velocities[:, 0], velocities[:, 1],
                       angles='xy', scale_units='xy', scale=1, color='blue')
    

    scatter.set_offsets(observation)
    quiver.set_offsets(observation)

    plt.show()
    plt.close(fig)

In [None]:
def plot_observations_to_grid(observation, box_top):
    fig, ax = plt.subplots()
    ax.set_xlim(0, box_top)
    ax.set_ylim(0, box_top)
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    
    scatter = ax.scatter([], [], c='orange', s=10)
    scatter.set_offsets(observation.cpu().numpy()[:])

    plt.show()
    plt.close(fig)

In [None]:
def visualize_vectorfield_and_interp_vectorfield(X, Y, U, S, V, XX, YY, UU, SS, VV):
    div = lambda n, d: np.divide(n, d, out = np.zeros_like(d), where=d!=0)
    
    fig, ax = plt.subplots(1, 1)
    quiver = ax.quiver(X, Y, div(U, S), div(V, S), S, cmap='autumn')
    ax.set_aspect('equal')
    ax.set_xlim(X.min(), X.max())
    ax.set_ylim(Y.min(), Y.max())
    fig.colorbar(quiver)
    plt.show()
    plt.close()
    
    # Visualize interpolated field
    fig, ax = plt.subplots(1,1)
    stream = ax.streamplot(XX.T, YY.T, div(UU, SS).T, div(VV, SS).T, color=SS.T, density=1, cmap='autumn')
    fig.colorbar(stream.lines)
    ax.set_aspect('equal')
    plt.show()
    plt.close()

In [None]:
def plot_histogram(hist, box_top):
    fig, ax = plt.subplots()
    hist_normalized = hist / hist.max()

    cmap = plt.get_cmap('viridis')
    im = ax.imshow(hist_normalized.cpu().numpy(), cmap=cmap, extent=[0, box_top, 0, box_top], origin='lower', aspect='equal')
    
    # cbar = plt.colorbar(im)
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    plt.show()
    plt.close(fig)

In [None]:
def plot_histogram_as_point_plot(histogram, box_top):
    fig, ax = plt.subplots()
    ax.set_xlim(0, box_top)
    ax.set_ylim(0, box_top)
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    
    grid_size = histogram.shape[0]
    y_indices, x_indices = torch.nonzero(histogram, as_tuple=True)
    x_coords = x_indices.float() * (box_top / grid_size)
    y_coords = y_indices.float() * (box_top / grid_size)
    
    points = torch.stack((x_coords, y_coords), dim=1)
    scatter = ax.scatter(points[:, 0].cpu().numpy(), points[:, 1].cpu().numpy(), 
                         c='orange', s=10)

    plt.show()
    plt.close(fig)

In [None]:
def visualize_boids_as_histograms(list_of_birds_pos, grid_size, box_top):
    plt.close('all')
    
    offset = 0
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(offset, box_top - offset)
    ax.set_ylim(offset, box_top - offset)
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    
    scatter = ax.scatter([], [], c='orange', s=10)

    def update(frame):
        histogram = positions_to_grid(list_of_birds_pos[frame], grid_size, box_top)

        y_indices, x_indices = torch.nonzero(histogram, as_tuple=True)
        x_coords = x_indices.float() * (box_top / grid_size)
        y_coords = y_indices.float() * (box_top / grid_size)
    
        points = torch.stack((x_coords, y_coords), dim=1)
        scatter.set_offsets(points.cpu().numpy())
        return scatter,
    
    ani = FuncAnimation(fig, update, frames=len(list_of_birds_pos), interval=50, blit=True)
    plt.close(fig)
    
    return HTML(ani.to_html5_video())

#### Alternative Configurations

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 30

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 0.02,
    'max_speed': 1,
    'max_acc': 0.1,
    
    'view_radius': 3,
    'view_angle': None,
    
    'avoid_radius': 5,
    'avoid_view': True,
    
    'sep_factor': 0.05,    # avoidfactor
    'align_factor': 0.05,  # matchingfactor
    'cohe_factor': 0.0005,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 300

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 3,
    'max_speed': 6,
    'max_acc': 0.1,
    
    'view_radius': 20,
    'view_angle': None,
    
    'avoid_radius': 15,
    'avoid_view': True,
    
    'sep_factor': 0.05,    # avoidfactor
    'align_factor': 0.05,  # matchingfactor
    'cohe_factor': 0.05,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 300

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 3,
    'max_speed': 6,
    'max_acc': 0.1,
    
    'view_radius': 10,
    'view_angle': None,
    
    'avoid_radius': 8,
    'avoid_view': True,
    
    'sep_factor': 0.3,    # avoidfactor
    'align_factor': 0.05,  # matchingfactor
    'cohe_factor': 0.01,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 300

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 3,
    'max_speed': 6,
    'max_acc': 0.1,
    
    'view_radius': 7,
    'view_angle': None,
    
    'avoid_radius': 8,
    'avoid_view': True,
    
    'sep_factor': 0.5,    # avoidfactor
    'align_factor': 0.5,  # matchingfactor
    'cohe_factor': 0.01,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

In [None]:
# Circular behavior

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 300

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 3,
    'max_speed': 6,
    'max_acc': 0.3,
    
    'view_radius': 7,
    'view_angle': None,
    
    'avoid_radius': 8,
    'avoid_view': True,
    
    'sep_factor': 0.5,    # avoidfactor
    'align_factor': 0.5,  # matchingfactor
    'cohe_factor': 0.01,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

In [None]:
# temporary cool, but then one unit
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 300

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 3,
    'max_speed': 6,
    'max_acc': 0.2,
    
    'view_radius': 10,
    'view_angle': None,
    
    'avoid_radius': 8,
    'avoid_view': True,
    
    'sep_factor': 0.5,    # avoidfactor
    'align_factor': 2,  # matchingfactor
    'cohe_factor': 0.005,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

In [None]:
# Circular-ish but stable
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 300

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 3,
    'max_speed': 6,
    'max_acc': 0.3,
    
    'view_radius': 7,
    'view_angle': None,
    
    'avoid_radius': 8,
    'avoid_view': True,
    
    'sep_factor': 0.5,    # avoidfactor
    'align_factor': 3,  # matchingfactor
    'cohe_factor': 0.005,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

In [None]:
# stable (ciruclar) formation floating through the screen
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 400
box_top = 300

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': False,
    'bouncy_edges': True,
    'device': device,
}

boid_args = {
    'init_speed': None,
    'min_speed': 3,
    'max_speed': 6,
    'max_acc': 0.5,
    
    'view_radius': 10,
    'view_angle': None,
    
    'avoid_radius': 8,
    'avoid_view': True,
    
    'sep_factor': 0.5,    # avoidfactor
    'align_factor': 2,  # matchingfactor
    'cohe_factor': 0.005,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

#### Config-Values

In [None]:
# temporary cool, but then one unit
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
N = 100
box_top = 20

flock_args = {
    'D': 2,
    'N': N,
    'box_top': box_top,
    'pass_through_edges': True,
    'bouncy_edges': False,
    'device': device,
}

boid_args = {
    'init_speed': None,
    # 'min_speed': 3,
    # 'max_speed': 6,
    # 'max_acc': 0.5,

    'min_speed': 3/5,
    'max_speed': 6/5,
    'max_acc': 0.5/5,
    
    'view_radius': 10,
    'view_angle': None,
    
    'avoid_radius': 8,
    'avoid_view': True,
    
    'sep_factor': 0.5,    # avoidfactor
    'align_factor': 2,  # matchingfactor
    'cohe_factor': 0.005,  # centeringfactor
    'bias_factor': 0.005,
    'edge_factor': 0.05,
    
    'is_debug': False
}

## Simulation and Visualization

##### Helper

In [None]:
def normalize_vectors(vectors, desired_length=0.1):
    norms = np.linalg.norm(vectors, axis=1, keepdims=True)
    norms[norms < 1e-10] = 1
    normalized_vectors = vectors / norms * desired_length
    
    return normalized_vectors

##### Experiments

In [None]:
flock = Flock(
    **flock_args,
    **boid_args
)

list_of_birds_pos = []
list_of_birds_vel = []

iterations = 150

for _ in range(iterations):
    flock.update()
    list_of_birds_pos.append(flock.pos)
    list_of_birds_vel.append(flock.vel)

vid_len = 50
visualize_boids(list_of_birds_pos[iterations-vid_len:], list_of_birds_vel[iterations-vid_len:], box_top)

In [None]:
scaling = 20
flock_pos = flock.pos.cpu().numpy() / scaling
flock_vel = flock.vel.cpu().numpy() / scaling
plot_observations_and_velocities_to_grid(flock_pos, flock_vel, box_top / scaling)

In [None]:
scaling = 20
flock_pos = flock.pos.cpu().numpy() / scaling
flock_vel = flock.vel.cpu().numpy()

X, Y = flock_pos[:, 0], flock_pos[:, 1]
UV = flock_vel
UV = np.array([UV[:, 0], UV[:, 1]])
U, V = UV[0], UV[1]

# interp = interpolant(dim=2)
initialized_interpolant = dfi.interpolant(nu = 5, k = 3, dim = 2)
# penrose, lstsq, linsolve
initialized_interpolant.condition(np.array([X, Y]).T, UV.T, 1, method="linsolve")

x_min, x_max = flock_pos[:, 0].min(), flock_pos[:, 0].max()
y_min, y_max = flock_pos[:, 1].min(), flock_pos[:, 1].max()

grid_size = 100
x = np.linspace(0, box_top / scaling, grid_size)
y = np.linspace(0, box_top / scaling, grid_size)
X, Y = np.meshgrid(x, y)

velocities = initialized_interpolant(X, Y)

# visualization
subset_size = 40
x_subset = np.linspace(0, grid_size-1, subset_size, dtype=int)
y_subset = np.linspace(0, grid_size-1, subset_size, dtype=int)
X_subset, Y_subset = np.meshgrid(x_subset, y_subset)

# Extract positions for the subset
point_pos = np.column_stack((X[X_subset, Y_subset].ravel(), Y[X_subset, Y_subset].ravel()))  # shape (100, 2)
point_vel = velocities[X_subset, Y_subset].reshape(-1, 2)

normalized_velocities = normalize_vectors(point_vel, desired_length=0.02)
plot_observations_and_velocities_to_grid(point_pos, normalized_velocities, box_top / scaling)

In [None]:
scaling = 20
flock_pos = flock.pos.cpu().numpy() / scaling
flock_vel = flock.vel.cpu().numpy()

X, Y = flock_pos[:, 0], flock_pos[:, 1]
UV = flock_vel
UV = np.array([UV[:, 0], UV[:, 1]])
U, V = UV[0], UV[1]
S = (U**2 + V**2)**0.5

# penrose, lstsq, linsolve
initialized_interpolant = dfi.interpolant(nu=5, k=3, dim=2)
initialized_interpolant.condition(np.array([X, Y]).T, UV.T, 1, method="linsolve")

_n, _m = 100, 100
x_min, x_max = X.min(), X.max()
y_min, y_max = Y.min(), Y.max()
XX, YY = np.mgrid[x_min:x_max:_n*1j, y_min:y_max:_m*1j]
UV = initialized_interpolant(XX, YY)
UU = UV[:,:,0]
VV = UV[:,:,1]
SS = (UU**2 + VV**2)**0.5

visualize_vectorfield_and_interp_vectorfield(X, Y, U, S, V, XX, YY, UU, SS, VV)

In [None]:
np.random.seed(69)
div = lambda n, d: np.divide(n, d, out = np.zeros_like(d), where=d!=0)

# vector_field = lambda x, y: np.array([-2*x**3 * y, 3*x**2 * y**2])
vector_field = lambda x, y: np.array([x+1, y+1])
# vector_field = lambda x, y: np.array([x*0 + 1, y*0 + 1])

N = 250
X, Y = np.random.rand(N), np.random.rand(N)
UV = vector_field(X, Y)
U, V = UV[0], UV[1]
S = (U**2 + V**2)**0.5

initialized_interpolant = dfi.interpolant(nu = 5, k = 3, dim = 2)
initialized_interpolant.condition(np.array([X, Y]).T, UV.T, 1)

_n, _m = 100, 100
XX, YY = np.mgrid[0:1:_n*1j, 0:1:_m*1j]
UV = initialized_interpolant(XX, YY)
UU = UV[:,:,0]
VV = UV[:,:,1]
SS = (UU**2 + VV**2)**0.5

visualize_vectorfield_and_interp_vectorfield(X, Y, U, S, V, XX, YY, UU, SS, VV)

##### Longer Iterations

In [None]:
# list_of_birds_pos = []
# list_of_birds_vel = []

# iterations = 500

# for _ in range(iterations):
#     flock.update()
#     list_of_birds_pos.append(flock.pos)
#     list_of_birds_vel.append(flock.vel)

# vid_len = 50
# visualize_boids(list_of_birds_pos[iterations-vid_len:], list_of_birds_vel[iterations-vid_len:], box_top)

In [None]:
# list_of_birds_pos = []
# list_of_birds_vel = []

# iterations = 3000

# for _ in range(iterations):
#     flock.update()
#     list_of_birds_pos.append(flock.pos)
#     list_of_birds_vel.append(flock.vel)

# vid_len = 50
# visualize_boids(list_of_birds_pos[iterations-vid_len:], list_of_birds_vel[iterations-vid_len:], box_top)

In [None]:
# list_of_birds_pos = []
# list_of_birds_vel = []

# iterations = 5000

# for _ in range(iterations):
#     flock.update()
#     list_of_birds_pos.append(flock.pos)
#     list_of_birds_vel.append(flock.vel)

# vid_len = 50
# visualize_boids(list_of_birds_pos[iterations-vid_len:], list_of_birds_vel[iterations-vid_len:], box_top)

## Graph-Network

#### Data to Graph

In [None]:
a = list_of_birds_pos[100]
b = list_of_birds_pos[101]

In [None]:
def histogram(xs, bins):
    min, max = xs.min(), xs.max()
    counts = torch.histc(xs, bins, min=min, max=max)
    boundaries = torch.linspace(min, max, bins + 1)
    return counts, boundaries

def positions_to_grid(observations, grid, box_top):
    observations = torch.clamp(observations, 0, box_top)
    bin_size = box_top / grid
    
    x = observations[:, 0]
    y = observations[:, 1]
    x_bins = (x / bin_size).long()
    y_bins = (y / bin_size).long()
    x_bins = torch.clamp(x_bins, 0, grid - 1)
    y_bins = torch.clamp(y_bins, 0, grid - 1)
    
    indices = y_bins * grid + x_bins
    hist = torch.bincount(indices, minlength=grid*grid).reshape(grid, grid)
    
    return hist

In [None]:
grid_size = box_top * 1
hist_observations = positions_to_grid(a, grid_size, box_top)

plot_histogram_as_point_plot(hist_observations, box_top)

In [None]:
vid_len = 50
grid_size = box_top * 1
visualize_boids_as_histograms(list_of_birds_pos[iterations-vid_len:], grid_size, box_top)

#### Graph-Neural Networks