In [56]:
from typing import Optional, Any

import gymnasium as gym
from gymnasium.wrappers import TransformAction, TransformObservation

import cst_python as cst
from cst_python.core.entities import Memory, MemoryObject

In [None]:
class GymCodelet(cst.Codelet):
    _last_indexes = {"reward":-1, "reset":-1, "terminated":-1, "truncated":-1, "info":-1, "seed":-1}

    def __init__(self, mind:cst.Mind, env:gym.Env):
        super().__init__()
        
        self.env = env
        
        self.observation_memories = self.space_to_memories(mind, env.observation_space)
        self.action_memories = self.space_to_memories(mind, env.action_space, action=True)

        self._common_memories : dict[str, MemoryObject] = {}
        for name in ["reward", "reset", "terminated", "truncated", "info", "seed"]:
            self._last_indexes[name] += 1

            memory_name = name
            if self._last_indexes[name] != 0:
                memory_name += str(self._last_indexes[name])
            
            self._common_memories[name] = mind.create_memory_object(memory_name)

        self._common_memories["reward"].set_info(0.0)
        self._common_memories["reset"].set_info(False)
        self._common_memories["terminated"].set_info(False)
        self._common_memories["truncated"].set_info(False)
        self._common_memories["info"].set_info({})
        self._common_memories["seed"].set_info(None)


        self.is_memory_observer = True
        for memory_name in self.action_memories:
            memory = self.action_memories[memory_name]
            memory.add_memory_observer(self)
        self._common_memories["reset"].add_memory_observer(self)

        self._last_reset = self._common_memories["reset"].get_timestamp()

    @property
    def reward_memory(self) -> MemoryObject:
        return self._common_memories["reward"]
    
    @property
    def reset_memory(self) -> MemoryObject:
        return self._common_memories["reset"]
    
    @property
    def terminated_memory(self) -> MemoryObject:
        return self._common_memories["terminated"]
    
    @property
    def truncated_memory(self) -> MemoryObject:
        return self._common_memories["truncated"]
    
    @property
    def info_memory(self) -> MemoryObject:
        return self._common_memories["info"]
    
    @property
    def seed_memory(self) -> MemoryObject:
        return self._common_memories["seed"]

    def access_memory_objects(self):
        pass

    def calculate_activation(self):
        pass

    def proc(self):
        if self._last_reset < self.reset_memory.get_timestamp():
            self._last_reset = self.reset_memory.get_timestamp()

            observation, info = self.env.reset(seed=self.seed_memory.get_info())
            reward = 0
            terminated = False
            truncated = False

        else:
            action = self.memories_to_space(self.action_memories, self.env.action_space)
            observation, reward, terminated, truncated, info = self.env.step(action)

            print("Observation", observation)
    
        self.reward_memory.set_info(reward)
        self.terminated_memory.set_info(terminated)
        self.truncated_memory.set_info(truncated)
        self.info_memory.set_info(info)

        self.sample_to_memories(observation, self.observation_memories)

    @classmethod
    def space_to_memories(cls, mind:cst.Mind, 
                          space:gym.Space,
                          action:bool=False) -> dict[str, cst.MemoryObject]:
        memories = {}

        if isinstance(space, gym.spaces.Dict):
            for space_name in space:
                subspace = space[space_name]

                name = space_name
                if space_name in cls._last_indexes:
                    cls._last_indexes[space_name] += 1
                    name += str(cls._last_indexes[space_name])
                else:
                    cls._last_indexes[space_name] = 0

                info = subspace.sample()
                memory = mind.create_memory_object(name, info)
                memories[space_name] = memory
            
        else:
            if action:
                space_name = "action"
            else:
                space_name = "observation"

            name = space_name
            if space_name in cls._last_indexes:
                cls._last_indexes[space_name] += 1
                name += str(cls._last_indexes[space_name])
            else:
                cls._last_indexes[space_name] = 0

            info = space.sample()
            memory = mind.create_memory_object(name, info)
            memories[space_name] = memory
            

        return memories
    
    @classmethod
    def sample_to_memories(cls, sample:dict[str, Any]|Any, memories:dict[str, Memory]) -> None:
        if isinstance(sample, dict):
            for name in sample:
                element = sample[name]
                memory = memories[name]
                
                memory.set_info(element)
        else:
            memory = memories[next(iter(memories))]
            memory.set_info(sample)
        

    @classmethod
    def memories_to_space(cls, memories:dict[str, Memory], space:gym.spaces.Dict) -> dict[str, Any]|Any:
        if isinstance(space, gym.spaces.Dict):
            sample = {}
            for memory_name in memories:
                sample[memory_name] = memories[memory_name].get_info()
        else:
            sample = memories[next(iter(memories))].get_info()

        if not space.contains(sample):
            raise ValueError("Memories do not correspond to an element of the Space.")
        
        return sample

In [58]:
env = gym.make("Blackjack-v1")

env = TransformObservation(env, 
                           lambda obs:{"player_sum":obs[0], "dealer_card":obs[1], "usable_ace":obs[2]}, 
                           gym.spaces.Dict({"player_sum":env.observation_space[0], "dealer_card":env.observation_space[1], "usable_ace":env.observation_space[2]}))

env = TransformAction(env, 
                           lambda action:action["hit"], 
                           gym.spaces.Dict({"hit":env.action_space}))



In [59]:
mind = cst.Mind()
gym_codelet = GymCodelet(mind, env)
mind.insert_codelet(gym_codelet)

