Skip to content

Commit ce20ad7

Browse files
committed
svg: Add ts.svg() support for rendering to SVG
This is substantially more straightforward than using opengl + mpeg + video tag.
1 parent 8d89190 commit ce20ad7

File tree

3 files changed

+275
-1
lines changed

3 files changed

+275
-1
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Versioning](https://semver.org/spec/v2.0.0.html) when it reaches version 1.0.
99
### Added
1010
- All geometry classes implement `len()`, which returns the number of steps in the geometry.
1111
- Convenience methods `Transform.transform_vec` and `Transform.transform_point`
12+
- `ts.svg()`: Render geometries to an animated SVG (no dependencies required(!)).
1213
### Changed
1314
### Removed
1415

tomosipo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from_astra,
3030
to_astra,
3131
)
32-
32+
from .svg import svg
3333

3434
from . import phantom
3535

tomosipo/svg.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""SVG support for tomosipo
4+
5+
"""
6+
import tomosipo as ts
7+
import tomosipo.vector_calc as vc
8+
import base64
9+
import collections
10+
import numpy as np
11+
from pathlib import Path
12+
from functools import singledispatch
13+
from tomosipo.geometry.base_projection import ProjectionGeometry
14+
from tomosipo.geometry.volume import VolumeGeometry
15+
from tomosipo.geometry.volume_vec import VolumeVectorGeometry
16+
17+
###############################################################################
18+
# Showing, saving SVG #
19+
###############################################################################
20+
21+
22+
class SVG:
23+
def __init__(self, svg_str, height=200, width=320):
24+
super().__init__()
25+
self.svg_str = svg_str
26+
self.height = height
27+
self.width = width
28+
29+
def _repr_html_(self):
30+
svg64 = base64.encodebytes(self.svg_str.encode()).decode("ascii")
31+
# https://vecta.io/blog/best-way-to-embed-svg
32+
# TODO: use "<object />"
33+
TAG = r"""<object height="{height}" width="{width}" data="data:image/svg+xml;base64,{image}" />"""
34+
35+
return TAG.format(height=self.height, width=self.width, image=svg64)
36+
37+
def save(self, path):
38+
"""Save svg to disk"""
39+
path = Path(path)
40+
path.write_text(self.svg_str)
41+
42+
def __str__(self):
43+
return self.svg_str
44+
45+
46+
###############################################################################
47+
# Data structures #
48+
###############################################################################
49+
LineItem = collections.namedtuple("LineItem", ["pos", "width", "color"])
50+
51+
52+
def line_item(pos, width=1, color=(0.0, 0.0, 0.0, 1.0)):
53+
"""Create a line item
54+
55+
:param pos: (N,3) array of floats specifying point locations.
56+
:param width: width of the line in UNITS TODO
57+
:param color: tuple of floats in [0.0-1.0] specifying a single color for the entire item in rgba format.
58+
:returns: a LineItem
59+
:rtype: LineItem
60+
61+
"""
62+
63+
pos = np.array(pos, dtype=np.float64, copy=False)
64+
assert pos.ndim == 2
65+
assert pos.shape[1] == 3
66+
67+
assert len(color) == 4
68+
69+
return LineItem(pos=pos, width=float(width), color=tuple(color))
70+
71+
72+
###############################################################################
73+
# From geometries to LineItems #
74+
###############################################################################
75+
@singledispatch
76+
def to_line_items(g, i):
77+
pass
78+
79+
80+
@to_line_items.register(VolumeVectorGeometry)
81+
def vol_vec_to_line_items(vg, i):
82+
N = len(vg)
83+
84+
corners = vg.corners[i % N]
85+
indices = [
86+
(0, 1), # <- Bottom square
87+
(2, 3), # |
88+
(4, 5), # |
89+
(6, 7), # |
90+
(0, 2), # <- Top square
91+
(1, 3), # |
92+
(4, 6), # |
93+
(5, 7), # |
94+
(0, 4), # <- Connecting (vertical) lines
95+
(1, 5), # |
96+
(2, 6), # |
97+
(3, 7), # |
98+
]
99+
100+
return [line_item(pos=np.stack((corners[a], corners[b]))) for a, b in indices]
101+
102+
103+
@to_line_items.register(VolumeGeometry)
104+
def vol_to_line_items(g, i):
105+
return to_line_items(g.to_vec(), i)
106+
107+
108+
@to_line_items.register(ProjectionGeometry)
109+
def pg_to_line_items(pg, i):
110+
pg = pg.to_vec()
111+
i = i % len(pg)
112+
113+
det_curve = line_item(pos=pg.det_pos, width=0.2)
114+
115+
# detector plane
116+
corners = pg.corners[i]
117+
det_plane_indices = [0, 1, 3, 2, 0]
118+
det_plane = line_item(pos=corners[det_plane_indices])
119+
120+
if pg.is_cone:
121+
src_pos = pg.src_pos[i]
122+
rays = [line_item(pos=np.stack((src_pos, c)), width=0.2) for c in corners]
123+
src_curve = line_item(pos=pg.src_pos, width=0.2)
124+
return [src_curve, det_curve, det_plane, *rays]
125+
if pg.is_parallel:
126+
return [det_curve, det_plane]
127+
128+
129+
###############################################################################
130+
# Projection: From 3D LineItems to 2D LineItems #
131+
###############################################################################
132+
133+
134+
def default_camera(height, width, angle=1 / 2.7):
135+
# default camera:
136+
R0 = ts.rotate(pos=0, axis=(1, 0, 0), deg=70)
137+
R1 = ts.rotate(pos=0, axis=(0, 0, 1), deg=-25)
138+
size = (1, 1 * width / height)
139+
140+
good_cone = (
141+
R0 * R1 * ts.cone(cone_angle=angle, shape=(height, width), size=size).to_vec()
142+
)
143+
camera = ts.translate(good_cone.src_pos * 10 * angle) * good_cone
144+
145+
return camera
146+
147+
148+
def project_pos(camera, pos):
149+
assert len(camera) == 1, "Camera must have single viewpoint"
150+
151+
pos = vc.to_vec(ts.utils.to_pos(pos))
152+
# Project onto camera coordinates:
153+
pos = np.array([camera.project_point(p)[0] for p in pos])
154+
# Move (0, 0) from detector center to detector lower-left corner
155+
pos += np.array(camera.det_shape)[None, :] / 2
156+
157+
# Add back z-dimension (equal to zero)
158+
pos = np.concatenate(
159+
[np.zeros((len(pos), 1)), pos],
160+
axis=1,
161+
)
162+
return pos
163+
164+
165+
def project_lines(camera, line_items):
166+
return [
167+
line_item(
168+
pos=project_pos(camera, l.pos),
169+
width=l.width,
170+
color=l.color,
171+
)
172+
for l in line_items
173+
]
174+
175+
176+
###############################################################################
177+
# From LineItems to SVG text #
178+
###############################################################################
179+
180+
181+
def text_svg_frame(
182+
line_items, frame_begin, frame_end, total_duration, height=100, width=100
183+
):
184+
polylines = []
185+
for l in line_items:
186+
# extract y and x coordinates (from zyx to xy)
187+
pos2d = l.pos[:, (2, 1)]
188+
# Correct y to move from bottom to top instead of vice versa
189+
pos2d[:, 1] = height - pos2d[:, 1]
190+
points = " ".join(f"{p:0.2f}" for p in pos2d.ravel())
191+
polyline = f'<polyline points="{points}" stroke="black" fill="none" stroke-width="{l.width:0.2f}"/>'
192+
polylines.append(polyline)
193+
lines_str = "\n".join(polylines)
194+
195+
return f"""<g>
196+
{lines_str}
197+
<animate
198+
attributeName="display"
199+
values="none;inline;none;none"
200+
keyTimes="0;{frame_begin};{frame_end};1"
201+
dur="{total_duration}s"
202+
begin="0s"
203+
repeatCount="indefinite" />
204+
</g>
205+
"""
206+
207+
208+
def text_svg_animation(line_items, duration=10, height=100, width=100):
209+
frame_duration = 1 / len(line_items)
210+
211+
const_opts = dict(total_duration=duration, height=height, width=width)
212+
213+
frames_str = "\n".join(
214+
text_svg_frame(ls, i * frame_duration, (i + 1) * frame_duration, **const_opts)
215+
for i, ls in enumerate(line_items)
216+
)
217+
# We format this string with `replace` because it is littered with {}'s.
218+
JS = r"""
219+
<script type="text/ecmascript"><![CDATA[
220+
function mouse_move(evt) {
221+
var root = document.querySelector('svg');
222+
if (evt.buttons > 0) {
223+
var root = document.querySelector('svg');
224+
root.pauseAnimations();
225+
var new_time = DURATION * evt.clientX / evt.target.getBBox().width;
226+
root.setCurrentTime(new_time);
227+
}
228+
}
229+
function on_click(evt) {
230+
console.log(document.getElementById('frame3'));
231+
var root = document.querySelector('svg');
232+
root.animationsPaused() ? root.unpauseAnimations() : root.pauseAnimations();
233+
}
234+
]]>
235+
</script>
236+
237+
""".replace(
238+
"DURATION", str(duration)
239+
)
240+
241+
return f"""<svg xmlns="http://www.w3.org/2000/svg" height="{height}" width="{width}">
242+
<title>Click to pause/unpause, press and hold to scroll through animation</title>
243+
{frames_str}
244+
{JS}
245+
<rect x="1" y="1" rx="5" ry="5" onmousemove='mouse_move(evt)' onclick='on_click(evt)' width="{width - 1}" height="{height - 1}" stroke="gray" fill="transparent" stroke-width="1"/>
246+
</svg>
247+
"""
248+
249+
250+
###############################################################################
251+
# Tying it all together: from *geometries => svg #
252+
###############################################################################
253+
def svg(*geoms, height=200, width=320, duration=3, camera=None):
254+
num_steps = max(map(len, geoms))
255+
256+
# default camera:
257+
if camera is None:
258+
c = default_camera(height, width)
259+
260+
def geoms2frame(i):
261+
# For each geometry, generate a list of line items
262+
frames_list = (project_lines(c, to_line_items(g, i)) for g in geoms)
263+
# flatten the list
264+
return [f for fs in frames_list for f in fs]
265+
266+
svg_text = text_svg_animation(
267+
[geoms2frame(i) for i in range(num_steps)],
268+
duration=duration,
269+
height=height,
270+
width=width,
271+
)
272+
273+
return SVG(svg_text, height=height, width=width)

0 commit comments

Comments
 (0)