# Lens refraction demo

Shows refraction of light rays in a biconvex lens.
Change lens surface radius of curvature with the slider.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Arc, ConnectionPatch
from ipywidgets import widgets

In [2]:
def find_circle_line_intersection_point(p1, p2, c, r, solution=1):
    """Find the coordinates of the intersection point between a circle and a line

    Parameters:
        p1 (Point): first point on the line
        p2 (Point): second point on the line
        c (Point): center of the circle
        r (float): radius of the circle
        solution (1 or 2): solution of the quadratic equation

    Returns:
        Point: intersection point of the circle and line

    See http://mathworld.wolfram.com/Circle-LineIntersection.html
    """

    # Subtract c to move to a coordinate system, where the circle center is (0, 0)
    x1, y1 = p1 - c
    x2, y2 = p2 - c
    dx = x2 - x1
    dy = y2 - y1
    dr2 = dx**2 + dy**2
    D = x1*y2 - x2*y1
    
    discriminant = r**2 * dr2 - D**2
    if discriminant < 0:
        return Point(np.nan, np.nan)  # no intersection
    
    sqrt_discriminant = np.sqrt(discriminant)

    dy_sign = np.sign(dy)
    if dy_sign == 0:
        dy_sign = 1

    if solution == 1:
        x = (D*dy + dy_sign*dx*sqrt_discriminant) / dr2
        y = (-D*dx + abs(dy)*sqrt_discriminant) / dr2
    elif solution == 2:
        x = (D*dy - dy_sign*dx*sqrt_discriminant) / dr2
        y = (-D*dx - abs(dy)*sqrt_discriminant) / dr2
    else:
        raise ValueError('solution must be 1 or 2')

    # Add c to move back to original coordinate system
    return Point(x, y) + c


def find_refraction_angle(incidence_angle, n1, n2):
    return np.arcsin(n1/n2 * np.sin(incidence_angle))


