Skip to content
8 changes: 8 additions & 0 deletions mixtape/core/models/agent_step.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.db import models

from mixtape.core.ray_utils.logger import NumpyJSONEncoder
from mixtape.environments.mappings import action_maps

from .step import Step

Expand All @@ -15,3 +16,10 @@ class Meta:
action = models.FloatField()
reward = models.FloatField()
observation_space = models.JSONField(encoder=NumpyJSONEncoder)

@property
def action_string(self) -> str:
# Note: "select_related" should be called on any AgentStep where this is used, otherwise
# this property can create very inefficient queries
environment = self.step.episode.inference_request.checkpoint.training_request.environment
Comment thread
bnmajor marked this conversation as resolved.
return action_maps[environment].get(int(self.action), f'{self.action}')
1 change: 1 addition & 0 deletions mixtape/core/templates/core/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://unpkg.com/@tailwindcss/browser@4"></script>
<script src="//unpkg.com/alpinejs" defer></script>
<script src="//unpkg.com/plotly.js@3.0.1/dist/plotly.min.js" charset="utf-8"></script>
<title>{% block title %}MIXTAPE{% endblock %}</title>
<style>
[x-cloak] {
Expand Down
228 changes: 191 additions & 37 deletions mixtape/core/templates/core/insights.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,202 @@
{% block title %}Inference Results{% endblock %}

{% block content %}
<div x-data="{ currentStep: 0 }">

{{ plot_data|json_script:"plot_data_json" }}

<div x-data="{
plotData: JSON.parse(document.getElementById('plot_data_json').textContent),
currentStep: 0,
playing: false,
maxSteps: {{ steps|length|add:'-1' }},
stepForward() {
if (!this.playing) return;
if (this.currentStep < this.maxSteps) {
setTimeout(() => {
this.currentStep++;
this.stepForward()
}, 100);
} else {
this.playing = false;
this.currentStep = 0;
}
}
}">
<h1 class="text-center text-2xl font-bold">Episode #{{ episode.pk }}</h1>
{% for step in steps %}
<div
x-cloak
x-show="currentStep === {{ forloop.counter0 }}"
class="flex items-start gap-5 text-center h-[90vh]"
>
<div class="content-center w-1/2 h-full min-h-full">
<img class="max-w-full h-auto" src="{{ step.image.url }}" alt="Step Image">
<div class="grid grid-cols-12">
<div class="col-span-3 content-center" x-data="{
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it help to split each of these plots into its own template and {% include them? I'm not sure if that would be more readable, but I think this file is getting pretty big.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that would make sense! I did not do that in this PR, but my WIP to produce aggregated results across many episodes does include changes to break out components and then include them as needed.

data: [],
layout: {
title: {text: 'Actions VS Reward Frequency'},
xaxis: {title: {text: 'Action'}},
yaxis: {title: {text: 'Reward Frequency'}},
autosize: true,
barmode: 'group',
},
plot: null,
}"
x-init="
data = Object.entries(plotData.action_v_reward).map(([agent, actions]) => ({
x: Object.keys(actions),
y: Object.values(actions),
type: 'bar',
name: agent
}));
plot = Plotly.newPlot($refs.actionsRewards, data, layout)"
>
<div x-ref="actionsRewards"></div>
</div>
<div class="col-span-3 content-center" x-data="{
data: [],
layout: {
title: {text: 'Rewards VS Frequency'},
xaxis: {title: {text: 'Reward'}},
yaxis: {title: {text: 'Frequency'}},
autosize: true
},
plot: null,
}"
x-init="
data = [{x: plotData.reward_histogram, type: 'histogram'}];
plot = Plotly.newPlot($refs.rewardsFrequency, data, layout)"
>
<div x-ref="rewardsFrequency"></div>
</div>
<div class="col-span-3 content-center" x-data="{
data: [],
layout: {
title: {text: 'Actions VS Frequency'},
autosize: true
},
plot: null,
}"
x-init="
data = [{labels: Object.keys(plotData.action_v_frequency), values: Object.values(plotData.action_v_frequency), type: 'pie'}];
plot = Plotly.newPlot($refs.actionsFrequency, data, layout)"
>
<div x-ref="actionsFrequency"></div>
</div>
<div class="border border-gray-300 p-2 w-1/2 m-auto max-h-[80vh] overflow-y-scroll">
<p class="text-xl font-bold">Step {{ step.number }}</p>
{% for agent_step in step.agent_steps.all %}
<div>
<p class="flex justify-evenly">
<span><strong>Agent:</strong> {{ agent_step.agent }}</span>
<span><strong>Action:</strong> {{ agent_step.action }}</span>
<span><strong>Reward:</strong> {{ agent_step.reward }}</span>
</p>
<p><strong>Observation Space</strong></p>
<p>
<pre class="bg-gray-100 p-3 rounded-md font-mono whitespace-pre text-justify text-sm">{{ agent_step.observation_space|pprint }}</pre>
</p>
<hr>
<div class="col-span-3 content-center" x-data="{
data: [],
layout: {
title: {text: 'Rewards Over Time'},
xaxis: {title: {text: 'Time Step'}},
yaxis: {title: {text: 'Cumulative Reward'}},
autosize: true
},
plot: null,
}"
x-init="
data = [{x: plotData.rewards_over_time.length, y: plotData.rewards_over_time, type: 'scatter'}];
plot = Plotly.newPlot($refs.rewardsOverTime, data, layout)"
>
<div x-ref="rewardsOverTime"></div>
</div>
</div>
<div class="grid grid-cols-12">
<div class="col-span-6">
{% for step in steps %}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that the step-based views are going to continue, I think we should definitely switch to a more scalable method for rendering this much data. Emitting this much HTML at once may be problematic for many users (I'm worried we're already in "it works on my machine" territory, given that developers have over-powered machines relative to many users).

