In [None]:
import numpy as np
import scipy
import sys
import torch
import torch.nn.functional as F
import random
from collections import deque
from functools import reduce

print("===VERSIONS===")
print(f"Python: {sys.version}")
print(f"numpy: {np.__version__}")
print(f"PyTorch: {torch.__version__}")
print(f"Scipy: {scipy.__version__}")

In [None]:
PARAM_EPSILON = 0.1
PARAM_ALPHA = 0.002
PARAM_BETA = 1
PARAM_C = 0.01

In [None]:
class NormalizedTiling:
    def __init__(self, ndim, ksteps, offset):
        self.ndim = ndim
        self.k = ksteps
        self.offset = offset
        assert self.offset < 0, "Offset must be less than 0"
        assert self.offset >= -1 / self.k, f"Offset cannot be above {-1/self.k}"
        self.counts = np.zeros(np.power(self.k + 1, self.ndim))

    def tile_index(self, state):
        assert len(state) == self.ndim, f"Expected state of dimension {self.ndim}, received {len(state)}"
        shift = (state - self.offset * np.ones(self.ndim)) * self.k
        return sum([int(np.power(self.k + 1, j) * (np.floor(shift[j]) if shift[j] < self.k else self.k)) for j in range(self.ndim)])

class TilingDensity:
    def __init__(self, ndim, ntiles, ksteps):
        self.tiles = [NormalizedTiling(ndim, ksteps, -(i + 1)/(ksteps * ntiles)) for i in range(ntiles)]
        self.total_count = 0

    def count(self, x):
        for tiling in self.tiles:
            idx = tiling.tile_index(x)
            tiling.counts[idx] += 1
        self.total_count += 1

    def density(self, x):
        if self.total_count > 0:
            return sum([tiling.counts[tiling.tile_index(x)] for tiling in self.tiles]) / self.total_count
        else:
            return 0

In [None]:
class BasicNN(torch.nn.Module):
    def __init__(self, ndim):
        super(BasicNN, self).__init__()
        hls = round(8*ndim / 3) # chosen by vibes
        #hls = 10 * 2 * ndim
        self.fc1 = torch.nn.Linear(2*ndim, hls)
        self.fc2 = torch.nn.Linear(hls, hls)
        self.fc3 = torch.nn.Linear(hls, 1)

        torch.nn.init.normal_(self.fc1.weight, std=0.1)
        torch.nn.init.normal_(self.fc1.weight, std=0.1)
        torch.nn.init.normal_(self.fc3.weight, std=0.1)

        #torch.nn.init.zeros_(self.fc1.bias)
        #torch.nn.init.zeros_(self.fc2.bias)
        #torch.nn.init.zeros_(self.fc3.bias)

        self.double()

    def forward(self, x):
        #x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        #x = self.fc3(x)
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = F.tanh(self.fc3(x))
        return x

In [None]:
class DeepNN(torch.nn.Module):
    def __init__(self, ndim):
        super(DeepNN, self).__init__()
        hls = 10 * 2 * ndim
        self.fc1 = torch.nn.Linear(2* ndim, hls)
        self.fc2 = torch.nn.Linear(hls, hls)
        self.fc3 = torch.nn.Linear(hls, hls)
        self.fc4 = torch.nn.Linear(hls, 1)
        
        torch.nn.init.normal_(self.fc1.weight, std=0.1)
        torch.nn.init.normal_(self.fc2.weight, std=0.1)
        torch.nn.init.normal_(self.fc3.weight, std=0.1)
        torch.nn.init.normal_(self.fc4.weight, std=0.1)

        self.double()

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = F.tanh(self.fc3(x))
        x = F.tanh(self.fc4(x))
        return x

In [None]:
class ReLUNet(torch.nn.Module):
    def __init__(self, ndim):
        super(ReLUNet, self).__init__()
        hls = 100
        self.fc1 = torch.nn.Linear(ndim, hls)
        self.fc2 = torch.nn.Linear(hls, hls)
        self.fc3 = torch.nn.Linear(hls, 2 * ndim)
        self.action_size = 2 * ndim

        torch.nn.init.normal_(self.fc1.weight, std=0.1)
        torch.nn.init.normal_(self.fc1.weight, std=0.1)
        torch.nn.init.normal_(self.fc3.weight, std=0.1)
        
        self.double()

    def forward(self, x):
        x = torch.sub(x, 0.5) # TODO does this help/hurt?
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class DTAMERLoss(torch.nn.Module):
    def __init__(self, actionsize):
        super(DTAMERLoss, self).__init__()
        self.actionsize = actionsize

    def forward(self, result, target, action, weights=None):
        self.one_hot = F.one_hot(action, num_classes=self.actionsize).to(torch.float64)
        self.q = torch.mul(self.one_hot, result).sum(axis=1)
        if weights is not None:
            self.error = (weights * torch.square(target - self.q)).mean()
        else:
            self.error = ((target - self.q) ** 2).mean()
        return self.error

