In [2]:
# custom stuff
import numpy as np
import torch
from torch_geometric.data import HeteroData, Batch

class HeteroDataReplayBuffer(object):
    def __init__(self, batch_size, buffer_size, device):
        self.batch_size = batch_size
        self.max_size = int(buffer_size)
        self.device = device

        # pointer where to add data
        self.ptr = 0
        # current size of the data
        self.crt_size = 0

        self.state = []
        self.action = np.zeros((self.max_size, 1))
        self.next_state = []
        self.reward = np.zeros((self.max_size, 1))
        self.not_done = np.zeros((self.max_size, 1))

    def add(self, state, action, next_state, reward, done):
        if len(self.state) < self.max_size:
            self.state.append(state)
            self.next_state.append(next_state)
        else:
            self.state[self.ptr] = state
            self.next_state[self.ptr] = next_state

        self.action[self.ptr] = action
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done

        self.ptr = (self.ptr + 1) % self.max_size
        self.crt_size = min(self.crt_size + 1, self.max_size)

    def sample(self):
        ind = np.random.randint(0, self.crt_size, size=self.batch_size)
        
        state_batch = Batch.from_data_list([self.state[i] for i in ind])
        next_state_batch = Batch.from_data_list([self.next_state[i] for i in ind])

        return (
            state_batch.to(self.device),
            torch.LongTensor(self.action[ind]).to(self.device),
            next_state_batch.to(self.device),
            torch.FloatTensor(self.reward[ind]).to(self.device),
            torch.FloatTensor(self.not_done[ind]).to(self.device)
        )

    def save(self, save_folder):
        torch.save(self.state[:self.crt_size], f"{save_folder}_state.pt")
        np.save(f"{save_folder}_action.npy", self.action[:self.crt_size])
        torch.save(self.next_state[:self.crt_size], f"{save_folder}_next_state.pt")
        np.save(f"{save_folder}_reward.npy", self.reward[:self.crt_size])
        np.save(f"{save_folder}_not_done.npy", self.not_done[:self.crt_size])
        np.save(f"{save_folder}_ptr.npy", self.ptr)

    def load(self, save_folder, size=-1):
        reward_buffer = np.load(f"{save_folder}_reward.npy")
        
        # Adjust crt_size if we're using a custom size
        size = min(int(size), self.max_size) if size > 0 else self.max_size
        self.crt_size = min(reward_buffer.shape[0], size)

        self.state = torch.load(f"{save_folder}_state.pt")[:self.crt_size]
        self.action[:self.crt_size] = np.load(f"{save_folder}_action.npy")[:self.crt_size]
        self.next_state = torch.load(f"{save_folder}_next_state.pt")[:self.crt_size]
        self.reward[:self.crt_size] = reward_buffer[:self.crt_size]
        self.not_done[:self.crt_size] = np.load(f"{save_folder}_not_done.npy")[:self.crt_size]

        print(f"Replay Buffer loaded with {self.crt_size} elements.")


In [3]:
batch_size = 64
# we want that all the past experiences fit there as we don't interract with any en
train_test_ratio = 0.8
number_of_trajectories = 19000
buffer_size = train_test_ratio * 19000
replay_buffer = HeteroDataReplayBuffer(batch_size=batch_size, buffer_size=buffer_size, device)

SyntaxError: positional argument follows keyword argument (3886481327.py, line 6)

In [None]:
# choose 80% of positive trajectories and save them in the list. use the analysis notebook for that.
# choose 80% of negative trajectories and save
# shuffle train trajectories
# iterate over each train trajectory
#     query neo4j to get trajectory graph
#     set history of 1 extra node -> iterate over every timestep to create 

In [None]:
from utils import get_trajectory_from_neo4j, Neo4jConnection

# Initialize the Neo4j driver
neo4j_connection = Neo4jConnection()
driver = neo4j_connection.get_driver()

In [None]:
ids = [6, 7]

In [None]:
def create_subgraph(step_idx):
    start_idx = max(0, step_idx - history)
        end_idx = step_idx
    pass

In [None]:
def populate_replay_buffer(replay_buffer, history=1):

for traj in ids:
    patient, time_steps, actions = get_trajectory_from_neo4j(driver=driver, traj=traj)
    
    n_time_steps = len(time_steps)
    next_obs = None

    for step_idx in range(n_time_steps):
        # Applying ReLu basically to make sure that history
        # substracting does not result into negative indice
        
        
        obs = next_obs if next_obs is not None else create_subgraph(patient, time_steps, step_idx)
        action = pass
        done = pass
        reward = pass
        if not done:
            next_obs = create_subgraph(patient, time_steps, step_idx)
        replay_buffer.add(obs, action, reward, next_obs)
        