#19 is one approach, though I'm happy to discuss others.

This change is out of scope for this PR, but it'd be good to get this fixed ASAP, unless we're not planning to allow the selection of individual steps (or things of equivalent cardinality) anymore.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened this up as a new issue for follow-up once this PR is in.

<h1 x-cloak x-show="currentStep === {{ forloop.counter0 }}" class="text-xl font-bold text-center">Step {{ step.number }}</h1>
<div x-cloak x-show="currentStep === {{ forloop.counter0 }}" class="grid grid-cols-12 h-[45vh]">
<div class="col-span-9 content-center m-2 h-[inherit]">
<img class="max-h-4/5 m-auto" src="{{ step.image.url }}" alt="Step Image">
<div class="flex w-1/2 justify-evenly justify-self-center border border-black rounded-lg my-2 mx-5 py-1">
<button
class="cursor-pointer text-blue-500 hover:text-blue-800 focus:ring-3 focus:ring-blue-800 focus:text-blue-800 rounded-lg p-1 text-center disabled:text-gray-300"
:disabled="currentStep === 0"
@click="currentStep = 0"
>
<!-- Rewind Icon -->
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a lot easier to use an icon library. I've used https://remixicon.com/ for several other server-rendered applications.

You should just need to add

<link href="https://unpkg.com/remixicon/fonts/remixicon.css" rel="stylesheet"/>

then you'll be able to use something like

< i class="ri-rewind-line"></i>

instead of all this SVG.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a follow-up branch with additional cleanup for styles. I will include this change in that PR.