In [None]:
#def idenfeatures(states, actions):
#    return np.array([*states, *actions])

#def idenstate(states, actions):
#    return np.array([*states])

#get_features = idenstate

def and_op(x, y):
    return x and y

In [None]:
import time
from scipy.stats import gamma


class LinearGammaCrediter:
    def __init__(self, ndims):
        self._history = []
        # From TAMER
        self.k = 2.0
        self.theta = 0.28
        self.delay = 0.20 # seconds

    def add_index(self, feature_vec):
        self._history.append((feature_vec, time.time()))

    def credit(self):
        # Prune old times
        self._history = [x for x in self._history if time.time() - x[1] < gamma.ppf(0.999, self.k, self.delay, self.theta)]
        self._history.sort(key=lambda x: x[1], reverse=True)
        if len(self._history) == 0:
            raise Exception("Empty history array - cannot assign credit")
        # Calculate from remaining
        return sum([
            (gamma.cdf(x[1], self.k, self.delay, self.theta) - \
             (0 if idx == 0 else gamma.cdf(self._history[idx-1][1], self.k, self.delay, self.theta))) * \
            x[0] for idx, x in enumerate(self._history)
        ])

class GammaCrediter:
    """
    Unlike the above/older version of the crediter, this version
    does not combine the history into one feature vector. Instead,
    each vector is returned alongside a weight. Weights sum to 1.
    """
    def __init__(self, ndims):
        self._history = []
        self.k = 2.0
        self.theta = 0.28
        self.delay = 0.20
        
    def add_index(self, feature_vec):
        self._history.append((feature_vec, time.time()))
    
    def credit(self):
        # Prune old times
        self._history = [x for x in self._history if time.time() - x[1] < gamma.ppf(0.999, self.k, self.delay, self.theta)]
        self._history.sort(key=lambda x: x[1], reverse=True)
        if len(self._history) == 0:
            raise Exception("Empty history array - cannot assign credit")
        current_time = time.time()
        weights = np.array([gamma.cdf(current_time - x[1], self.k, self.delay, self.theta) - \
                  (0 if idx == 0 else gamma.cdf(current_time - self._history[idx-1][1], self.k, self.delay, self.theta)) \
                  for idx, x in enumerate(self._history)]).reshape((len(self._history), 1))
        return (np.array([x[0] for x in self._history]), weights)

    def credit2(self):
        # Prune old times
        next = []
        self._history = [x for x in self._history if time.time() - x[1] < gamma.ppf(0.999, self.k, self.delay, self.theta)]
        self._history.sort(key=lambda x: x[1], reverse=True)
        if len(self._history) == 0:
            raise Exception("Empty history array - cannot assign credit")
        current_time = time.time()
        weights = np.array([gamma.cdf(current_time - x[1], self.k, self.delay, self.theta) - \
                  (0 if idx == 0 else gamma.cdf(current_time - self._history[idx-1][1], self.k, self.delay, self.theta)) \
                  for idx, x in enumerate(self._history)]).reshape((len(self._history), 1))
        return (np.array([x[0] for x in self._history]), weights, np.zeros(len(weights)))

