diff --git a/mixtape/core/templates/core/insights.html b/mixtape/core/templates/core/insights.html index 3d329cf..061e790 100644 --- a/mixtape/core/templates/core/insights.html +++ b/mixtape/core/templates/core/insights.html @@ -5,9 +5,11 @@ {% block content %} {{ plot_data|json_script:"plot_data_json" }} +{{ step_data|json_script:"step_data_json" }}
Episode #{{ episode.pk }}
- {% for step in steps %} -

Step {{ step.number }}

-
-
- Step Image +

Step

+
+
+ Step Image
- {% endfor %}
    diff --git a/mixtape/core/views.py b/mixtape/core/views.py index 0ea8191..8304b8e 100644 --- a/mixtape/core/views.py +++ b/mixtape/core/views.py @@ -2,49 +2,68 @@ from itertools import accumulate from typing import Any -from django.db.models import Count, Sum +from django.db.models import Sum from django.http import HttpRequest, HttpResponse from django.shortcuts import get_object_or_404, render -from mixtape.core.models.agent_step import AgentStep from mixtape.core.models.episode import Episode -from mixtape.core.models.step import Step from mixtape.environments.mappings import action_maps def insights(request: HttpRequest, episode_pk: int) -> HttpResponse: + # Prefetch all related data in a single query episode = get_object_or_404( - Episode.objects.select_related('inference_request__checkpoint__training_request'), + Episode.objects.select_related( + 'inference_request__checkpoint__training_request' + ).prefetch_related('steps', 'steps__agent_steps'), pk=episode_pk, ) - steps = Step.objects.prefetch_related('agent_steps').filter(episode=episode) - agent_steps = AgentStep.objects.filter(step__episode_id=episode_pk) - agent_steps_aggregation = agent_steps.annotate( - total_rewards=Sum('reward'), - reward_frequency=Count('reward'), - action_frequency=Count('action'), - ).order_by('action', 'reward') + # Prepare step data + step_data = { + step.number: { + 'image_url': step.image.url, + 'agent_steps': [ + { + 'agent': agent_step.agent, + 'action': agent_step.action_string, + 'reward': agent_step.reward, + } + for agent_step in step.agent_steps.all() + ], + } + for step in episode.steps.all() + } + + # Prepare plot data env_name = episode.inference_request.checkpoint.training_request.environment plot_data: dict[str, Any] = { # dict mapping agent (str) to action (str) to total reward (float) 'action_v_reward': defaultdict(lambda: defaultdict(float)), # all reward values received over the episode (list of floats) - 'reward_histogram': [a.reward for a in agent_steps], + 'reward_histogram': [ + a.reward for step in episode.steps.all() for a in step.agent_steps.all() + ], # dict mapping action (str) to freuency of action (int) 'action_v_frequency': defaultdict(int), } - action_map = action_maps.get(env_name, {}) - for entry in agent_steps_aggregation: - action = action_map.get(int(entry.action), f'{entry.action}') - plot_data['action_v_reward'][entry.agent][action] += entry.total_rewards - plot_data['action_v_frequency'][action] += entry.action_frequency + for step in episode.steps.all(): + for agent_step in step.agent_steps.all(): + action = action_map.get(int(agent_step.action), f'{agent_step.action}') + plot_data['action_v_reward'][agent_step.agent][action] += agent_step.reward + plot_data['action_v_frequency'][action] += 1 - key_steps = steps.annotate(total_rewards=Sum('agent_steps__reward', default=0)).order_by( - 'number' + key_steps = ( + episode.steps.all() + .annotate(total_rewards=Sum('agent_steps__reward', default=0)) + .order_by('number') ) + # list of cumulative rewards over time plot_data['rewards_over_time'] = list(accumulate(ks.total_rewards for ks in key_steps)) + # TODO: Revist this. This exists more as a placeholder, timeline + # should represent points of interest with more meaning. + # list of steps with rewards greater than 0 timeline_steps = key_steps.filter(total_rewards__gt=0) return render( @@ -52,9 +71,10 @@ def insights(request: HttpRequest, episode_pk: int) -> HttpResponse: 'core/insights.html', { 'episode': episode, - 'steps': steps, + 'steps': episode.steps.all(), 'plot_data': plot_data, 'timeline_steps': timeline_steps, + 'step_data': step_data, }, )