mind.start()
gym_codelet.seed_memory.set_info(42)
gym_codelet.reset_memory.set_info(not gym_codelet.reset_memory.get_info())

GymCodelet execution


-1

In [60]:
gym_codelet.observation_memories, gym_codelet.terminated_memory, gym_codelet.reward_memory

({'dealer_card': MemoryObject [idmemoryobject=0, timestamp=1732659816658, evaluation=0.0, I=2, name=dealer_card],
  'player_sum': MemoryObject [idmemoryobject=1, timestamp=1732659816658, evaluation=0.0, I=15, name=player_sum],
  'usable_ace': MemoryObject [idmemoryobject=2, timestamp=1732659816658, evaluation=0.0, I=0, name=usable_ace]},
 MemoryObject [idmemoryobject=6, timestamp=1732659816658, evaluation=0.0, I=False, name=terminated],
 MemoryObject [idmemoryobject=4, timestamp=1732659816658, evaluation=0.0, I=0, name=reward])

In [61]:
gym_codelet.action_memories["hit"].set_info(1)

Observation {'player_sum': 25, 'dealer_card': 2, 'usable_ace': 0}
GymCodelet execution


-1

In [62]:
gym_codelet.observation_memories, gym_codelet.terminated_memory, gym_codelet.reward_memory

({'dealer_card': MemoryObject [idmemoryobject=0, timestamp=1732659816687, evaluation=0.0, I=2, name=dealer_card],
  'player_sum': MemoryObject [idmemoryobject=1, timestamp=1732659816687, evaluation=0.0, I=25, name=player_sum],
  'usable_ace': MemoryObject [idmemoryobject=2, timestamp=1732659816687, evaluation=0.0, I=0, name=usable_ace]},
 MemoryObject [idmemoryobject=6, timestamp=1732659816687, evaluation=0.0, I=True, name=terminated],
 MemoryObject [idmemoryobject=4, timestamp=1732659816687, evaluation=0.0, I=-1.0, name=reward])

In [63]:
gym_codelet.reset_memory.set_info(True)

GymCodelet execution


-1

In [64]:
gym_codelet.observation_memories, gym_codelet.terminated_memory, gym_codelet.reward_memory

({'dealer_card': MemoryObject [idmemoryobject=0, timestamp=1732659816736, evaluation=0.0, I=2, name=dealer_card],
  'player_sum': MemoryObject [idmemoryobject=1, timestamp=1732659816736, evaluation=0.0, I=15, name=player_sum],
  'usable_ace': MemoryObject [idmemoryobject=2, timestamp=1732659816736, evaluation=0.0, I=0, name=usable_ace]},
 MemoryObject [idmemoryobject=6, timestamp=1732659816736, evaluation=0.0, I=False, name=terminated],
 MemoryObject [idmemoryobject=4, timestamp=1732659816736, evaluation=0.0, I=0, name=reward])

In [65]:
gym_codelet.action_memories["hit"].set_info(0)

Observation {'player_sum': 15, 'dealer_card': 2, 'usable_ace': 0}
GymCodelet execution


-1

In [66]:
gym_codelet.observation_memories, gym_codelet.terminated_memory, gym_codelet.reward_memory

({'dealer_card': MemoryObject [idmemoryobject=0, timestamp=1732659816814, evaluation=0.0, I=2, name=dealer_card],
  'player_sum': MemoryObject [idmemoryobject=1, timestamp=1732659816814, evaluation=0.0, I=15, name=player_sum],
  'usable_ace': MemoryObject [idmemoryobject=2, timestamp=1732659816814, evaluation=0.0, I=0, name=usable_ace]},
 MemoryObject [idmemoryobject=6, timestamp=1732659816814, evaluation=0.0, I=True, name=terminated],
 MemoryObject [idmemoryobject=4, timestamp=1732659816814, evaluation=0.0, I=1.0, name=reward])

In [67]:
env = gym.make("Blackjack-v1")
mind = cst.Mind()

gym_codelet = GymCodelet(mind, env)
mind.insert_codelet(gym_codelet)

mind.start()
gym_codelet.seed_memory.set_info(42)
gym_codelet.reset_memory.set_info(not gym_codelet.reset_memory.get_info())

GymCodelet execution


-1

In [68]:
gym_codelet.observation_memories, gym_codelet.terminated_memory, gym_codelet.reward_memory

({'observation': MemoryObject [idmemoryobject=0, timestamp=1732659816913, evaluation=0.0, I=(15, 2, 0), name=observation]},
 MemoryObject [idmemoryobject=4, timestamp=1732659816913, evaluation=0.0, I=False, name=terminated1],
 MemoryObject [idmemoryobject=2, timestamp=1732659816913, evaluation=0.0, I=0, name=reward1])

In [69]:
gym_codelet.action_memories["action"].set_info(1)

Observation (25, 2, 0)
GymCodelet execution


-1

In [70]:
gym_codelet.observation_memories, gym_codelet.terminated_memory, gym_codelet.reward_memory

({'observation': MemoryObject [idmemoryobject=0, timestamp=1732659816947, evaluation=0.0, I=(25, 2, 0), name=observation]},
 MemoryObject [idmemoryobject=4, timestamp=1732659816947, evaluation=0.0, I=True, name=terminated1],
 MemoryObject [idmemoryobject=2, timestamp=1732659816947, evaluation=0.0, I=-1.0, name=reward1])