class UniformCrediter:
    """
    Equally splits reward over interval of t-0.2 to t-4. Weights sum to 1
    """
    def __init__(self, ndims):
        self._history = []
        self._low = 0.1
        self._high = 2.0

    def add_index(self, feature_vec):
        self._history.append((feature_vec, time.time()))

    def credit(self):
        # Prune old times
        call_time = time.time()
        self._history = [x for x in self._history if call_time - x[1] <= self._high]
        tmp = [x for x in self._history if call_time - x[1] >= self._low]
        if len(tmp) == 0:
            raise Exception("No eligible history - cannot assign credit")
        return (np.array([x[0] for x in tmp]), np.ones((len(tmp), 1)) / len(tmp)) 

    def credit2(self):
        call_time = time.time()
        self._history = [x for x in self._history if call_time - x[1] <= self._high] # Limit to those eligible to receive credit
        values = []
        next = []
        for i in range(len(self._history)):
            if call_time - self._history[i][1] >= self._low:
                values.append(self._history[i][0])
                if i < len(self._history) - 1:
                    next.append(self._history[i + 1][0])
                else:
                    # This case shouldn't happen
                    print("Entire history eligible for credit, this shouldn't typically happen!")
                    next.append(self._history[i][0]) # Can't give dummy value, assuming that the person wants a no-op as much as possible
            else:
                break
        assert len(values) == len(next), f"{len(values)} values eligible, but {len(next)} next state-actions found"
        if len(values) == 0:
            raise Exception("No eligible history - cannot assign credit")
        weights = np.ones((len(values), 1)) # / len(values)
        return (np.array(values), weights, np.array(next))

    def credit3(self):
        call_time = time.time()
        self._history = [x for x in self._history if call_time - x[1] <= self._high]
        values = [i[0] for i in self._history if call_time - i[1] > self._low]
        states = np.array([i[0] for i in values])
        actions = np.array([[i[1]] for i in values])
        return (states, np.ones(actions.shape), actions)
        

In [None]:
class BaseAgent:
    def __init__(self, ndims, step, *args):
        self._ndims = ndims
        self._step = step
        self.state = np.zeros(self._ndims)
        self.state_lows = np.zeros(self._ndims)
        self.state_highs = np.ones(self._ndims)
        a = np.eye(self._ndims) * self._step
        self._actions = np.concatenate((a, -a))
        self._actions_index = list(range(len(self._actions)))
        self._exclude_dims = set()
        self._rng = np.random.default_rng()

# Values should be normalized to 0-1 space for each
    def set_state(self, state, *, lows=None, highs=None, action=None, history=True):
        state = np.array(state)
        # Check range and update
        if lows is not None:
            assert len(lows) == self._ndims, f"Expected lows to contain {self._ndims} elements, not {len(lows)}"
            assert reduce(and_op, [x >= 0 for x in lows], True) and reduce(and_op, [x <= 1 for x in lows], True), f"Low range not normalized: {lows}"
            self.state_lows = lows
        if highs is not None:
            assert len(highs) == self._ndims, f"Expected highs to contain {self._ndims} elements, not {len(highs)}"
            assert reduce(and_op, [x >= 0 for x in highs], True) and reduce(and_op, [x <= 1 for x in highs], True), f"High range not normalized: {highs}"
            self.state_highs = highs
        assert reduce(and_op, [x >= 0 for x in state], True) and reduce(and_op, [x <= 1 for x in state], True), f"State out of bounds {state}"
        old_state = self.state
        self.state = state
        return old_state

    def to_action(self, action):
        return self._actions[action]

    def _check_bounds(self, state):
        return ((state >= 0) & (state <= 1) & (state >= np.array(self.state_lows)) & (state <= np.array(self.state_highs))).all(0)

    def apply_action(self, action):
        next_state = self.state + self.to_action(action)
        if self._check_bounds(next_state):
            self.set_state(next_state, action=action)
        else:
            raise Exception(f"Tried to transition to an invalid state {next_state}.")

    def update_activation(self, dimension, activation):
        if activation:
            self._exclude_dims.discard(dimension)
        else:
            self._exclude_dims.add(dimension)

    def _included_actions(self):
        return np.array([act for act in self._actions_index if reduce(lambda x, y: x and y, [self.to_action(act)[dim] == 0 for dim in self._exclude_dims], True)])

In [None]:
losses = []