class Lens:
    def __init__(self, height, min_thickness, r, n=1.5, show_debug_info=False):
            self.height = height
            self.min_thickness = min_thickness
            self.n = n
            self.show_debug_info = False
            
            self.r = None
            self.c1 = None
            self.c2 = None
            self.set_r(r)
            
            self.rays = []
            self.init_rays()
            
    def init_rays(self, x0=-4, max_x=4, max_y=1, ray_count=3):
        """Initialize rays
        
        Parameters:
            x0 (float): starting x coordinate of rays
            max_x (float): maximum x coordinate of rays
            max_y (float): maximum starting y coordinate of rays (minimum is -max_y)
            ray_count (int): number of rays
        """
        self.rays = []
        ys = np.linspace(-max_y, max_y, ray_count)
        for y0 in ys:
            self.rays.append(Ray(x0, y0, inclination_angle=0, max_x=max_x))
            
    def set_r(self, r):
        """Set radius of curvature of lens surfaces
        """
        if r > self.height / 2:
            self.r = r
        else:
            self.r = self.height / 2
        self.update_centers()
            
    def get_x(self):
        """Get x coordinate of the center of curvature of the lens surface.
        """
        return self.r * np.cos(self.get_theta())
    
    def get_max_x(self):
        """Get maximum x coordinate of the center of curvature of the lens surface
        taking into account the minimum thickness parameter.
        """
        return self.r - self.min_thickness / 2.0
    
    def update_centers(self):
        """Update the center of curvature points of the lens surfaces.
        """
        x = min(self.get_x(), self.get_max_x())
        y = 0
        self.c1 = Point(x, y)
        self.c2 = Point(-x, y)
      
    def get_theta(self):
        """Get angle between the main optical axis and the line
        connecting the lens edge and center of curvature.
        """
        return np.arcsin(self.height/(2*self.r))

    def get_patches(self):
        """Get lens contours as matplotlib patches for plotting
        """
        x = self.get_x()
        max_x = self.get_max_x()
        theta_degrees = np.degrees(self.get_theta())
        patches = []
        patches.append(Arc(xy=tuple(self.c1), width=2*self.r, height=2*self.r,
                           angle=180, theta1=-theta_degrees, theta2=theta_degrees))
        patches.append(Arc(xy=tuple(self.c2), width=2*self.r, height=2*self.r,
                           angle=0, theta1=-theta_degrees, theta2=theta_degrees))
        if x > max_x:
            h_2 = self.height/2
            dx = max_x - x
            patches.append(ConnectionPatch(xyA=(-dx, h_2), xyB=(dx, h_2), coordsA='data'))
            patches.append(ConnectionPatch(xyA=(-dx, -h_2), xyB=(dx, -h_2), coordsA='data'))
        return patches
    
    def trace_ray(self, ray):
        # Reset ray points
        ray.xs = [ray.xs[0]]
        ray.ys = [ray.ys[0]]
        
        # Find intersection point with first lens surface
        line_point1 = Point(ray.xs[0], ray.ys[0])
        line_point2 = Point(ray.xs[0]+1, ray.ys[0])
        hit_point = find_circle_line_intersection_point(p1=line_point1, p2=line_point2,
                                                        c=self.c1, r=self.r, solution=2)
        
        # Ray does not hit the lens
        if (abs(hit_point.y) > self.height/2) or (tuple(hit_point) == (np.nan, np.nan)):
            ray.xs.append(ray.max_x)
            ray.ys.append(ray.ys[0])
            return
        
        # Ray hits the lens -- add hit point
        ray.xs.append(hit_point.x)
        ray.ys.append(hit_point.y)
        
        # Calculate refraction on first lens surface
        inclination_angle = 0.0
        hit_point_polar_angle = np.arcsin(hit_point.y / self.r)
        incidence_angle = hit_point_polar_angle + inclination_angle
        refraction_angle = find_refraction_angle(incidence_angle, n1=1, n2=self.n)
        inclination_angle += refraction_angle - incidence_angle
        if self.show_debug_info:
            print('incidence angle 1', np.degrees(incidence_angle))
            print('refraction angle 1', np.degrees(refraction_angle))
        
        # Find intersection point with second lens surface
        line_point1 = hit_point
        line_point2 = hit_point + Point(1., np.tan(inclination_angle))
        solution = 2 if inclination_angle < 0 else 1
        hit_point = find_circle_line_intersection_point(p1=line_point1, p2=line_point2,
                                                        c=self.c2, r=self.r, solution=solution)
        ray.xs.append(hit_point.x)
        ray.ys.append(hit_point.y)
        
        # Calculate refraction on second lens surface
        hit_point_polar_angle = np.arcsin(hit_point.y / self.r)
        incidence_angle = hit_point_polar_angle - inclination_angle
        refraction_angle = find_refraction_angle(incidence_angle, n1=self.n, n2=1)
        inclination_angle -= refraction_angle - incidence_angle
        if self.show_debug_info:
            print('incidence angle 2', np.degrees(incidence_angle))
            print('refraction angle 2', np.degrees(refraction_angle))
            print()
       
        ray.xs.append(ray.max_x)
        ray.ys.append(hit_point.y + np.tan(inclination_angle)*(ray.max_x - hit_point.x))
        
    def trace_rays(self):
        for ray in self.rays:
            self.trace_ray(ray)

    def draw(self, ax=None):
        if ax is None:
            fig, ax = plt.subplots()
            ax.set_aspect('equal', 'box')
        
        # Draw lens
        ax.patches = []
        for p in self.get_patches():
            ax.add_patch(p)
            
        # Draw rays
        self.trace_rays()
        ax.lines = []
        for ray in self.rays:
            ray.draw(ax=ax)
            
    
class Ray:
    def __init__(self, x0, y0, inclination_angle, max_x):
        self.xs = [x0]
        self.ys = [y0]
        self.inclination_angle = inclination_angle
        self.max_x = max_x
        
    def draw(self, ax):
        ax.plot(self.xs, self.ys, color='r')

        
class Point:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __add__(self, other):
        if not isinstance(other, Point):
            return NotImplemented
        return Point(self.x + other.x, self.y + other.y)
    
    def __sub__(self, other):
        if not isinstance(other, Point):
            return NotImplemented
        return Point(self.x - other.x, self.y - other.y)
    
    def __neg__(self):
        return Point(-self.x, -self.y)
    
    def __repr__(self):
        return "({}, {})".format(self.x, self.y)
    
    def __iter__(self):
        for attr in dir(self):
            if not attr.startswith("__"):
                yield self.__dict__[attr]

In [3]:
%matplotlib widget

In [4]:
lens = Lens(height=3.0, min_thickness=0.5, r=2)
lens.init_rays(x0=-2, max_x=6, max_y=1, ray_count=5)

fig, ax = plt.subplots()
ax.set_aspect('equal', 'box')
ax.set_xlim(-2, 6)
ymax=3
ax.set_ylim(-ymax, ymax)
plt.axis('off')

@widgets.interact(log_r=(0, 6, 0.05))
def update(log_r = 1.5):
    # Slider is logarithmic to cover the r range better
    r = np.exp(log_r) + 0.5
    lens.set_r(r)
    lens.draw(ax)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=1.5, description='log_r', max=6.0, step=0.05), Output()), _dom_classes…