In [1]:
import os
import sys
import torch as t
from torch import Tensor
import einops as e
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
import typeguard

# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part1_ray_tracing"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from part1_ray_tracing.utils import render_lines_with_plotly, setup_widget_fig_ray, setup_widget_fig_triangle
import part1_ray_tracing.tests as tests

MAIN = __name__ == "__main__"

In [2]:
def make_rays_1d(num_pixels: int, y_limit: float) -> t.Tensor:
    '''
    num_pixels: The number of pixels in the y dimension. Since there is one ray per pixel, this is also the number of rays.
    y_limit: At x=1, the rays should extend from -y_limit to +y_limit, inclusive of both endpoints.

    Returns: shape (num_pixels, num_points=2, num_dim=3) where the num_points dimension contains (origin, direction) and the num_dim dimension contains xyz.

    Example of make_rays_1d(9, 1.0): [
        [[0, 0, 0], [1, -1.0, 0]],
        [[0, 0, 0], [1, -0.75, 0]],
        [[0, 0, 0], [1, -0.5, 0]],
        ...
        [[0, 0, 0], [1, 0.75, 0]],
        [[0, 0, 0], [1, 1, 0]],
    ]
    '''
    a = t.zeros(num_pixels, 1, 3)
    y_values = t.FloatTensor([-y_limit+(x/(num_pixels-1))*2*y_limit for x in range(num_pixels)])
    b = t.cat((t.ones(num_pixels, 1, 1), e.repeat(y_values, "a -> a 1 1"), t.zeros(num_pixels, 1, 1)), -1)
    return t.cat((a, b), 1)
    

rays1d = make_rays_1d(9, 10.0)

fig = render_lines_with_plotly(rays1d)

In [3]:
fig = setup_widget_fig_ray()
display(fig)

@interact
def response(seed=(0, 10, 1), v=(-2.0, 2.0, 0.01)):
    t.manual_seed(seed)
    L_1, L_2 = t.rand(2, 2)
    P = lambda v: L_1 + v * (L_2 - L_1)
    x, y = zip(P(-2), P(2))
    with fig.batch_update(): 
        fig.data[0].update({"x": x, "y": y}) 
        fig.data[1].update({"x": [L_1[0], L_2[0]], "y": [L_1[1], L_2[1]]}) 
        fig.data[2].update({"x": [P(v)[0]], "y": [P(v)[1]]})

FigureWidget({
    'data': [{'type': 'scatter', 'uid': 'dee83445-ca57-42f3-ae60-7ca0bb21c982', 'x': [], 'y': []},
             {'marker': {'size': 12},
              'mode': 'markers',
              'type': 'scatter',
              'uid': '4f793a69-9a36-42fa-91f2-822d7b6f967d',
              'x': [],
              'y': []},
             {'marker': {'size': 12, 'symbol': 'x'},
              'mode': 'markers',
              'type': 'scatter',
              'uid': 'ef4f2c1c-a771-4e1b-9bef-e55f9b02482a',
              'x': [],
              'y': []}],
    'layout': {'height': 500,
               'showlegend': False,
               'template': '...',
               'width': 600,
               'xaxis': {'range': [-1.5, 2.5]},
               'yaxis': {'range': [-1.5, 2.5]}}
})

interactive(children=(IntSlider(value=5, description='seed', max=10), FloatSlider(value=0.0, description='v', …

In [4]:
segments = t.tensor([
    [[1.0, -12.0, 0.0], [1, -6.0, 0.0]], 
    [[0.5, 0.1, 0.0], [0.5, 1.15, 0.0]], 
    [[2, 12.0, 0.0], [2, 21.0, 0.0]]
])
fig = render_lines_with_plotly(t.cat((segments, rays1d), 0))

In [5]:
# @jaxtyped(typeguard.typechecked)
def intersect_ray_1d(ray: Float[Tensor, "2 3"], segment: Float[Tensor, "2 3"]) -> bool:
    '''
    ray: shape (n_points=2, n_dim=3)  # O, D points
    segment: shape (n_points=2, n_dim=3)  # L_1, L_2 points

    Return True if the ray intersects the segment.
    '''
    ray = ray[:, 0:2]
    segment = segment[:, 0:2]
    direction = segment[0]-segment[1]
    stacked = t.stack((ray[1], direction), 1)
    try:
        result = t.linalg.solve(stacked, segment[0]-ray[0])
    except RuntimeError:
        return False
    return (result[0] >= 0 and result[1] <= 1 and result[1] >= 0).item()


tests.test_intersect_ray_1d(intersect_ray_1d)
tests.test_intersect_ray_1d_special_case(intersect_ray_1d)

All tests in `test_intersect_ray_1d` passed!
All tests in `test_intersect_ray_1d_special_case` passed!



As of jaxtyping version 0.2.24, jaxtyping now prefers the syntax
```
from jaxtyping import jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker

@jaxtyped(typechecker=typechecker)
def foo(...):
```
and the old double-decorator syntax
```
@jaxtyped
@typechecker
def foo(...):
```
should no longer be used. (It will continue to work as it did before, but the new approach will produce more readable error messages.)
In particular note that `typechecker` must be passed via keyword argument; the following is not valid:
```
@jaxtyped(typechecker)
def foo(...):
```




In [6]:
x = t.randn(2, 3)
x_repeated = e.repeat(x, 'a b -> a b c', c=4)

assert x_repeated.shape == (2, 3, 4)
for c in range(4):
    t.testing.assert_close(x, x_repeated[:, :, c])

In [50]:
def intersect_rays_1d(rays: Float[Tensor, "nrays 2 3"], segments: Float[Tensor, "nsegments 2 3"]) -> Bool[Tensor, "nrays"]:
    '''
    For each ray, return True if it intersects any segment.
    '''

    nrays = rays.shape[0]
    nsegments = segments.shape[0]

    rays2 = e.repeat(rays, "a b c -> a nsegments b c", nsegments=nsegments)
    segments2 = e.repeat(segments, "a b c -> nrays a b c", nrays=nrays)


    O = rays2[:, :, 0, :2]
    D = rays2[:, :, 1, :2]

    L1 = segments2[:, :, 0, :2]
    L2 = segments2[:, :, 1, :2]


    leftarray = t.stack((D, L1-L2), -1)
    
    rightarray = L1-O

    identity = e.repeat(t.eye(2,2), "a b -> nrays nsegments a b", nrays=nrays, nsegments=nsegments)

    mask = t.linalg.det(leftarray).abs() < 1e-6

    mask = e.repeat(mask, "a b -> a b 2 2")

    leftarray = t.where(mask, identity, leftarray)

    result = t.linalg.solve(leftarray, rightarray)
    
    result = (result[:, :, 0] >= 0) & (result[:, :, 1] <= 1) & (result[:, :, 1] >= 0)

    result = t.where(mask[:, :, 0, 0], False, result)

    result = e.reduce(result.long(), "a b -> a", "sum") > 0.5

    return result


tests.test_intersect_rays_1d(intersect_rays_1d)
tests.test_intersect_rays_1d_special_case(intersect_rays_1d)

All tests in `test_intersect_rays_1d` passed!
All tests in `test_intersect_rays_1d_special_case` passed!