# Scurto used alpha = 0.002
class NeuralSGDAgent(BaseAgent):
    def __init__(self, ndims, step, epsilon=PARAM_EPSILON, alpha=PARAM_ALPHA, crediter=UniformCrediter, replay=True):
        BaseAgent.__init__(self, ndims, step)
        self.crediter = crediter(self._ndims)
        self._net = ReLUNet(self._ndims)
        #self._criterion = torch.nn.MSELoss()
        self._criterion = DTAMERLoss(2 * self._ndims)
        self._optimizer = torch.optim.SGD(self._net.parameters(), lr=alpha)
        self._epsilon = epsilon
        self._alpha = alpha
        self._beta = PARAM_BETA
        self._c = PARAM_C
        n_tiles = int(2 + np.ceil(np.log2(self._ndims)))
        k_tile = int(np.ceil(1 / (4 * self._step)))
        #n_tiles = int(2 + np.ceil(np.log2(self._ndims)))
        #k_tile = int(np.ceil(1/(2 * self._step)))
        self.tiling = TilingDensity(self._ndims, n_tiles, k_tile)
        self.replay = replay
        self._replay_batch = 32
        self._history = deque(maxlen=700)

    def set_state(self, state, *, lows=None, highs=None, action=None, history=True):
        # Check range and update
        old_state = BaseAgent.set_state(self, state, lows=lows, highs=highs, action=action, history=history)
        if history:
            self.tiling.count(old_state)
            if action is not None:
                self.crediter.add_index((old_state, action))

    def _select_action(self):
        start_time = time.time()
        max_actions = []
        invs = []
        max_value = np.NINF
        valid_actions = [action for action in self._included_actions() if self._check_bounds(self.state + self.to_action(action))]
        if len(valid_actions) > 0:
            action_values = self._net(torch.from_numpy(self.state)).detach().numpy()[valid_actions]
            explore_values = np.array([self._beta * np.power(
                self.tiling.density(self.state + self.to_action(action)) * self.tiling.total_count + self._c,
                -0.5
            ) for action in valid_actions])
            valid_values = action_values + explore_values
            #print(f"Action values ({len(action_values)}): {action_values}")
            #print(f"Explore values ({len(explore_values)}): {explore_values}")
            max_ind = np.argmax(valid_values)
            #print(f"Selected {max_ind} of {len(valid_values)} values")
            #print(f"Reward: [{np.min(action_values)}, {np.max(action_values)}], Explore: [{np.min(explore_values)}, {np.max(explore_values)}]")
            end_time = time.time()
            timings.append({"key": "step_greedy", "start": start_time, "end": end_time})
            return valid_actions[max_ind]
        else:
            print("No valid actions!")
            return None

    def select_epsilon_greedy_action(self):
        if self._rng.random() < self._epsilon:
            start_time = time.time()
            # Exploration-only action
            valid_actions = [action for action in self._included_actions() if self._check_bounds(self.state + self.to_action(action))]
            if len(valid_actions) > 0:
                feature_explore_values = np.array([self._beta * np.power(self.tiling.density(self.state + self.to_action(act)) * self.tiling.total_count + self._c, -0.5) for act in valid_actions])
                max_ind = np.argmax(feature_explore_values)
                #max_ind = self._rng.integers(len(valid_actions))
                end_time = time.time()
                timings.append({"key": "step_rand", "start": start_time, "end": end_time})
                return valid_actions[max_ind]
            else:
                print("No valid actions!")
                return None
        else:
            return self._select_action()

    def replay_from_history(self):
        global losses
        if len(self._history) >= 2 * self._replay_batch:
            sample = random.sample(self._history, self._replay_batch)
            states = np.array([x[0] for x in sample])
            weights = np.array([x[1] for x in sample])
            actions = np.array([x[2] for x in sample])
            self._optimizer.zero_grad()
            error = self._criterion(self._net(torch.from_numpy(states)), torch.from_numpy(weights), torch.from_numpy(actions))
            error.backward()
            self._optimizer.step()
            losses.append((time.time(), error.detach().numpy()))
            print("Replayed from history")

    def process_guiding_reward(self, reward):
        global losses
        try:
            states, credit_weight, actions = self.crediter.credit3()
            print(states.shape)
            credit_weight = credit_weight * reward
            # credit_x := credit_x + gamma * q(snext, anext, weights)
        except Exception as e:
            print(f"ERROR: {e}")
            print("Not applying reward...")
            return
        print("determined guidance")
        self._optimizer.zero_grad()
        error = self._criterion(self._net(torch.from_numpy(states)), torch.from_numpy(credit_weight), torch.from_numpy(actions))
        print(f"Error: {error}")
        error.backward()
        self._optimizer.step()

        losses.append((time.time(), error.detach().numpy()))

        print("updated model")
        if self.replay:
            history_buf = list(zip(states, credit_weight, actions))
            self._history.extend(history_buf)

    def process_zone_reward(self, reward):
        # Positive reward - apply towards this point - negative reward - apply away from this point
        # Directions include those disabled so we properly encode the zone here
        SCALE_FACTOR = 1
        ZONE_STEPS = 3 # Arbitrarily chosen, TODO scale to match number of divisions in a dimension. Note that Scurto et al. effectively used 5.
        features = []
        weights = []
        reward_mag = reward #/ (ZONE_STEPS * len(self._actions))
        for action in self._actions:
            state = self.state
            # Iterate through steps and apply 1) reward in direction we want to move 2) -reward in direction we do not want to move
            for _step in range(1, ZONE_STEPS + 1):
                tmp = state + action
                if not self._check_bounds(tmp):
                    break
                features.append(get_features(tmp, action))
                weights.append(np.array([-float(SCALE_FACTOR * reward_mag)]))
                features.append(get_features(state, -action))
                weights.append(np.array([float(SCALE_FACTOR * reward_mag)]))
                
                state = tmp
        self._optimizer.zero_grad()
        error = self._criterion(self._net(torch.from_numpy(np.array(features))), torch.from_numpy(np.array(weights)))
        print(f"Error: {error}")
        error.backward()
        self._optimizer.step()
        if self.replay:
            history_buf = list(zip(features, weights))
            self._history.extend(history_buf)

