In [1]:
import jupyter_manim
from manim import *

# Grover's Algorithm Visualized
* A graph plotting the real amplitudes of the basis kets.
* A dashed line at the mean value of the amplitudes
* When flipping an amplitude, draw a vector from the base to the tip, fade out the amplitude bars, flip the vectors about the base and then fade in the new amplitude bars

In [2]:
%%manim --disable_caching GroverSAT4

class GroverSAT4(Scene):
    def construct(self):
        myTemplate = TexTemplate()
        myTemplate.add_to_preamble(r'\usepackage{amsmath}')
        myTemplate.add_to_preamble(r'\usepackage{amssymb}')
        myTemplate.add_to_preamble(r'\usepackage{braket}')
        
        # Start state
        amps = np.array([0.5, 0.5, 0.5, 0.5])
        mean = np.mean(amps)
        
        # Constructing the graph's skeleton
        self.x_scale = 1
        self.y_scale = 2
        
        x_len = 4.5*self.x_scale
        y_len = 1.25*self.y_scale
        
        x_axis = Line([0,0,0],[x_len,0,0])
        y_axis = Line([0,-y_len,0],[0,y_len,0])
        axes = VGroup(x_axis,y_axis)
        
        TICK_LEN = 0.2
        ticks = VGroup(
            Line([0,self.y_scale,0],[-TICK_LEN,self.y_scale,0]), 
            Line([0,-self.y_scale,0],[-TICK_LEN,-self.y_scale,0]),
            Line([0,0,0],[-TICK_LEN,0,0])
        )
        y_lab_size = 36
        y_labels = VGroup(
            Tex('1.00', font_size=y_lab_size).next_to(ticks[0],LEFT), 
            Tex('-1.00',font_size=y_lab_size).next_to(ticks[1],LEFT),
            Tex('0.00', font_size=y_lab_size).next_to(ticks[2],LEFT) 
        )
        y_name = Tex('Amplitudes', font_size=30).rotate(np.pi/2).next_to(y_labels, LEFT)
        
        x_points = Group(*[ Point(location=[self.x_scale * i, 0, 0]) for i in range(1,5) ])
        x_lab_size = 24
        x_labels = VGroup(
            MathTex(r'\ket{00}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[0].get_center()[0], y_len+0.5, 0]),
            MathTex(r'\ket{01}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[1].get_center()[0], y_len+0.5, 0]),
            MathTex(r'\ket{10}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[2].get_center()[0], y_len+0.5, 0]),
            MathTex(r'\ket{11}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[3].get_center()[0], y_len+0.5, 0]),
        )
        x_name = Tex('Basis States', font_size=22).next_to(x_labels, 1.5*LEFT)
        
        probs = VGroup(*[
            DecimalNumber().set_value(amps[i]**2).scale(0.5).next_to(x_labels[i], 1.5*UP) for i in range(4)
        ])
        prob_label = Tex('Measurement\n\nProbabilities', font_size=22).next_to(probs, 1.5*LEFT)
        
        # Group the graph skeleton and move it to the center of the screen
        graph_group = Group(axes, ticks, y_labels, x_points, x_labels,x_name, y_name, probs, prob_label)
        graph_group.move_to([0,0,0])
        
        # Draw the graph
        self.play(Create(axes), Create(ticks), Write(y_labels), Write(y_name),lag_ratio=0.0)
        
        # Amplitude Bars
        amp_bars = VGroup(*[
            Line(x_points[i].get_center(), x_points[i].get_center() + self.y_scale*amps[i]*UP, stroke_width=10) for i in range(4)
        ])
        mean_line = DashedLine([y_axis.get_center()[0], self.y_scale*mean + x_axis.get_center()[1], 0], [y_axis.get_center()[0]+x_len, self.y_scale*mean + x_axis.get_center()[1], 0], 
                               dash_length=x_len/16, dashed_ratio=0.8, color='yellow', stroke_width=2, )
        mean_number = DecimalNumber().set_value(np.mean(amps)).set_color(YELLOW).scale(0.5)
        mean_number.add_updater(lambda number: number.next_to(mean_line, LEFT))
        mean_number.add_updater(lambda number: number.set_value( (mean_line.get_center()[1] - x_axis.get_center()[1])/self.y_scale ))
        mean_label = Tex('Mean Amplitude', font_size=24).set_color(YELLOW).next_to(mean_line, RIGHT)
        mean_label.add_updater(lambda label: label.next_to(mean_line, RIGHT))
        
        self.play(Write(x_name), Write(x_labels), lag_ratio=0.0)
        self.play(Create(amp_bars), lag_ratio=0.0)
        self.play(Write(probs), Write(prob_label), lag_ratio=0.0)
        self.play(Create(mean_line), Write(mean_number), Write(mean_label), lag_ratio=0.0)
        self.play(*[
            ReplacementTransform(x_labels[i], MathTex('\ket{{{:02b}}}\ket{{{:01b}}}'.format(i, int(i==3)), font_size=x_lab_size, tex_template=myTemplate).move_to(x_points[i].get_center() + np.array([0,y_len+0.5, 0]) ) )
            for i in range(4)
        ])
        def get_mean_transition(amps, target, val):
            b = amps.copy()
            b[target] = val
            return np.mean(b)

        def flip_target(target):
            amp = amps[target]
            p1 = Dot().move_to(x_points[target].get_center())#.set_color(rgba_to_color())
            p2 = Dot(radius=0.05).move_to(p1).shift(amp*self.y_scale*UP).set_color(RED)
            arrow = Line(p1.get_center(), p2.get_center(), buff=0).set_color(RED)
            
            theta = ValueTracker(0)
            
            p2.add_updater(lambda p: p.set_y(p1.get_y() + amp * self.y_scale * np.cos(theta.get_value())))
            arrow.add_updater(lambda l: l.become(Line (p1.get_center(), p2.get_center(), buff=0).set_color(RED) ))
            
            mean_line.add_updater(lambda l:
                                 l.become(DashedLine(
                                     [y_axis.get_center()[0], 
                                      self.y_scale*get_mean_transition(amps, target, amp*np.cos(theta.get_value())) + x_axis.get_center()[1], 
                                      0], 
                                     [y_axis.get_center()[0]+x_len, 
                                      self.y_scale*get_mean_transition(amps, target, amp*np.cos(theta.get_value())) + x_axis.get_center()[1], 
                                      0], 
                                     dash_length=x_len/16, dashed_ratio=0.8, color='yellow', stroke_width=2
                                 )))
            
            self.play(FadeOut(amp_bars[target]), Create(arrow), Create(p2))
            self.play(theta.animate.increment_value(np.pi), rate_func=linear, run_time=1)
            p2.clear_updaters()
            arrow.clear_updaters()
            
            amps[target] = -amp
            amp_bars[target].become(Line(x_points[target].get_center(), x_points[target].get_center() + self.y_scale*amps[target]*UP, stroke_width=10))
            self.play(FadeOut(arrow), FadeOut(p2), Create(amp_bars[target]))
            
            mean_line.clear_updaters()
            global mean
            mean = np.mean(amps)
        
        def flip_about_mean(amps):
            global mean
            p1 = [Dot().move_to(x_points[i].get_center() + mean*self.y_scale*UP) for i in range(4)]
            p2 = [Dot(radius=0.05).move_to(p1[i]).shift((amps[i] - mean)*self.y_scale*UP).set_color(RED) for i in range(4)]
            arrows = [Line(p1[i].get_center(), p2[i].get_center(), buff=0).set_color(RED) for i in range(4)]
            
            theta = ValueTracker(0)
            for i in range(4):
                p2[i].add_updater(lambda p, i=i: p.set_y(p1[i].get_y() + (amps[i]-mean) * self.y_scale * np.cos(theta.get_value())))
                arrows[i].add_updater(lambda l, i=i: l.become(Line (p1[i].get_center(), p2[i].get_center(), buff=0).set_color(RED) ))
                        
            self.play(FadeOut(amp_bars), *[Create(arrow) for arrow in arrows], *[Create(p) for p in p2])
            self.play(theta.animate.increment_value(np.pi),
                      *[probs[i].animate.set_value((2*mean - amps[i])**2) for i in range(4)],
                      rate_func=linear, run_time=1)
            for i in range(4):
                p2[i].clear_updaters()
                arrows[i].clear_updaters()
            # calculate new mean and amplitudes
            amps *= -1
            amps += 2*mean
            mean = np.mean(amps)
            for i in range(4):
                amp_bars[i].become(Line(x_points[i].get_center(), p2[i].get_center(), stroke_width=10))
            
            self.play(*[FadeOut(arrow) for arrow in arrows], 
                      *[FadeOut(p) for p in p2], 
                      *[Create(amp_bar) for amp_bar in amp_bars],
                      mean_line.animate.move_to(x_axis.get_center() + self.y_scale*mean*UP),
                     )
            return amps
        
        flip_target(3)
        amps = flip_about_mean(amps)
        
        self.wait(1)
        

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

# Testing Image Mobjects to see if I can animate sprites

In [49]:
%%manim --disable_caching ImageTest

class ImageTest(Scene):
    def construct(self):
        image_paths = ['./images/01.png', './images/02.png']
        images = [None] * len(image_paths)
        for i in range(len(image_paths)):
            images[i] = ImageMobject(image_paths[i])
            images[i].set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
            images[i].scale(10)
        
        sprite = images[0].copy()
        
        x = DecimalNumber().set_value(sprite.get_center()[0]).next_to(sprite, UP)
        y = DecimalNumber().set_value(sprite.get_center()[1]).next_to(x, RIGHT)
        
        x.add_updater( lambda x: x.set_value(sprite.get_center()[0]) )
        y.add_updater( lambda y: y.set_value(sprite.get_center()[1]) )
        
        self.FRAME_RATE = 10
        self.clock = 0
        def update_clock(mob, dt):
            self.clock += dt
            mob.set_value(self.clock * self.FRAME_RATE)
        
        dot = Dot().move_to(sprite.get_center())
        dot.add_updater(lambda d: d.move_to([np.cos(0.5*np.pi*self.clock), np.sin(0.5*np.pi*self.clock), 0]))
        
        frame = Integer().set_value(0).next_to(x, 2*UP)
        frame.add_updater(update_clock)
        frame.update()
        c = DecimalNumber().set_value(0).next_to(frame, UP)
        c.add_updater(lambda c: c.set_value(self.clock))
        sprite.add_updater(lambda s: s.become(images[frame.get_value() % len(images)]))
        sprite.add_updater(lambda s: s.move_to(dot.get_center()))
        
        arrow = Arrow(ORIGIN, sprite.get_center(), buff=0).set_color(YELLOW)
        arrow.add_updater(lambda a: a.become(Arrow(ORIGIN, sprite.get_center(), buff=0)).set_color(YELLOW))
        
        
#         self.add(*images)
        self.add(sprite, frame, dot, c, arrow)
#         self.play(sprite.animate.move_to([2, 0, 0]))
#         sprite.update()
        self.wait(3)
        

                                                                                                                       

## Sprite Controller
Now that I know that I can do sprite animations, I'm going to make a controller for Mario and Luigi that handles their movement and animations.

**Sprites:**
* Idle - velocity = 0, duck flag = False
* Running - horizontal velocity != 0 and vertical velocity = 0
* Jumping - vertical velocity > 0
* Falling - vertical velocity < 0
* Ducking - velocity = 0, duck flag = True

**Variables:**
* Direction - Left/Right to see which direction the sprite should face
* Position
* Velocity
* Acceleration
* Duck flag
* Image Directory (so it's made general for Mario and Luigi)
* Boundaries - Boundaries of the box to ensure that the character doesn't leave the screen

In [2]:
RUN_FRAMERATE = 10
class Sprite_Controller:
    def __init__(self, img_path, bounds, sprite, v, j, g, scale):
        self.v = v # Walking speed
        self.j = j # Jumping speed
        self.g = g # Gravitational acceleration
#         self.scale = scale
        
        self.pos = sprite.get_center()[0:2]
        self.vel = np.zeros(2)
        self.acc = np.zeros(2)
        
        self.h = sprite.get_top()[1] - sprite.get_bottom()[1]
        self.w = sprite.get_right()[0] - sprite.get_left()[0]
        
        self.direction = 1 # +1 for right, -1 for left
        self.duck = False 
        self.running = False
        
        self.img_path = img_path
        
        # Left facing frames
        self.idle_frame_l = ImageMobject(img_path+'idle.png')
        self.run_frames_l = [ImageMobject(img_path+'walk-{}.png'.format(i+1)) for i in range(3) ]
        self.jump_frame_l = ImageMobject(img_path+'jump.png')
        self.fall_frame_l = ImageMobject(img_path+'fall.png')
        self.duck_frame_l = ImageMobject(img_path+'duck.png')
        
        self.idle_frame_l.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        self.jump_frame_l.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        self.fall_frame_l.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        self.duck_frame_l.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        for i in range(3):
            self.run_frames_l[i].set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        
        self.idle_frame_l.scale(scale)
        self.jump_frame_l.scale(scale)
        self.fall_frame_l.scale(scale)
        self.duck_frame_l.scale(scale)
        for i in range(3):
            self.run_frames_l[i].scale(scale)
        
        # Right facing frames
        self.idle_frame_r = ImageMobject(img_path+'idle-r.png')
        self.run_frames_r = [ImageMobject(img_path+'walk-{}-r.png'.format(i+1)) for i in range(3) ]
        self.jump_frame_r = ImageMobject(img_path+'jump-r.png')
        self.fall_frame_r = ImageMobject(img_path+'fall-r.png')
        self.duck_frame_r = ImageMobject(img_path+'duck-r.png')
        
        self.idle_frame_r.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        self.jump_frame_r.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        self.fall_frame_r.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        self.duck_frame_r.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        for i in range(3):
            self.run_frames_r[i].set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        
        self.idle_frame_r.scale(scale)
        self.jump_frame_r.scale(scale)
        self.fall_frame_r.scale(scale)
        self.duck_frame_r.scale(scale)
        for i in range(3):
            self.run_frames_r[i].scale(scale)
        
        
        self.bounds = bounds # 4 vec [LEFT, RIGHT, DOWN, UP]
        
        self.last_frame = 0
        self.current_frame = 0
        self.running_frame = 0
        
    def update(self, sprite: Mobject, dt: float, clock: float) -> float:
        self.last_frame = int(clock * RUN_FRAMERATE) % 4
        clock += dt
        self.current_frame = int(clock * RUN_FRAMERATE) % 4
        # Integrate to get new position and velocity
        new_pos = self.pos + self.vel*dt
        new_vel = self.vel + self.acc*dt
        
        
        # Update x position and velocity
        if new_pos[0] - self.w/2 <= self.bounds[0]:
            self.pos[0] = self.bounds[0] + self.w/2 + 1e-5
            self.vel[0] = 0.0
        elif new_pos[0] + self.w/2 >= self.bounds[1]:
            self.pos[0] = self.bounds[1] - self.w/2 - 1e-5
            self.vel[0] = 0.0
        else:
            self.pos[0] = new_pos[0]
            self.vel[0] = new_vel[0]
        
        # Update y position and velocity
        if new_pos[1] - self.h/2 <= self.bounds[2]:
            self.pos[1] = self.bounds[2] + self.h/2
            if self.vel[1] < 0.0:
                self.vel[1] = 0.0
                self.acc[1] = 0.0
        elif new_pos[1] + self.h/2 >= self.bounds[3]:
            self.pos[1] = self.bounds[3] - self.h/2 - 1e-5
            if self.vel[1] > 0:
                self.vel[1] = 0.0
        else:
            self.pos[1] = new_pos[1]
            self.vel[1] = new_vel[1]
            self.acc[1] = -self.g
        
        # determine the sprite
        if self.direction < 0:
            if self.vel[1] == 0.0:
                if not self.running:
                    if not self.duck:
                        sprite.become(self.idle_frame_l)
                    else:
                        sprite.become(self.duck_frame_l)
                else:
                    if self.current_frame != self.last_frame:
                        self.running_frame = (self.running_frame + 1) % 4
                    if self.running_frame == 3:
                        sprite.become(self.run_frames_l[1])
                    else:
                        sprite.become(self.run_frames_l[self.running_frame])
            else:
                if self.vel[1] > 0.0:
                    sprite.become(self.jump_frame_l)
                else:
                    sprite.become(self.fall_frame_l)
        else:
            if self.vel[1] == 0.0:
                if not self.running:
                    if not self.duck:
                        sprite.become(self.idle_frame_r)
                    else:
                        sprite.become(self.duck_frame_r)
                else:
                    if self.current_frame != self.last_frame:
                        self.running_frame = (self.running_frame + 1) % 4
                    if self.running_frame == 3:
                        sprite.become(self.run_frames_r[1])
                    else:
                        sprite.become(self.run_frames_r[self.running_frame])
            else:
                if self.vel[1] > 0.0:
                    sprite.become(self.jump_frame_r)
                else:
                    sprite.become(self.fall_frame_r)
        
        # update the sprite's position
        sprite.move_to([self.pos[0], self.pos[1], 0.0])
        return clock
    
    def move_left(self):
        self.vel[0] = -self.v
        self.running = True
        self.direction = -1
    def move_right(self):
        self.vel[0] = self.v
        self.running = True
        self.direction = 1
    def jump(self):
        if self.pos[1] == self.bounds[2] + self.h/2:
            self.vel[1] = self.j
    def stop(self):
        self.vel[0] = 0.0
        self.running = False
    def duck_(self):
        self.stop()
        self.duck = True
        self.run_frame = 0
    def flip(self):
        self.direction *= -1
        self.vel[0] *= -1

In [3]:
%%manim --disable_caching TestController
class TestController(Scene):
    def construct(self):
        bounds = [-3, 3, -2, 2]
        box = Polygon(*[
            [bounds[0], bounds[2], 0],
            [bounds[0], bounds[3], 0],
            [bounds[1], bounds[3], 0],
            [bounds[1], bounds[2], 0]],
            color=WHITE)
#         sprite = Dot(radius=0.2).move_to([0, bounds[2]+0.2, 0])
        sprite = ImageMobject('./images/luigi_sprites/idle-r.png')
        sprite.set_resampling_algorithm(RESAMPLING_ALGORITHMS['nearest'])
        scale = 5
        sprite.scale(scale)
        h = sprite.get_top()[1] - sprite.get_bottom()[1]
        sprite.move_to([0, bounds[2]+h/2, 0])
        
        pos_lab = Tex('Position: ', font_size=28).to_corner(UL)
        vel_lab = Tex('Velocity: ', font_size=28).next_to(pos_lab, DOWN)
        acc_lab = Tex('Acceleration: ', font_size=28).next_to(vel_lab, DOWN)
        pos_x = DecimalNumber(include_sign=True).scale(0.5).next_to(pos_lab, RIGHT)
        pos_y = DecimalNumber(include_sign=True).scale(0.5).next_to(pos_x, RIGHT)
        vel_x = DecimalNumber(include_sign=True).scale(0.5).next_to(vel_lab, RIGHT)
        vel_y = DecimalNumber(include_sign=True).scale(0.5).next_to(vel_x, RIGHT)
        acc_x = DecimalNumber(include_sign=True).scale(0.5).next_to(acc_lab, RIGHT)
        acc_y = DecimalNumber(include_sign=True).scale(0.5).next_to(acc_x, RIGHT)
        
        pos_x.add_updater(lambda p: p.set_value(controller.pos[0]))
        pos_y.add_updater(lambda p: p.set_value(controller.pos[1]))
        vel_x.add_updater(lambda p: p.set_value(controller.vel[0]))
        vel_y.add_updater(lambda p: p.set_value(controller.vel[1]))
        acc_x.add_updater(lambda p: p.set_value(controller.acc[0]))
        acc_y.add_updater(lambda p: p.set_value(controller.acc[1]))
        
        log = Group(pos_lab, vel_lab, acc_lab, pos_x, pos_y, vel_x, vel_y, acc_x, acc_y)
        
        
        mvmt_params = np.array([3.7, 17.36, 33.9])/2
        controller = Sprite_Controller('./images/luigi_sprites/', bounds, sprite, 
                                       v=mvmt_params[0], j=mvmt_params[1], g=mvmt_params[2], scale=scale)
        self.clock = 0.0
        def update_sprite(sprite, dt):
            self.clock = controller.update(sprite, dt, self.clock)
        sprite.add_updater(update_sprite)
        
        self.add(box, sprite, log)
        
        
        controller.move_right()
        self.wait(1)
        controller.jump()
        self.wait(2)
        controller.move_left()
        self.wait(1)
        controller.jump()
        self.wait(2.5)
        controller.stop()
        controller.flip()
        self.wait(0.1)
        controller.duck_()
        self.wait(1)
        
#         controller.move_left()
#         controller.jump()
#         self.wait(2)
#         controller.stop()
#         self.wait(1)
        

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

# Test how to make the cat kets

In [4]:
%%manim CatKet

def get_cat_ket(state, scale=1.0):
    font_size = 350
    cat = ImageMobject('./images/cats/{}.png'.format(state))
    line = MathTex('|', font_size=font_size).next_to(cat, 0.1*LEFT)
    angle = MathTex(r'\rangle', font_size=font_size).next_to(cat, 0.01*RIGHT)
    cat_ket = Group(line, cat, angle)
    cat_ket.scale(scale)
    return cat_ket

class CatKet(Scene):
    def construct(self):
        dead = get_cat_ket('dead', 0.25).shift(LEFT)
        alive = get_cat_ket('alive', 0.25).next_to(dead, 1.5*RIGHT)
        self.add(dead, alive)
        self.wait(1)

In [72]:
%%manim CatStateSpace
class CatStateSpace(Scene):
    def construct(self):
        scale = 2
        ax_lim = 3.5
#         ax = VGroup(
#             DoubleArrow(start=[0,-ax_lim,0], end=[0,ax_lim,0], buff=0).set_color(WHITE),
#             DoubleArrow(start=[-ax_lim,0,0], end=[ax_lim,0,0], buff=0).set_color(WHITE)
#         )
        ax = NumberPlane(faded_line_ratio=1/scale)
        circle = Circle(radius=scale, stroke_width=1.5).set_color(YELLOW)
        
        ket_scale=0.1
        ket_font = 350 * ket_scale
        dead = get_cat_ket('dead', scale=ket_scale).move_to(circle.get_edge_center(UP) + 0.4*UL)
        alive = get_cat_ket('alive', scale=ket_scale).move_to(circle.get_edge_center(RIGHT) + 0.4*DR)
        
        theta = ValueTracker().set_value(0.0)
        state = Vector(scale*np.array([np.cos(theta.get_value()), np.sin(theta.get_value())]), buff=0)
        state.add_updater(lambda s: s.become(Vector(scale*np.array([np.cos(theta.get_value()), np.sin(theta.get_value())]), buff=0)))
        
#         self.add(ax, state)
#         self.play(Create(circle), theta.animate.increment_value(2*np.pi), run_time=2)
#         theta.set_value(0.0)
#         self.play(FadeOut(circle), run_time=0.5)
#         self.play(theta.animate.set_value(np.pi/2), FadeIn(dead))
#         self.play(theta.animate.set_value(0.0), FadeIn(alive))
        self.add(ax,state,dead,alive)
    

        h_color = YELLOW
        v_color = GREEN
        
        psi = MathTex(r'|\psi\rangle=', font_size=ket_font)
        alive_amp = DecimalNumber(font_size=ket_font).set_value(np.cos(theta.get_value())).set_color(h_color)
        dead_amp = DecimalNumber(font_size=ket_font, include_sign=True).set_value(np.sin(theta.get_value())).set_color(v_color)
        alive_amp.next_to(psi,RIGHT)
        aket = get_cat_ket('alive', scale=ket_scale).next_to(alive_amp, 1.3*RIGHT)
        dead_amp.next_to(aket, RIGHT)
        dket = get_cat_ket('dead', scale=ket_scale).next_to(dead_amp, RIGHT)
        dead_amp.add_updater(lambda n: n.set_value(np.sin(theta.get_value())))
        alive_amp.add_updater(lambda n: n.set_value(np.cos(theta.get_value())))
        
#         dead_amp.add_updater(lambda n: n.next_to(aket, RIGHT))
#         alive_amp.add_updater(lambda n: n.next_to(psi, RIGHT))
#         dket.add_updater(lambda n: n.next_to(dead_amp, RIGHT))
#         aket.add_updater(lambda n: n.next_to(alive_amp,RIGHT))
        
        
        statevec = Group(psi, dead_amp, alive_amp, dket,aket)
        statevec.next_to(circle, UR)
        
        h_vec = Vector([scale*np.cos(theta.get_value()), 0]).set_color(h_color)
        v_vec = Vector([0, scale*np.sin(theta.get_value())]).shift(scale*np.cos(theta.get_value())*RIGHT).set_color(v_color)
        def update_hvec(v):
            if np.abs(np.sin(theta.get_value())) > 1e-5:
                v.become(Vector([scale*np.cos(theta.get_value()), 0]).set_color(h_color))
            else:
                v.become(Vector([0,0]).set_color(h_color))
        h_vec.add_updater(update_hvec)
        v_vec.add_updater(lambda v: v.become(Vector([0, scale*np.sin(theta.get_value())]).shift(scale*np.cos(theta.get_value())*RIGHT).set_color(v_color)))
        
        components = VGroup(h_vec,v_vec)
        
#         self.play(FadeIn(statevec))
        self.add(components, statevec)
        self.play(theta.animate.set_value(2*np.pi), run_time=2)
        
        ket0 = MathTex(r'|0\rangle', font_size=ket_font).move_to(aket.get_center())
        ket1 = MathTex(r'|1\rangle', font_size=ket_font).move_to(dket.get_center())
        
        self.play(FadeOut(dket), FadeOut(aket), FadeIn(ket0), FadeIn(ket1))

        self.wait()

                                                                                                                       