Permalink
Fetching contributors…
Cannot retrieve contributors at this time
546 lines (457 sloc) 16.9 KB
// HELPER FUNCTIONS
var hasUtilityPriorBelief = function(agentParams) {
if (_.has(agentParams, 'utility') && _.has(agentParams, 'priorBelief')) {
var utility = agentParams.utility;
var priorBelief = agentParams.priorBelief;
return _.isFunction(utility);
}
return false;
};
var hasTransitionObserve = function(world) {
if (_.has(world, 'transition') && _.has(world, 'observe')) {
var transition = world.transition;
var observe = world.observe;
return _.isFunction(transition) && _.isFunction(observe);
}
return false;
};
// Test whether belief has right form for fastUpdateBelief, i.e. it's not
// an Dist but an object with 'manifestState' and 'latentStateDist' properties
var isDistOverLatent = function(belief) {
return (
_.has(belief, 'manifestState') &&
_.has(belief, 'latentStateDist') &&
!_.has(belief, 'support') &&
belief.latentStateDist.sample);
};
// Convert priorBelief from Dist over full states to Dist over latents
var toDistOverLatent = function(belief) {
var manifestState = belief.support()[0].manifestState;
var latentStateDist = Infer({
method: 'enumerate'
}, function() {
return sample(belief).latentState;
})
return {
manifestState: manifestState,
latentStateDist: latentStateDist
};
};
// Construct functions needed for running the POMDP agent
// Most of the code is for the manifest-latent optimization
var getPOMDPFunctions = function(params, world) {
// Sample a state from belief: works for general states and manifestLatent states.
var sampleBelief = function(belief) {
if (isDistOverLatent(belief)) {
return {
manifestState: belief.manifestState,
latentState: sample(belief.latentStateDist)
};
} else {
return sample(belief);
};
};
// ----------------------------
// Manifest-latent Optimization
// Helper functions
var getBeliefToActions = function(manifestStateToActions) {
return function(belief) {
var manifestState = isDistOverLatent(belief) ? belief.manifestState :
sample(belief).manifestState;
return manifestStateToActions(manifestState);
};
};
var getFullObserve = function(observeLatent) {
return function(state) {
return {
manifestState: state.manifestState,
observation: observeLatent(state)
};
};
};
// Agent gets an obervation in starting state without having taken an action
var isNullAction = function(action) {
return action == 'noAction';
};
// Optimized belief update for manifest/latent POMDPs
var getUpdateBeliefManifestLatent = function(transition, observe) {
var isDeltaDist = function(dist) {
return dist.support().length === 1;
};
return dp.cache(
function(belief_, observation, action) {
// If belief has form of "Dist over {manifest:, latent:}" then we convert
// to form of "{manifest:, Dist over latent}".
var belief = isDistOverLatent(belief_) ? belief_ : toDistOverLatent(belief_);
// we set the manifest state to the observed manifest state
var newManifestState = observation.manifestState;
// Don't update latent when you don't need to
if (observation.observation === 'noObservation' ||
isDeltaDist(belief.latentStateDist)) {
return extend(belief, {
manifestState: newManifestState
});
} else {
// Sample latent, update state, compare predicted observation from state
// to actual observation.
var newLatentDist = Infer({
method: 'enumerate'
}, function() {
var latentState = sample(belief.latentStateDist);
var state = {
manifestState: belief.manifestState,
latentState
};
var predictedNextState = isNullAction(action) ? state : transition(state, action);
var predictedObservation = observe(predictedNextState);
condition(_.isEqual(predictedObservation, observation));
return latentState;
});
return {
manifestState: newManifestState,
latentStateDist: newLatentDist
};
}
});
};
// updateBelief function for general POMDPs (no manifest/latent structure)
var getUpdateBeliefSimple = function(transition, observe) {
return dp.cache(
function(belief, observation, action) {
return Infer({
method: 'enumerate'
}, function() {
var state = sample(belief);
var predictedNextState = isNullAction(action) ? state : transition(state, action);
var predictedObservation = observe(predictedNextState);
condition(_.isEqual(predictedObservation, observation));
return predictedNextState;
});
});
};
var useManifestLatent = !_.isUndefined(params.useManifestLatent) ?
params.useManifestLatent :
stateHasManifestLatent(sampleBelief(params.priorBelief))
if (useManifestLatent) {
assert.ok(stateHasManifestLatent(sampleBelief(params.priorBelief)),
'state is not manifestLatent in form')
}
if (useManifestLatent) {
var observe = getFullObserve(world.observe);
return {
sampleBelief: sampleBelief,
useManifestLatent: true,
observe: observe,
beliefToActions: getBeliefToActions(world.manifestStateToActions),
updateBelief: getUpdateBeliefManifestLatent(world.transition, observe),
shouldTerminate: function(state) {
return state.manifestState.terminateAfterAction;
}
};
};
return {
sampleBelief: sampleBelief,
useManifestLatent: false,
observe: world.observe,
beliefToActions: world.beliefToActions,
updateBelief: getUpdateBeliefSimple(world.transition, world.observe),
shouldTerminate: function(state) {
return state.terminateAfterAction;
}
};
};
// --------------------------------------------------------------------
// Optimal POMDP agent
var makePOMDPAgentOptimal = function(params, world) {
assert.ok(hasUtilityPriorBelief(params) && hasTransitionObserve(world),
'makePOMDPAgent params and world');
// set defaults
var defaults = {
recurseOnStateOrBelief: 'belief',
alpha: 1000,
};
var params = extend(defaults, params);
var utility = params.utility;
var transition = world.transition;
var POMDPFunctions = getPOMDPFunctions(params, world);
var observe = POMDPFunctions.observe;
var beliefToActions = POMDPFunctions.beliefToActions;
var updateBelief = POMDPFunctions.updateBelief;
var shouldTerminate = POMDPFunctions.shouldTerminate;
var sampleBelief = POMDPFunctions.sampleBelief;
// RECURSE ON BELIEF (BELLMAN STYLE)
var act = dp.cache(
function(belief) {
return Infer({
method: 'enumerate'
}, function() {
var action = uniformDraw(beliefToActions(belief));
var eu = expectedUtilityBelief(belief, action);
factor(params.alpha * eu);
return action;
});
});
var expectedUtilityBelief = dp.cache(
function(belief, action) {
return expectation(
Infer({
method: 'enumerate'
}, function() {
var state = sampleBelief(belief);
var u = utility(state, action);
assert.ok(_.isFinite(utility(state, action)), 'utility is not finite. state: ' +
JSON.stringify(state));
if (shouldTerminate(state)) {
return u;
} else {
var nextState = transition(state, action);
var nextObservation = observe(nextState);
var nextBelief = updateBelief(belief, nextObservation, action);
var nextAction = sample(act(nextBelief));
var futureU = expectedUtilityBelief(nextBelief, nextAction);
return u + futureU;
}
}));
});
// RECURSE ON STATE (FIXES THE LATENT STATE)
var act_recState = dp.cache(
function(belief) {
return Infer({
method: 'enumerate'
}, function() {
var action = uniformDraw(beliefToActions(belief));
var eu = expectedUtilityBelief_recState(belief, action);
factor(params.alpha * eu);
return action;
});
});
var expectedUtilityBelief_recState = dp.cache(
function(belief, action) {
return expectation(
Infer({
method: 'enumerate'
}, function() {
var state = sampleBelief(belief);
return expectedUtilityState_recState(belief, state, action);
}));
});
var expectedUtilityState_recState = dp.cache(
function(belief, state, action) {
return expectation( // need this for caching
Infer({
method: 'enumerate'
}, function() {
var u = utility(state, action);
assert.ok(_.isFinite(utility(state, action)), 'utility is not finite. state: ' +
JSON.stringify(state));
if (shouldTerminate(state)) {
return u;
} else {
var nextState = transition(state, action);
var nextObservation = observe(nextState);
var nextBelief = updateBelief(belief, nextObservation, action);
var nextAction = sample(act(nextBelief));
var futureU = expectedUtilityState_recState(nextBelief, nextState, nextAction);
return u + futureU;
}
}));
});
var act = params.recurseOnStateOrBelief === 'belief' ? act : act_recState;
var expectedUtilityBelief = params.recurseOnStateOrBelief === 'belief' ?
expectedUtilityBelief : expectedUtilityBelief_recState;
return {
act: act,
expectedUtilityBelief: expectedUtilityBelief,
params: params,
updateBelief: updateBelief,
POMDPFunctions: POMDPFunctions,
};
};
// --------------------------------------------------------------------
// Sub-optimal POMDP agent
var makePOMDPAgentDelay = function(params, world) {
assert.ok(hasUtilityPriorBelief(params) && hasTransitionObserve(world),
'makePOMDPAgent params and world');
var defaults = {
alpha: 1000,
recurseOnStateOrBelief: 'belief',
discount: 0,
sophisticatedOrNaive: 'naive',
noDelays: true,
updateMyopic: false,
rewardMyopic: false
};
var params = extend(defaults, params);
if (params.rewardMyopic || params.updateMyopic) {
assert.ok(params.noDelays === false && params.sophisticatedOrNaive === 'naive',
'rewardMyopic and updateMyopic require Naive agent with delays');
}
assert.ok(params.rewardMyopic === false || params.updateMyopic === false,
"one of rewardMyopic and updateMyopic must be false");
// Variables for methods
var transition = world.transition;
var utility = params.utility;
var POMDPFunctions = getPOMDPFunctions(params, world);
var observe = POMDPFunctions.observe;
var beliefToActions = POMDPFunctions.beliefToActions;
var sampleBelief = POMDPFunctions.sampleBelief;
var _updateBelief = POMDPFunctions.updateBelief;
var updateBelief = function(belief, observation, action, delay) {
if (params.updateMyopic && (delay > params.updateMyopic.bound)) {
// update manifestState (assuming no possibility that isNullAction(action))
var nextBelief = Infer({
method: 'enumerate'
}, function() {
var state = sampleBelief(belief);
return transition(state, action);
});
return POMDPFunctions.useManifestLatent ? toDistOverLatent(nextBelief) : nextBelief;
} else {
return _updateBelief(belief, observation, action);
};
};
// Update the *delay* parameter in *expectedUtility* for sampling
// actions and future utilities
var transformDelay = function(delay) {
var table = {
naive: delay + 1,
sophisticated: 0
};
return params.noDelays ? 0 : table[params.sophisticatedOrNaive];
};
var incrementDelay = function(delay) {
return params.noDelays ? 0 : delay + 1;
};
// Define the discount function to be used
var discountFunction = params.discountFunction ? params.discountFunction :
function(delay) {
return 1.0 / (1 + params.discount * delay);
};
// Termination condition for *expectedUtility*
var shouldTerminate = function(state, delay) {
var terminateAfterAction = POMDPFunctions.useManifestLatent ?
state.manifestState.terminateAfterAction :
state.terminateAfterAction;
if (terminateAfterAction) {
return true;
}
if (params.rewardMyopic) {
return delay >= params.rewardMyopic.bound;
}
return false;
};
// RECURSE ON BELIEF (BELLMAN STYLE)
var act = dp.cache(
function(belief, delay) {
assert.ok(_.isFinite(delay), 'act: delay non-finite. delay: ' + delay);
return Infer({
method: 'enumerate'
}, function() {
var action = uniformDraw(beliefToActions(belief));
var eu = expectedUtilityBelief(belief, action, delay);
factor(params.alpha * eu);
return action;
});
});
var expectedUtilityBelief = dp.cache(
function(belief, action, delay) {
return expectation(
Infer({
method: 'enumerate'
}, function() {
var state = sampleBelief(belief);
assert.ok(_.isFinite(utility(state, action)), 'utility is not finite. state: ' +
JSON.stringify(state));
var u = discountFunction(delay) * utility(state, action);
if (shouldTerminate(state, delay)) {
return u;
} else {
var nextState = transition(state, action);
var nextObservation = observe(nextState);
var transformedDelay = transformDelay(delay);
var nextBelief = updateBelief(belief, nextObservation, action, transformedDelay);
var nextAction = sample(act(nextBelief, transformedDelay));
var futureU = expectedUtilityBelief(nextBelief, nextAction, incrementDelay(delay));
return u + futureU;
}
}));
});
// RECURSE ON STATE (FIXES THE LATENT STATE)
var act_recState = dp.cache(
function(belief, delay) {
assert.ok(_.isFinite(delay), 'act: delay non-finite. delay: ' + delay);
return Infer({
method: 'enumerate'
}, function() {
var action = uniformDraw(beliefToActions(belief));
var eu = expectedUtilityBelief_recState(belief, action, delay);
factor(params.alpha * eu);
return action;
});
});
var expectedUtilityBelief_recState = dp.cache(
function(belief, action, delay) {
return expectation(
Infer({
method: 'enumerate'
}, function() {
var state = sampleBelief(belief);
return expectedUtilityState_recState(belief, state, action, delay);
}));
});
var expectedUtilityState_recState = dp.cache(
function(belief, state, action, delay) {
return expectation( // need this for caching
Infer({
method: 'enumerate'
}, function() {
assert.ok(_.isFinite(utility(state, action)), 'utility is not finite. state: ' +
JSON.stringify(state));
var u = discountFunction(delay) * utility(state, action);
if (shouldTerminate(state, delay)) {
return u;
} else {
var nextState = transition(state, action);
var nextObservation = observe(nextState);
var transformedDelay = transformDelay(delay);
var nextBelief = updateBelief(belief, nextObservation, action,
transformedDelay);
var nextAction = sample(act_recState(nextBelief, transformedDelay));
var futureU = expectedUtilityState_recState(nextBelief, nextState,
nextAction,
incrementDelay(delay));
return u + futureU;
}
}));
});
var act = params.recurseOnStateOrBelief === 'belief' ? act : act_recState;
var expectedUtility = params.recurseOnStateOrBelief === 'belief' ?
expectedUtilityBelief : expectedUtilityBelief_recState;
return { act, updateBelief, expectedUtility, params, POMDPFunctions };
};
// General POMDP wrappers
var isOptimalPOMDPAgent = function(agentParams) {
var optimalProperties = function() {
return !(_.has(agentParams, 'noDelays') ||
_.has(agentParams, 'discount') ||
_.has(agentParams, 'myopic-observations'));
}
return _.isUndefined(agentParams.optimal) ? optimalProperties() : agentParams.optimal;
};
var isManualPOMDPAgent = function(agentParams) {
return (_.has(agentParams, 'act') &&
_.has(agentParams, 'updateBelief') &&
_.has(agentParams, 'params') &&
_.has(agentParams.params, 'priorBelief'));
}
var makePOMDPAgent = function(params, world) {
if (isManualPOMDPAgent(params)) {
return extend(params, {
POMDPFunctions: getPOMDPFunctions(params.params, world)
});
} else {
return (isOptimalPOMDPAgent(params) ?
makePOMDPAgentOptimal(params, world) :
makePOMDPAgentDelay(params, world));
}
};