In [None]:
class SplitNeuralSGDAgent(NeuralSGDAgent):
    def __init__(self, ndims1, ndims2, step, epsilon=PARAM_EPSILON, alpha=PARAM_ALPHA, crediter=UniformCrediter, gamma=PARAM_GAMMA, replay=True):
        NeuralSGDAgent.__init__(self, ndims1 + ndims2, step, epsilon, alpha, crediter, gamma, replay)
        self._split_index = ndims1
        self.crediter = None
        self.crediter1 = crediter(ndims1)
        self.crediter2 = crediter(ndims2)
        self._net = None
        self._net1 = DeepNN(ndims1)
        self._net2 = DeepNN(ndims2)
        #self._net1 = BasicNN(ndims1)
        #self._net2 = BasicNN(ndims2)
        self._optimizer = None
        self._optimizer1 = torch.optim.SGD(self._net1.parameters(), lr=alpha)
        self._optimizer2 = torch.optim.SGD(self._net2.parameters(), lr=alpha)

    def split(self, value):
        assert len(value) == self._ndims, f"Expected vector of size {self._ndims}, received {len(value)}"
        return (value[:self._split_index], value[self._split_index:])

    def set_state(self, state, *, lows=None, highs=None, action=None, history=True):
        old_state = BaseAgent.set_state(self, state, lows=lows, highs=highs, action=action, history=history)
        if history:
            self.tiling.count(old_state)
            if action is not None:
                state1, state2 = self.split(old_state)
                action1, action2 = self.split(action)
                self.crediter1.add_index(get_features(state1, action1))
                self.crediter2.add_index(get_features(state2, action2))

    def _get_value(self, state, action):
        state1, state2 = self.split(state)
        action1, action2 = self.split(action)
        return self._net1(torch.from_numpy(get_features(state1, action1))).item() + self._net2(torch.from_numpy(get_features(state2, action2))).item()

    def _select_action(self):
        start_time = time.time()
        valid_actions = [action for action in self._included_actions() if self._check_bounds(self.state + action)]
        if len(valid_actions) > 0:
            features1 = np.array([get_features(self.state[:self._split_index] + action[:self._split_index], action[:self._split_index]) for action in valid_actions])
            features2 = np.array([get_features(self.state[self._split_index:] + action[self._split_index:], action[self._split_index:]) for action in valid_actions])
            feature_reward_values = self._net1(torch.from_numpy(features1)).detach().numpy() + self._net2(torch.from_numpy(features2)).detach().numpy()
            feature_explore_values = np.array([[self._beta * np.power(self.tiling.density(self.state + action) * self.tiling.total_count + self._c, -0.5)] for action in valid_actions])
            feature_values = feature_reward_values + feature_explore_values
            max_ind = np.argmax(feature_values)
            end_time = time.time()
            timings.append({"key": "step_greedy", "start": start_time, "end": end_time})
            return valid_actions[max_ind]
        else:
            print("No valid actions!")
            return None

    def replay_from_history(self):
        if len(self._history) >= 2 * self._replay_batch:
            sample = random.sample(self._history, self._replay_batch)
            a1 = np.array([x[0] for x in sample if x[2] == 1])
            a2 = np.array([x[0] for x in sample if x[2] == 2])
            b1 = np.array([x[1] for x in sample if x[2] == 1])
            b2 = np.array([x[1] for x in sample if x[2] == 2])
            print(a2.shape)
            print(b2.shape)
            if len(a1) > 0:
                self._optimizer1.zero_grad()
                error = self._criterion(self._net1(torch.from_numpy(a1)), torch.from_numpy(b1))
                error.backward()
                self._optimizer1.step()
            if len(a2) > 0:
                self._optimizer2.zero_grad()
                error = self._criterion(self._net2(torch.from_numpy(a2)), torch.from_numpy(b2))
                error.backward()
                self._optimizer2.step()
            print("Replayed from history")

    def process_guiding_reward(self, reward, modality):
        assert modality is not None, "No modality specified"
        try:
            credit_x, credit_y, next = self.crediter1.credit2() if modality == 1 else self.crediter2.credit2()
            credit_y = credit_y * reward
            # credit_x := credit_x + gamma * q(snext, anext, weights)
            if self._gamma > 0:
                start = time.time()
                if modality == 1:
                    credit_x = credit_x + self._gamma * self._net1(torch.from_numpy(next)).detach().numpy()
                else:
                    credit_x = credit_x + self._gamma * self._net2(torch.from_numpy(next)).detach().numpy()

        except Exception as e:
            print(f"ERROR: {e}")
            print("Not applying reward...")
            return
        if modality == 1:
            self._optimizer1.zero_grad()
            error = self._criterion(self._net1(torch.from_numpy(credit_x)), torch.from_numpy(credit_y))
            print(f"Error: {error}")
            error.backward()
            self._optimizer1.step()
        else:
            self._optimizer2.zero_grad()
            error = self._criterion(self._net2(torch.from_numpy(credit_x)), torch.from_numpy(credit_y))
            print(f"Error: {error}")
            error.backward()
            self._optimizer2.step()

        if self.replay:
            history_buf = list(zip(credit_x, credit_y, np.ones(credit_y.shape) * modality))
            self._history.extend(history_buf)

    def process_zone_reward(self, reward):
        # Positive reward - apply towards this point - negative reward - apply away from this point
        # Directions include those disabled so we properly encode the zone here
        SCALE_FACTOR = 1
        ZONE_STEPS = 5 # Arbitrarily chosen, TODO scale to match number of divisions in a dimension. Note that Scurto et al. effectively used 5.
        features1 = []
        features2 = []
        weights = []
        for action in self._actions:
            state = self.state
            # Iterate through steps and apply 1) reward in direction we want to move 2) -reward in direction we do not want to move
            for _step in range(1, ZONE_STEPS + 1):
                tmp = state + action
                if not self._check_bounds(tmp):
                    break
                tmp1, tmp2 = self.split(tmp)
                state1, state2 = self.split(state)
                action1, action2 = self.split(action)
                features1.append(get_features(tmp1, action1))
                features2.append(get_features(tmp2, action2))
                weights.append(np.array([-float(SCALE_FACTOR * reward)]))
                features1.append(get_features(state1, -action1))
                features2.append(get_features(state2, -action2))
                weights.append(np.array([float(SCALE_FACTOR * reward)]))
                
                state = tmp
        self._optimizer1.zero_grad()
        error = self._criterion(self._net1(torch.from_numpy(np.array(features1))), torch.from_numpy(np.array(weights)))
        print(f"Error: {error}")
        error.backward()
        self._optimizer1.step()
        self._optimizer2.zero_grad()
        error = self._criterion(self._net2(torch.from_numpy(np.array(features2))), torch.from_numpy(np.array(weights)))
        error.backward()
        self._optimizer2.step()

        if self.replay:
            history_buf = list(zip(features1, features2, weights, weights))
            self._history.extend(history_buf)

