In [1]:
import numpy as np
import random

In [None]:
a = np.array([0,1,4,3,2,5])
b = np.array([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3], [4,4,4,4], [5,5,5,5]])

In [2]:
a = np.ones((10,), dtype=np.int32)

In [6]:
a

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [7]:
~a

array([-2, -2, -2, -2, -2, -2, -2, -2, -2, -2])

In [17]:
b[a]

array([[0, 0, 0, 0],
       [1, 1, 1, 1],
       [4, 4, 4, 4],
       [3, 3, 3, 3],
       [2, 2, 2, 2],
       [5, 5, 5, 5]])

In [167]:
import numpy as np

class DenseStackedObservations:
    def __init__(self, stack_size, step, observation_shape, n_agents=1):
        self.stack_size = stack_size
        self.observation_shape = observation_shape
        self.step = step
        self.n_agents = n_agents
        
        # Inizializziamo lo stack e il contatore
        self.stack = np.zeros((n_agents, observation_shape * stack_size), dtype=np.float32)
        self.step_counter = np.zeros((n_agents,), dtype=np.int32) # Usiamo int per i contatori

    def reset(self, observation_ids):
        # Reset veloce dei contatori per gli agenti specificati
        self.step_counter[observation_ids] = 0

    def add_observation(self, observation_stack, observation_ids):
        # Conversione forzata in array NumPy per gestire correttamente l'indicizzazione
        observation_ids = np.atleast_1d(np.array(observation_ids, dtype=np.int32))
        observation_stack = np.atleast_2d(np.array(observation_stack, dtype=np.float32))

        if len(observation_ids) > self.n_agents:
            raise ValueError("observation_ids length exceeds number of agents.")
        
        # 1. Identifichiamo chi deve inizializzare (primo step)
        mask_init = (self.step_counter[observation_ids] == 0)
        
        if np.any(mask_init):
            ids_to_init = observation_ids[mask_init]
            obs_to_init = observation_stack[mask_init]
            # Usiamo l'assegnazione diretta per evitare "fancy indexing" issues
            self.stack[ids_to_init] = np.tile(obs_to_init, (1, self.stack_size))
        
        # 2. Identifichiamo chi deve aggiornare lo stack (ogni 'step' passi)
        # Escludiamo chi Ã¨ appena stato inizializzato
        mask_update = (self.step_counter[observation_ids] % self.step == 0) & (~mask_init)
        
        if np.any(mask_update):
            ids_to_update = observation_ids[mask_update]
            obs_to_update = observation_stack[mask_update]
            
            # Applichiamo il roll riga per riga per gli agenti interessati
            self.stack[ids_to_update] = np.roll(self.stack[ids_to_update], -self.observation_shape, axis=1)
            self.stack[ids_to_update, -self.observation_shape:] = obs_to_update
        
        # 3. Incremento contatori
        self.step_counter[observation_ids] += 1
        
    def get_stacked_observations(self, observation_ids):
        # Restituisce una copia per evitare modifiche accidentali esterne
        return self.stack[observation_ids].copy()

In [168]:
agent_ids = np.intersect1d([1], [1])
agent_ids

array([1])

In [150]:
a = DenseStackedObservations(stack_size=4, step=1, observation_shape=3, n_agents=5)

In [171]:
a.reset([])

In [170]:
a.get_stacked_observations([])

array([], shape=(0, 12), dtype=float32)

In [146]:
a.add_observation([[0,0,0]], observation_ids=[0])
a.add_observation([[1,1,1]]*2, observation_ids=[0,1])
a.add_observation([[2,2,2]]*3, observation_ids=[0,1,2])
a.add_observation([[3,3,3]]*4, observation_ids=[0,1,2,3])

In [147]:
a.stack

array([[0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.],
       [1., 1., 1., 1., 1., 1., 2., 2., 2., 3., 3., 3.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [148]:
for i in range(100, 0, -1):
    
    if random.random() < 0.1:
        id = random.randint(0,4)
        a.reset(observation_ids=[id])
        print("reset id", id)
    
    a.add_observation(np.array([[i, i, i]]*3), observation_ids=[0,2,3])
    print(a.stack)

[[  1.   1.   1.   2.   2.   2.   3.   3.   3. 100. 100. 100.]
 [  1.   1.   1.   1.   1.   1.   2.   2.   2.   3.   3.   3.]
 [  2.   2.   2.   2.   2.   2.   3.   3.   3. 100. 100. 100.]
 [  3.   3.   3.   3.   3.   3.   3.   3.   3. 100. 100. 100.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]]
[[  2.   2.   2.   3.   3.   3. 100. 100. 100.  99.  99.  99.]
 [  1.   1.   1.   1.   1.   1.   2.   2.   2.   3.   3.   3.]
 [  2.   2.   2.   3.   3.   3. 100. 100. 100.  99.  99.  99.]
 [  3.   3.   3.   3.   3.   3. 100. 100. 100.  99.  99.  99.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]]
[[  3.   3.   3. 100. 100. 100.  99.  99.  99.  98.  98.  98.]
 [  1.   1.   1.   1.   1.   1.   2.   2.   2.   3.   3.   3.]
 [  3.   3.   3. 100. 100. 100.  99.  99.  99.  98.  98.  98.]
 [  3.   3.   3. 100. 100. 100.  99.  99.  99.  98.  98.  98.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]]
[[100. 100. 100.  99.  99.  99.  98.  98.  98.  97. 

In [144]:
a.get_stacked_observations(np.array([4,0]))

array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [19., 19., 19., 14., 14., 14.,  9.,  9.,  9.,  4.,  4.,  4.]],
      dtype=float32)