<a href="https://colab.research.google.com/github/GAlonzoS/AI-Institute-June-2023-Workshop/blob/main/RL_Small_Batch_Self_Assembly.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports and dependencies

In [None]:
-%%capture
# Install dependents libraries
!sudo apt-get update
!sudo apt-get install -y xvfb ffmpeg freeglut3-dev
!pip install 'imageio==2.4.0'
#!pip install pyvirtualdisplay
#!pip install tf-agents[reverb]
#!pip install pyglet
#!pip install tensorflow

!pip install -q git+https://www.github.com/google/jax-md
!pip install --upgrade pip
!pip install --upgrade "jax[cpu]"
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


In [None]:
#Assign Library Calls

import time
import functools
import freud
import base64
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image

from jax.config import config
config.update('jax_enable_x64', True)
import jax.numpy as jnp
from jax.scipy import ndimage
from jax import jit, grad, value_and_grad, tree_map, vmap, random, lax, ops, remat
from jax_md import space, minimize, simulate, energy, quantity, smap, dataclasses, rigid_body
from jax.example_libraries import stax
from jax.example_libraries import optimizers
from jax_md.util import *
from jax_md import dataclasses
from functools import partial
from typing import Optional, Tuple, Dict, Callable, List, Union
import jax
from jax_md import util

from jax_md.colab_tools import renderer
Array = util.Array
f64 = util.f64
import argparse
import os
from pathlib import Path
from google.colab import files

## Check hardware and software

In [None]:
!nvidia-smi

/bin/bash: nvidia-smi: command not found


In [None]:
def format_plot(x, y):
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)

def finalize_plot(shape=(1, 1)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1],
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

f32 = np.float32
f64 = np.float64

def draw_system(R, box_size, marker_size, color=None):
  if color == None:
    color = [64 / 256] * 3
  ms = marker_size / box_size

  R = jnp.array(R)

  marker_style = dict(
      linestyle='none',
      markeredgewidth=3,
      marker='o',
      markersize=ms,
      color=color,
      fillstyle='none')

  plt.plot(R[:, 0], R[:, 1], **marker_style)
  plt.plot(R[:, 0] + box_size, R[:, 1], **marker_style)
  plt.plot(R[:, 0], R[:, 1] + box_size, **marker_style)
  plt.plot(R[:, 0] + box_size, R[:, 1] + box_size, **marker_style)
  plt.plot(R[:, 0] - box_size, R[:, 1], **marker_style)
  plt.plot(R[:, 0], R[:, 1] - box_size, **marker_style)
  plt.plot(R[:, 0] - box_size, R[:, 1] - box_size, **marker_style)

  plt.xlim([0, box_size])
  plt.ylim([0, box_size])
  plt.axis('off')


def plot_pos_now(state):
  ms = 65 #marker size

  R_plt = np.array(state.position)

  plt.plot(R_plt[:,0],R_plt[:,1], 'o', markersize=ms*0.5)


  plt.xlim([0, np.max(R_plt[:, 0])])
  plt.ylim([0, np.max(R_plt[:, 1])])

  plt.axis('off')

  finalize_plot((1, 1))

Functions for Shapes

In [None]:
def square_lattice(N, box_size):
  Nx = int(np.sqrt(N))
  Ny, ragged = divmod(N, Nx)
  if Ny != Nx or ragged:
    assert ValueError('Particle count should be a square. Found {}.'.format(N))
  length_scale = box_size / Nx
  R = []
  for i in range(Nx):
    for j in range(Ny):
      R.append([i * length_scale, j * length_scale])
  return np.array(R)

def make_square_shape(center_rad):
  square_shape = jnp.array([[0.,0.],[0. 2*center_rad],
                            [2*center_rad, 0.],
                            [2*center_rad,2*center_rad]])
  return square_shape

def make_square_circle(center_rad):
  ring_shape = 2*center_rad * jnp.array([[0., 0.],
                                                [1., 0.],
                                                [2., 0.],
                                                [2., 1.],
                                                [2., 2.],
                                                [1., 2.],
                                                [0., 2.],
                                                [0., 1.]
                                                ])
  return ring_shape

def make_triangle(center_rad):
  triangle_shape = 2*center_rad * jnp.array([[0., 0.],
                                                 [1., 0.],
                                                 [0.5, jnp.sqrt(3)/2.]])
  return triangle_shape



In [None]:
def my_render(box_size, states):

  renderer.render(box_size,
                renderer.Disk(states),
                resolution=(700, 700))
def plot_average_over_time(data):
    averages = []
    cumulative_sum = 0

    for i, value in enumerate(data):
        cumulative_sum += value
        average = cumulative_sum / (i + 1)  # Calculate average up to current point
        averages.append(average)

    # Plotting
    plt.plot(t, averages, linewidth = 4)

# Morse Potential Simulation