<svg class="w-6 h-6" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" viewBox="0 0 24 24">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m17 16-4-4 4-4m-6 8-4-4 4-4"/>
</svg>
</button>
<button
class="cursor-pointer text-blue-500 hover:text-blue-800 focus:ring-3 focus:ring-blue-800 focus:text-blue-800 rounded-lg p-1 text-center disabled:text-gray-300"
:disabled="currentStep === 0"
@click="currentStep--"
>
<!-- Step Backward Icon -->
<svg class="w-6 h-6" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" viewBox="0 0 24 24">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m14 8-4 4 4 4"/>
</svg>
</button>
<button
class="cursor-pointer text-blue-500 hover:text-blue-800 focus:ring-3 focus:ring-blue-800 focus:text-blue-800 rounded-lg p-1 text-center disabled:text-gray-300"
@click="playing = !playing; if (playing) stepForward()"
>
<!-- Play Icon -->
<svg x-show="!playing" class="w-6 h-6" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" viewBox="0 0 24 24">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 18V6l8 6-8 6Z"/>
</svg>
<!-- Pause Icon -->
<svg x-show="playing" class="w-6 h-6" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" viewBox="0 0 24 24">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 6H8a1 1 0 0 0-1 1v10a1 1 0 0 0 1 1h1a1 1 0 0 0 1-1V7a1 1 0 0 0-1-1Zm7 0h-1a1 1 0 0 0-1 1v10a1 1 0 0 0 1 1h1a1 1 0 0 0 1-1V7a1 1 0 0 0-1-1Z"/>
</svg>
</button>
<button
class="cursor-pointer text-blue-500 hover:text-blue-800 focus:ring-3 focus:ring-blue-800 focus:text-blue-800 rounded-lg p-1 text-center disabled:text-gray-300"
:disabled="currentStep === maxSteps"
@click="currentStep++"
>
<!-- Step Forward Icon -->
<svg class="w-6 h-6" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" viewBox="0 0 24 24">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m10 16 4-4-4-4"/>
</svg>
</button>
<button
class="cursor-pointer text-blue-500 hover:text-blue-800 focus:ring-3 focus:ring-blue-800 focus:text-blue-800 rounded-lg p-1 text-center disabled:text-gray-300"
:disabled="currentStep === maxSteps"
@click="currentStep = maxSteps"
>
<!-- Fast Forward Icon -->
<svg class="w-6 h-6" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" viewBox="0 0 24 24">
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m7 16 4-4-4-4m6 8 4-4-4-4"/>
</svg>
</button>
</div>
</div>
<div class="col-span-3 flex flex-col justify-evenly overflow-y-auto h-9/10">
{% for agent_step in step.agent_steps.all|dictsortreversed:"reward" %}
<div class="border border-gray-900 rounded-lg m-1 p-1">
<p class="flex justify-evenly">
<p><strong>Agent:</strong> {{ agent_step.agent }}</p>
<p><strong>Action:</strong> {{ agent_step.action_string }}</p>
<p><strong>Reward:</strong> {{ agent_step.reward|floatformat }}</p>
</p>
</div>
{% empty %}
<p>End of episode</p>
{% endfor %}
</div>
{% empty %}
<p>End of episode</p>
</div>
{% endfor %}
<div class="bg-gray-100 rounded-lg">
<ol class="items-center flex p-2">
<div class="flex w-5 bg-gray-700 h-6 ml-2"></div>
{% for ts in timeline_steps %}
<li x-show="{{ ts.total_rewards }} > 0" class="relative w-full flex even:flex-col odd:flex-col-reverse even:mb-5 odd:mt-5">
<p class="text-sm text-gray-900 text-center">
{{ ts.total_rewards|floatformat }}
</p>
<div class="flex items-center">
<div class="flex w-full bg-gray-700 h-0.5"></div>
<button
class="cursor-pointer z-10 flex items-center justify-center w-2 h-2 rounded-full ring-4 ring-blue-700 shrink-0 text-blue-500 bg-blue-500 focus:ring-yellow-500"
@click="currentStep = {{ ts.number }}"
>
</button>
<div class="flex w-full bg-gray-700 h-0.5"></div>
</div>
</li>
{% endfor %}
<div class="flex w-5 bg-gray-700 h-6 mr-2"></div>
</ol>
</div>
</div>
<div class="col-span-6 border border-dashed border-gray-400 rounded-lg m-1 content-center">
<p class="text-center text-xl font-bold text-gray-400">PLACEHOLDER</p>
</div>
</div>
{% endfor %}
<div>
<input
type="range"
id="stepsRange"
min="0"
max="{{ steps|length|add:'-1' }}"
value="0"
class="w-full"
x-model.number="currentStep"
>
</div>
</div>
{% endblock %}
58 changes: 54 additions & 4 deletions mixtape/core/views.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,66 @@
from collections import defaultdict
from itertools import accumulate
from typing import Any