In [None]:
from pythonosc.dispatcher import Dispatcher
from pythonosc.osc_server import ThreadingOSCUDPServer
from pythonosc.udp_client import SimpleUDPClient
from threading import Thread
import time

manualMode = True
agents = {}

agentType = "joint"
haptic_dims = 6

ip = "127.0.0.1" # localhost
port = 8080
destPort = 8081

client = SimpleUDPClient(ip, destPort)
timings = []

def default_handler(address, *args):
    print(f"DEFAULT {address}: {args}")

def auto_switch_handler(address, state, *args):
    start_time = time.time()
    print(f"Is Manual {state}")
    manualMode = state
    end_time = time.time()
    timings.append({"key": "switch", "start": start_time, "end": end_time})

def manual_set(address, element, *args):
    start_time = time.time()
    state = args[::3]
    low = args[1::3]
    high = args[2::3]
    agents[element].set_state(state=state, lows=low, highs=high, history=True)
    end_time = time.time()
    timings.append({"key": "manual_set", "start": start_time, "end": end_time})
    #print(f"{element}: {agents[element].state}")

def manual_update(address, element, *args):
    start_time = time.time()
    state = args[::3]
    low = args[1::3]
    high = args[2::3]
    agents[element].set_state(state=state, lows=low, highs=high, history=False)
    end_time = time.time()
    timings.append({"key": "manual_update", "start": start_time, "end": end_time})

