In [None]:
import swarmrl as srl

from swarmrl.observables.top_down_image import TopDownImage
import numpy as np
import matplotlib.pyplot as plt
from swarmrl.components import Colloid
import open3d as o3d
import logging
import flax.linen as nn
from swarmrl.tasks.dummy_task import DummyTask
import optax
from swarmrl.actions.mpi_action import MPIAction
from swarmrl.engine.gaurav_sim import *
from swarmrl.trainers.global_continuous_trainer import GlobalContinuousTrainer as Trainer
import pint


In [None]:
logging.basicConfig(level=logging.WARNING)


resolution=128

rafts = o3d.io.read_triangle_mesh("rafts.stl")

obs = TopDownImage(
    np.array([10000.0, 10000.0, 0.1]), image_resolution=np.array([resolution]*2), particle_type=0, custom_mesh=rafts, is_2D=True
)
task = DummyTask()


class ActoCriticNet(nn.Module):
    """A simple dense model."""

    @nn.compact
    def __call__(self, x):
        x = x.flatten()
        y = nn.Dense(features=2)(x)
        x = nn.Dense(features=2)(x)
        x = nn.relu(x)
        y = nn.relu(y)

        y = nn.Dense(features=12)(y)
        x = nn.Dense(features=12)(x)
        x = nn.relu(x)
        y = nn.relu(y)
        y = nn.Dense(features=1)(x)  # Critic
        x = nn.Dense(features=32)(x)  # Actor
        x = nn.relu(x)
        return x, y


exploration_policy = srl.exploration_policies.RandomExploration(probability=0.1)

# Define a sampling_strategy
sampling_strategy = srl.sampling_strategies.ContinuousGaussianDistribution()

# Value function to use
value_function = srl.value_functions.GlobalExpectedReturns(gamma=0.1, standardize=True)

# Define the model
actor_critic = ActoCriticNet()



In [None]:
ureg = pint.UnitRegistry()
Q_ = ureg.Quantity

# Define parameters in SI units
params = GauravSimParams(
            ureg=ureg,
            box_length=Q_(10000, "micrometer"),
            time_step=Q_(1e-3, "second"),
            time_slice=Q_(1, "second"),
            snapshot_interval=Q_(0.002, "second"),
            raft_radius=Q_(150, "micrometer"),
            raft_repulsion_strength=Q_(1e-7, "newton"),
            dynamic_viscosity=Q_(1e-3, "Pa * s"),
            fluid_density=Q_(1000, "kg / m**3"),
            lubrication_threshold=Q_(15, "micrometer"),
            magnetic_constant=Q_(4 * np.pi * 1e-7, "newton /ampere**2"),
            capillary_force_data_path=pathlib.Path(
                "/work/clohrmann/mpi_collab/capillaryForceAndTorque_sym6"
            ),  # TODO
        )

# Initialize the simulation system
system_runner = GauravSim(params=params, out_folder="./", with_precalc_capillary=True,save_h5=False)
mag_mom = Q_(1e-8, "ampere * meter**2")
for i in range(2):
    system_runner.add_colloids(pos = [np.random.rand()*10000,np.random.rand()*10000, 0]* ureg.micrometer, alpha = np.random.rand()*2*np.pi, magnetic_moment = 1E-8* ureg.ampere * ureg.meter**2)



In [None]:

network = srl.networks.ContinuousFlaxModel(
    flax_model=actor_critic,
    optimizer=optax.adam(learning_rate=0.01),
    input_shape=(resolution,resolution),
    sampling_strategy=sampling_strategy,
    exploration_policy=exploration_policy,
    number_of_gaussians=2,
    action_dimension=8,
)
loss = srl.losses.GlobalPolicyGradientLoss(value_function=value_function)

protocol = srl.agents.MPIActorCriticAgent(
    particle_type=0,
    network=network,
    task=task,
    observable=obs,
    loss=loss,
)


rl_trainer = Trainer([protocol])

In [None]:
rl_trainer.perform_rl_training(system_runner, 2, 2)