In [None]:
#Sim Setup param
N = 4 #number of particles
dimension = 2
density = .2
sigma = .5
vol = quantity.particle_volume(radii = sigma, spatial_dimension = dimension, particle_count = N)
box_size  = np.sqrt(N*(np.pi *.5**2 / vol)/(density/vol))

dt = 5e-4
displacement, shift = space.periodic(box_size)
kT = .4
epsilon = 2*(1. - jnp.eye(N)) #interaction matrix

batch_size = 12
filename = 'square_opt'
center_radius = floor(box_size/5)

steps = 100000
write_every = 1000

print(box_size)

In [None]:
key = random.PRNGKey(0)
key, split = random.split(key)

In [None]:
def create_species(N,k):
  array_length = max(N, k)
  spec_list = np.arange(array_length) % (k)
  return np.sort(spec_list)

species = create_species(N,4)

In [None]:
target_shape = make_square_shape(center_radius)

In [None]:
#implement this later
def make_patch_particle(patch_num, patch_loc, patch_epsilon):

  return


In [None]:
energy_fn = energy.soft_sphere_pair(displacement,species=species, sigma=sigma, epsilon = epsilon)
init, gapply = simulate.nvt_nose_hoover(energy_fn, shift, dt, kT)  #same as above
Rcheck = square_lattice(N, box_size)
gstate = init(key,Rcheck)


In [None]:
def get_BOO(displacement_all, k=6):

  def weight(r, r0=1.1, alpha=10):
    return jnp.where(r<1e-7, 0., 1.0/(1 + jnp.exp(alpha*(r - r0))))

  i_imaginary = complex(0,1)
  def get_ylms(dR_ij):
    epsilon = 0.00001 #avoids nan in derivative
    dR_ij = jnp.where(dR_ij==0.0, epsilon, dR_ij)
    theta = jnp.arctan2(dR_ij[1], dR_ij[0])
    return jnp.exp(i_imaginary*k*theta)

  def calculate_order_param(R):
    v_get_ylms = vmap(vmap(get_ylms))
    ds = displacement_all(R, R)
    r = space.distance(ds)
    w = weight(r)
    q_6m = (jnp.sum(v_get_ylms(ds)*w, axis=0)/(jnp.sum(w, axis=0)+0.00001))
    return jnp.abs(jnp.mean(q_6m))

  return calculate_order_param

#displacement_all = vmap(vmap(displacement, (0, None), 0), (None, 0), 0)
#order_param_fn = get_BOO(displacement_all, k=6)