def step(address, element):
    start_time = time.time()
    old_state = agents[element].state
    action = agents[element].select_epsilon_greedy_action()
    if action is not None:
        #print(f"{element}: Taking action {action}")
        agents[element].apply_action(action)
        #print(f"Transitioned from {old_state} to {agent.state}")
        client.send_message("/controller/agentSet", [element, *agents[element].state])
        agents[element].replay_from_history()
    else:
        print(f"{element}: All actions excluded! Doing nothing.")
    end_time = time.time()
    timings.append({"key": "step", "start": start_time, "end": end_time})

def reward(address, element, reward, modality=None):
    start_time = time.time()
    
    if modality:
        
        agents[element].process_guiding_reward(reward, modality)
    else:
        agents[element].process_guiding_reward(reward)
    end_time = time.time()
    timings.append({"key": "guidance", "start": start_time, "end": end_time})
    # print(f"Weights updated from {old_weights} to {agent._weights}")

def zone_reward(address, element, reward):
    # Calculate length N_STEPS away on each axis, store in agent
    start_time = time.time()
    agents[element].process_zone_reward(reward)
    end_time = time.time()
    timings.append({"key": "zone", "start": start_time, "end": end_time})
    
def activate(address, element, dimension, activation):
    print(f"{element}: Setting dimension {dimension} to {activation}")
    agents[element].update_activation(dimension, activation)
    print(f"{agents[element]._exclude_dims}")

def init(address, element, ndims, step):
    if element in agents:
        print(f"Replacing agent {element} with fresh. {ndims} dimensions, initial step {step} (norm)")
    else:
        print(f"New agent {element} with {ndims} dimensions, initial step {step} (norm)")
    #agents[element] = LinearSGDAgent(ndims, step)
    if agentType == "joint":
        agents[element] = NeuralSGDAgent(ndims, step, crediter=UniformCrediter)
    elif agentType == "split":
        agents[element] = SplitNeuralSGDAgent(haptic_dims, ndims - haptic_dims, step, crediter=UniformCrediter)
    elif agentType == "random":
        agents[element] = RandomAgent(ndims, step)

def delete(address, element):
    if element in agents:
        print(f"Deleting agent {element} ({agents[element]._ndims} dimensions)")
        del agents[element]
    else:
        print(f"No agent with identifier {element}!")

dispatcher = Dispatcher()
dispatcher.set_default_handler(default_handler)
dispatcher.map("/uistate/setAutonomous", auto_switch_handler)
dispatcher.map("/controller/manualSet", manual_set)
dispatcher.map("/controller/updateManual", manual_update)
dispatcher.map("/controller/step", step)
dispatcher.map("/controller/reward", reward)
dispatcher.map("/controller/activate", activate)
dispatcher.map("/controller/init", init)
dispatcher.map("/controller/zone_reward", zone_reward)

ip = "127.0.0.1" # localhost
port = 8080

with ThreadingOSCUDPServer((ip, port), dispatcher) as server:
    def quit_func(address, *args):
        print("Quit!")
        server.shutdown()
        server.server_close()
    dispatcher.map("/quit", quit_func)
    thread = Thread(target=server.serve_forever)
    thread.start()
    thread.join()
print("And we're out!")