from django.db.models import Count, Sum
from django.http import HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, render

from mixtape.core.models.agent_step import AgentStep
from mixtape.core.models.episode import Episode
from mixtape.core.models.step import Step
from mixtape.environments.mappings import action_maps


def insights(request: HttpRequest, episode_pk: int) -> HttpResponse:
episode = get_object_or_404(Episode, pk=episode_pk)
steps = Step.objects.prefetch_related('agent_steps').filter(episode=episode_pk)
return render(request, 'core/insights.html', {'episode': episode, 'steps': steps})
episode = get_object_or_404(
Episode.objects.select_related('inference_request__checkpoint__training_request'),
pk=episode_pk,
)
steps = Step.objects.prefetch_related('agent_steps').filter(episode=episode)
agent_steps = AgentStep.objects.filter(step__episode_id=episode_pk)
Comment thread
bnmajor marked this conversation as resolved.
agent_steps_aggregation = agent_steps.annotate(
total_rewards=Sum('reward'),
reward_frequency=Count('reward'),
action_frequency=Count('action'),
).order_by('action', 'reward')

env_name = episode.inference_request.checkpoint.training_request.environment
plot_data: dict[str, Any] = {
# dict mapping agent (str) to action (str) to total reward (float)
'action_v_reward': defaultdict(lambda: defaultdict(float)),
# all reward values received over the episode (list of floats)
'reward_histogram': [a.reward for a in agent_steps],
# dict mapping action (str) to freuency of action (int)
'action_v_frequency': defaultdict(int),
}

action_map = action_maps.get(env_name, {})
for entry in agent_steps_aggregation:
action = action_map.get(int(entry.action), f'{entry.action}')
Comment thread
brianhelba marked this conversation as resolved.
plot_data['action_v_reward'][entry.agent][action] += entry.total_rewards
plot_data['action_v_frequency'][action] += entry.action_frequency

key_steps = steps.annotate(total_rewards=Sum('agent_steps__reward', default=0)).order_by(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be combined with the original steps definition, so we don't need to execute 2 separate queries when steps and key_steps are evaluated? The extra .annotate shouldn't impact the use of steps, but maybe you don't want steps to have a different ordering by number (although that seems like a sensible ordering)?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, if you think it's better to keep these logically separated to make future refactoring easier, that's a reasonable price to pay for one more query.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned, these have remained separate queries for the ease of refactoring. This "any reward was received" as a metric for determining points of interest should be dropped altogether in the future in favor of more advanced insights once they've been added.

'number'
)
plot_data['rewards_over_time'] = list(accumulate(ks.total_rewards for ks in key_steps))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can make the database do this for you with a Window function, something like

Window(Sum('total_rewards'), order_by='number')

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was only able to get this approach to work with a sub-query (apparently you cannot pass an aggregate into a Window function) and this consistently increased both my query times as well as time to load the page so I did not make this change in this PR. Open to suggestions though if I'm maybe just missing something here...

timeline_steps = key_steps.filter(total_rewards__gt=0)

return render(
request,
'core/insights.html',
{
'episode': episode,
'steps': steps,
'plot_data': plot_data,
'timeline_steps': timeline_steps,
},
)


def home_page(request: HttpRequest) -> HttpResponse:
episodes = Episode.objects.all()
episodes = Episode.objects.select_related(
'inference_request__checkpoint__training_request'
).all()
return render(request, 'core/home.html', {'episodes': episodes})
Loading