Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions mixtape/core/templates/core/insights.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
{% block content %}

{{ plot_data|json_script:"plot_data_json" }}
{{ step_data|json_script:"step_data_json" }}

<div x-data="{
plotData: JSON.parse(document.getElementById('plot_data_json').textContent),
stepData: JSON.parse(document.getElementById('step_data_json').textContent),
currentStep: 0,
playing: false,
maxSteps: {{ steps|length|add:'-1' }},
Expand Down Expand Up @@ -41,11 +43,10 @@ <h1 class="text-center text-2xl font-bold">Episode #{{ episode.pk }}</h1>
</div>
<div class="grid grid-cols-12">
<div class="col-span-6">
{% for step in steps %}
<h1 x-cloak x-show="currentStep === {{ forloop.counter0 }}" class="text-xl font-bold text-center">Step {{ step.number }}</h1>
<div x-cloak x-show="currentStep === {{ forloop.counter0 }}" class="grid grid-cols-12 h-[45vh] {% if step.image %}content-center{% endif %}">
<div x-show="{{ step.image|yesno:'true,false' }}" class="col-span-9 content-center m-2 h-[inherit]">
<img class="max-h-4/5 m-auto" src="{{ step.image.url }}" alt="Step Image">
<h1 class="text-xl font-bold text-center">Step <span x-text="currentStep"></span></h1>
<div class="grid grid-cols-12 h-[45vh] {% if step.image %}content-center{% endif %}">
<div x-show="{{ step.image|yesno:'false,true' }}" class="col-span-9 content-center m-2 h-[inherit]">
<img class="max-h-4/5 m-auto" :src="stepData[currentStep].image_url" alt="Step Image">
<div class="flex w-1/2 justify-evenly justify-self-center border border-black rounded-lg my-2 mx-5 py-1">
<button
class="cursor-pointer text-blue-500 hover:text-blue-800 focus:ring-3 focus:ring-blue-800 focus:text-blue-800 rounded-lg p-1 text-center disabled:text-gray-300"
Expand Down Expand Up @@ -84,21 +85,19 @@ <h1 x-cloak x-show="currentStep === {{ forloop.counter0 }}" class="text-xl font-
</button>
</div>
</div>
<div class="flex justify-evenly {% if step.image %}col-span-3 flex-col overflow-y-auto h-9/10{% else %}col-span-12 flex-row h-fit{% endif %}">
{% for agent_step in step.agent_steps.all|dictsortreversed:"reward" %}
<div class="col-span-3 flex flex-col justify-evenly overflow-y-auto h-9/10">
<template x-for="agent_step in stepData[currentStep].agent_steps || []">
<div class="border border-gray-900 rounded-lg m-1 p-1">
<p class="flex justify-evenly">
<p><strong>Agent:</strong> {{ agent_step.agent }}</p>
<p><strong>Action:</strong> {{ agent_step.action_string }}</p>
<p><strong>Reward:</strong> {{ agent_step.reward|floatformat }}</p>
<p><strong>Agent:</strong> <span x-text="agent_step.agent"></span></p>
<p><strong>Action:</strong> <span x-text="agent_step.action"></span></p>
<p><strong>Reward:</strong> <span x-text="agent_step.reward"></span></p>
</p>
</div>
{% empty %}
<p>End of episode</p>
{% endfor %}
</template>
<p x-show="!stepData[currentStep]?.agent_steps?.length" class="text-center">End of episode</p>
</div>
</div>
{% endfor %}
<div class="bg-gray-100 rounded-lg">
<ol class="items-center flex p-2">
<div class="flex w-5 bg-gray-700 h-6 ml-2"></div>
Expand Down
60 changes: 40 additions & 20 deletions mixtape/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,79 @@
from itertools import accumulate
from typing import Any

from django.db.models import Count, Sum
from django.db.models import Sum
from django.http import HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, render

from mixtape.core.models.agent_step import AgentStep
from mixtape.core.models.episode import Episode
from mixtape.core.models.step import Step
from mixtape.environments.mappings import action_maps


def insights(request: HttpRequest, episode_pk: int) -> HttpResponse:
# Prefetch all related data in a single query
episode = get_object_or_404(
Episode.objects.select_related('inference_request__checkpoint__training_request'),
Episode.objects.select_related(
'inference_request__checkpoint__training_request'
).prefetch_related('steps', 'steps__agent_steps'),
pk=episode_pk,
)
steps = Step.objects.prefetch_related('agent_steps').filter(episode=episode)
agent_steps = AgentStep.objects.filter(step__episode_id=episode_pk)
agent_steps_aggregation = agent_steps.annotate(
total_rewards=Sum('reward'),
reward_frequency=Count('reward'),
action_frequency=Count('action'),
).order_by('action', 'reward')

# Prepare step data
step_data = {
step.number: {
'image_url': step.image.url,
'agent_steps': [
{
'agent': agent_step.agent,
'action': agent_step.action_string,
'reward': agent_step.reward,
}
for agent_step in step.agent_steps.all()
],
}
for step in episode.steps.all()
}

# Prepare plot data
env_name = episode.inference_request.checkpoint.training_request.environment
plot_data: dict[str, Any] = {
# dict mapping agent (str) to action (str) to total reward (float)
'action_v_reward': defaultdict(lambda: defaultdict(float)),
# all reward values received over the episode (list of floats)
'reward_histogram': [a.reward for a in agent_steps],
'reward_histogram': [
a.reward for step in episode.steps.all() for a in step.agent_steps.all()
],
# dict mapping action (str) to freuency of action (int)
'action_v_frequency': defaultdict(int),
}

action_map = action_maps.get(env_name, {})
for entry in agent_steps_aggregation:
action = action_map.get(int(entry.action), f'{entry.action}')
plot_data['action_v_reward'][entry.agent][action] += entry.total_rewards
plot_data['action_v_frequency'][action] += entry.action_frequency
for step in episode.steps.all():
for agent_step in step.agent_steps.all():
action = action_map.get(int(agent_step.action), f'{agent_step.action}')
plot_data['action_v_reward'][agent_step.agent][action] += agent_step.reward
plot_data['action_v_frequency'][action] += 1

key_steps = steps.annotate(total_rewards=Sum('agent_steps__reward', default=0)).order_by(
'number'
key_steps = (
episode.steps.all()
.annotate(total_rewards=Sum('agent_steps__reward', default=0))
.order_by('number')
)
# list of cumulative rewards over time
plot_data['rewards_over_time'] = list(accumulate(ks.total_rewards for ks in key_steps))
# TODO: Revist this. This exists more as a placeholder, timeline
# should represent points of interest with more meaning.
# list of steps with rewards greater than 0
timeline_steps = key_steps.filter(total_rewards__gt=0)

return render(
request,
'core/insights.html',
{
'episode': episode,
'steps': steps,
'steps': episode.steps.all(),
'plot_data': plot_data,
'timeline_steps': timeline_steps,
'step_data': step_data,
},
)

Expand Down