diff --git a/mixtape/core/templates/core/insights.html b/mixtape/core/templates/core/insights.html
index 3d329cf..061e790 100644
--- a/mixtape/core/templates/core/insights.html
+++ b/mixtape/core/templates/core/insights.html
@@ -5,9 +5,11 @@
{% block content %}
{{ plot_data|json_script:"plot_data_json" }}
+{{ step_data|json_script:"step_data_json" }}
Episode #{{ episode.pk }}
- {% for step in steps %}
-
Step {{ step.number }}
-
-
-

+
Step
+
+
+
- {% endfor %}
diff --git a/mixtape/core/views.py b/mixtape/core/views.py
index 0ea8191..8304b8e 100644
--- a/mixtape/core/views.py
+++ b/mixtape/core/views.py
@@ -2,49 +2,68 @@
from itertools import accumulate
from typing import Any
-from django.db.models import Count, Sum
+from django.db.models import Sum
from django.http import HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, render
-from mixtape.core.models.agent_step import AgentStep
from mixtape.core.models.episode import Episode
-from mixtape.core.models.step import Step
from mixtape.environments.mappings import action_maps
def insights(request: HttpRequest, episode_pk: int) -> HttpResponse:
+ # Prefetch all related data in a single query
episode = get_object_or_404(
- Episode.objects.select_related('inference_request__checkpoint__training_request'),
+ Episode.objects.select_related(
+ 'inference_request__checkpoint__training_request'
+ ).prefetch_related('steps', 'steps__agent_steps'),
pk=episode_pk,
)
- steps = Step.objects.prefetch_related('agent_steps').filter(episode=episode)
- agent_steps = AgentStep.objects.filter(step__episode_id=episode_pk)
- agent_steps_aggregation = agent_steps.annotate(
- total_rewards=Sum('reward'),
- reward_frequency=Count('reward'),
- action_frequency=Count('action'),
- ).order_by('action', 'reward')
+ # Prepare step data
+ step_data = {
+ step.number: {
+ 'image_url': step.image.url,
+ 'agent_steps': [
+ {
+ 'agent': agent_step.agent,
+ 'action': agent_step.action_string,
+ 'reward': agent_step.reward,
+ }
+ for agent_step in step.agent_steps.all()
+ ],
+ }
+ for step in episode.steps.all()
+ }
+
+ # Prepare plot data
env_name = episode.inference_request.checkpoint.training_request.environment
plot_data: dict[str, Any] = {
# dict mapping agent (str) to action (str) to total reward (float)
'action_v_reward': defaultdict(lambda: defaultdict(float)),
# all reward values received over the episode (list of floats)
- 'reward_histogram': [a.reward for a in agent_steps],
+ 'reward_histogram': [
+ a.reward for step in episode.steps.all() for a in step.agent_steps.all()
+ ],
# dict mapping action (str) to freuency of action (int)
'action_v_frequency': defaultdict(int),
}
-
action_map = action_maps.get(env_name, {})
- for entry in agent_steps_aggregation:
- action = action_map.get(int(entry.action), f'{entry.action}')
- plot_data['action_v_reward'][entry.agent][action] += entry.total_rewards
- plot_data['action_v_frequency'][action] += entry.action_frequency
+ for step in episode.steps.all():
+ for agent_step in step.agent_steps.all():
+ action = action_map.get(int(agent_step.action), f'{agent_step.action}')
+ plot_data['action_v_reward'][agent_step.agent][action] += agent_step.reward
+ plot_data['action_v_frequency'][action] += 1
- key_steps = steps.annotate(total_rewards=Sum('agent_steps__reward', default=0)).order_by(
- 'number'
+ key_steps = (
+ episode.steps.all()
+ .annotate(total_rewards=Sum('agent_steps__reward', default=0))
+ .order_by('number')
)
+ # list of cumulative rewards over time
plot_data['rewards_over_time'] = list(accumulate(ks.total_rewards for ks in key_steps))
+ # TODO: Revist this. This exists more as a placeholder, timeline
+ # should represent points of interest with more meaning.
+ # list of steps with rewards greater than 0
timeline_steps = key_steps.filter(total_rewards__gt=0)
return render(
@@ -52,9 +71,10 @@ def insights(request: HttpRequest, episode_pk: int) -> HttpResponse:
'core/insights.html',
{
'episode': episode,
- 'steps': steps,
+ 'steps': episode.steps.all(),
'plot_data': plot_data,
'timeline_steps': timeline_steps,
+ 'step_data': step_data,
},
)