In [None]:
def initial_setup(e, key):

  #Set up simulation parameters
  displacement, shift = space.periodic(box_size)

  #Set a random initial state
  key, split = random.split(key)

  #Random Positions
  R = box_size * random.uniform(split, (N,dimension), dtype=np.float64)


  #Compile the dynamics
  energy_fn = energy.soft_sphere_pair(displacement, species=species, sigma=sigma, epsilon = e)
  init, apply = simulate.nvt_nose_hoover(energy_fn, shift, dt, kT)  #not sure if I really need this part but it does not appear to slow things down

  #Generate the initial state
  state = init(key,R)

  #Run the dynamics
  log = {
      'kT': np.zeros((steps,)),
      'H': np.zeros((steps,)),
      'position': np.zeros((steps // write_every,) + R.shape),
      'Psi_K': np.zeros((steps // write_every, N)),
      'Avg_Psi_K': np.zeros((steps // write_every,))
  }

  #state, log = lax.fori_loop(0, steps, step_fn, (state, log))
  return state.position,  log

In [None]:
def mydynamics(R,e, log, key):

  #Set up simulation parameters
  displacement, shift = space.periodic(box_size)

  #Compile the dynamics
  energy_fn = energy.soft_sphere_pair(displacement, species=species, sigma=sigma, epsilon = e)
  init, apply = simulate.nvt_nose_hoover(energy_fn, shift, dt, kT)  #same as above
  key, split = random.split(key)
  apply = jit(apply)
  state = init(split, R)


  #Compute the order parameter
  displacement_all = vmap(vmap(displacement, (0, None), 0), (None, 0), 0)
  order_param_fn = get_BOO(displacement_all, k=6)


  def step_fn(i, state_and_log):
    state,log = state_and_log

    t = i*dt

    # Log information about the simulation
    T = quantity.temperature(momentum=state.momentum)
    log['kT'] = log['kT'].at[i].set(T)

    H = simulate.nvt_nose_hoover_invariant(energy_fn,state,kT)
    log['H'] = log['H'].at[i].set(H)

    log['position'] = lax.cond(i % write_every ==0,
                              lambda p:\
                              p.at[i// write_every].set(state.position), #this is a floor, keep p at every write_every'th step
                              lambda p:p,
                              log['position'])
    r = state.position.reshape(N,dimension)
    _psi_k = order_param_fn(r)
    log['Psi_K'] = lax.cond(i % write_every == 0,
                            lambda w:\
                            w.at[i// write_every].set(_psi_k),
                            lambda w:w,
                            log['Psi_K'])

    _avg_Psi_K = jnp.mean(_psi_k)
    log['Avg_Psi_K'] = lax.cond(i % write_every == 0,
                            lambda v:\
                            v.at[i// write_every].set(_avg_Psi_K),
                            lambda v:v,
                            log['Avg_Psi_K'])

    #Take a simulation step
    state = apply(state, kT=kT)
    return state, log

  #Run the dynamics

  sim_steps = jnp.arange(steps)
  state, log = lax.fori_loop(0, steps, step_fn, (state, log))

  return state, log

In [None]:
  #might not be needed
  def my_loss_function(x_k, x_goal):
    myloss = np.sqrt(jnp.mean(order_param_fn(x_k))**2 - jnp.mean(order_param_fn(x_goal))**2)
    return myloss

This part is already taken care of in the dynamics function. I probably do not need this.... maybe

In [None]:
target_rate = 100.0
# interaction matrix is epsilon
# Not sure what endpoints are, I will assume it is the position of the particles?
@jit
def compute_path(interaction_matrix, endpoints):
  DNEB_results = run_DNEB(endpoints, interaction_matrix)
  path_positions, path_energies = DNEB_results
  return path_positions, path_energies

@jit
def compute_loss(interaction_matrix, target_rate, path_positions):
  path_energies = vmap(energy_fn, (0, None))(path_positions, interaction_matrix)
  measured_rate = calc_transition_rate(path_energies, kT)
  return (target_rate - measured_rate)**2

In [None]:
grad_compute_loss = value_and_grad(compute_loss, 0)

In [None]:
opt_steps = 100
lr_steps = opt_steps // 3
learning_rates = jnp.array([0.01] * lr_steps + [0.001] * 2*lr_steps)
learning_rate = lambda t: learning_rates[t]

# Choose an optimizer (or write your own!)
opt_init, opt_update, get_params = optimizers.adam(step_size=learning_rate)

'''
Given the current state of an optimizer and the gradient, update the parameters
'''
def opt_step(i, opt_state, endpoints):
    params = get_params(opt_state)
    lossval, gradval = grad_compute_loss(params, target_rate, gstate.position)
    print("Loss: {}".format(lossval))
    # print("Parameters: {}".format(params))
    # print("Gradient: {}".format(gradval))
    return opt_update(i, gradval, opt_state), params, lossval

In [None]:
################ Loss function ##################

def get_desired_dists(ref_shape):
  displacement, shift = space.periodic(box_size)
  vdisp = space.map_product(displacement)
  ds = vdisp(ref_shape, ref_shape)
  dists = jnp.sort(space.distance(ds))
  return dists

REF_DISTS = get_desired_dists(target_shape)

@jit
def sys_loss(R):
  displacement, shift = space.periodic(box_size)
  vdisp = space.map_product(displacement)
  ds = jnp.sort(space.distance(vdisp(R, R)))
  subtract = lambda R, Rref: R - Rref
  v_subtract = space.map_product(subtract)
  diffs = v_subtract(ds[:, :len(REF_SHAPE)], REF_DISTS)
  nearest_nbrs_match_ref_dist = jnp.min(jnp.linalg.norm(diffs, axis=-1), axis=0)
  other_nbrs_far = ds[:, REF_SHAPE_SIZE:CLOSENESS_PENALTY_NEIGHBORS + REF_SHAPE_SIZE]
  return jnp.sum(nearest_nbrs_match_ref_dist) - CLOSENESS_PENALTY * jnp.mean(other_nbrs_far)

v_loss = vmap(sys_loss)


@jit
def avg_loss(R_batched):
  losses = v_loss(R_batched)
  return jnp.mean(losses)

In [None]:
opt_state = opt_init(epsilon)
endpoints = jnp.array(gstate.position) # Make this function of N points in a square lattice?
params = epsilon
best_params = params
min_loss = jnp.inf

for i in range(opt_steps):
  # Recompute the DNEB path periodically rather than every step
  # to save compute time
  if i % 5 == 0:
      path_positions, path_energies, compute_path(params, endpoints)
      endpoints = jnp.array([path_positions[0], path_positions[-1]])
  opt_state, new_params, loss = opt_step(i, opt_state, endpoints)
  if loss < min_loss:
    min_loss = loss
    best_params = params
  params = new_params

  if i % (opt_steps // 3) == 0:
    opt_state = opt_init(best_params)
    params = best_params

print('Minimum loss value: ', min_loss)

NameError: ignored

In [None]:
opt_steps = 100
lr_steps = opt_steps // 3
learning_rates = jnp.array([0.01] * lr_steps + [0.001] * 2*lr_steps)