In [2]:
import jupyter_manim
from manim import *

# Grover's Algorithm Visualized
* A graph plotting the real amplitudes of the basis kets.
* A dashed line at the mean value of the amplitudes
* When flipping an amplitude, draw a vector from the base to the tip, fade out the amplitude bars, flip the vectors about the base and then fade in the new amplitude bars

In [31]:
%%manim --disable_caching GroverSAT4

class GroverSAT4(Scene):
    def construct(self):
        myTemplate = TexTemplate()
        myTemplate.add_to_preamble(r'\usepackage{amsmath}')
        myTemplate.add_to_preamble(r'\usepackage{amssymb}')
        myTemplate.add_to_preamble(r'\usepackage{braket}')
        
        # Start state
        amps = np.array([0.5, 0.5, 0.5, 0.5])
        mean = np.mean(amps)
        
        # Constructing the graph's skeleton
        self.x_scale = 1
        self.y_scale = 2
        
        x_len = 4.5*self.x_scale
        y_len = 1.25*self.y_scale
        
        x_axis = Line([0,0,0],[x_len,0,0])
        y_axis = Line([0,-y_len,0],[0,y_len,0])
        axes = VGroup(x_axis,y_axis)
        
        TICK_LEN = 0.2
        ticks = VGroup(
            Line([0,self.y_scale,0],[-TICK_LEN,self.y_scale,0]), 
            Line([0,-self.y_scale,0],[-TICK_LEN,-self.y_scale,0]),
            Line([0,0,0],[-TICK_LEN,0,0])
        )
        y_lab_size = 36
        y_labels = VGroup(
            Tex('1.00', font_size=y_lab_size).next_to(ticks[0],LEFT), 
            Tex('-1.00',font_size=y_lab_size).next_to(ticks[1],LEFT),
            Tex('0.00', font_size=y_lab_size).next_to(ticks[2],LEFT) 
        )
        y_name = Tex('Amplitudes', font_size=30).rotate(np.pi/2).next_to(y_labels, LEFT)
        
        x_points = Group(*[ Point(location=[self.x_scale * i, 0, 0]) for i in range(1,5) ])
        x_lab_size = 24
        x_labels = VGroup(
            MathTex(r'\ket{00}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[0].get_center()[0], y_len+0.5, 0]),
            MathTex(r'\ket{01}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[1].get_center()[0], y_len+0.5, 0]),
            MathTex(r'\ket{10}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[2].get_center()[0], y_len+0.5, 0]),
            MathTex(r'\ket{11}', font_size=x_lab_size, tex_template=myTemplate).move_to([x_points[3].get_center()[0], y_len+0.5, 0]),
        )
        x_name = Tex('Basis States', font_size=22).next_to(x_labels, 1.5*LEFT)
        
        probs = VGroup(*[
            DecimalNumber().set_value(amps[i]**2).scale(0.5).next_to(x_labels[i], 1.5*UP) for i in range(4)
        ])
        prob_label = Tex('Measurement\n\nProbabilities', font_size=22).next_to(probs, 1.5*LEFT)
        
        # Group the graph skeleton and move it to the center of the screen
        graph_group = Group(axes, ticks, y_labels, x_points, x_labels,x_name, y_name, probs, prob_label)
        graph_group.move_to([0,0,0])
        
        # Draw the graph
        self.play(Create(axes), Create(ticks), Write(y_labels), Write(y_name),lag_ratio=0.0)
        
        # Amplitude Bars
        amp_bars = VGroup(*[
            Line(x_points[i].get_center(), x_points[i].get_center() + self.y_scale*amps[i]*UP, stroke_width=10) for i in range(4)
        ])
        mean_line = DashedLine([y_axis.get_center()[0], self.y_scale*mean + x_axis.get_center()[1], 0], [y_axis.get_center()[0]+x_len, self.y_scale*mean + x_axis.get_center()[1], 0], 
                               dash_length=x_len/16, dashed_ratio=0.8, color='yellow', stroke_width=2, )
        mean_number = DecimalNumber().set_value(np.mean(amps)).set_color(YELLOW).scale(0.5)
        mean_number.add_updater(lambda number: number.next_to(mean_line, LEFT))
        mean_number.add_updater(lambda number: number.set_value( (mean_line.get_center()[1] - x_axis.get_center()[1])/self.y_scale ))
        mean_label = Tex('Mean Amplitude', font_size=24).set_color(YELLOW).next_to(mean_line, RIGHT)
        mean_label.add_updater(lambda label: label.next_to(mean_line, RIGHT))
        
        self.play(Write(x_name), Write(x_labels), lag_ratio=0.0)
        self.play(Create(amp_bars), Write(probs), Write(prob_label), lag_ratio=0.0)
        self.play(Create(mean_line), Write(mean_number), Write(mean_label), lag_ratio=0.0)
        self.play(*[
            ReplacementTransform(x_labels[i], MathTex('\ket{{{:02b}}}\ket{{{:01b}}}'.format(i, int(i==3)), font_size=x_lab_size, tex_template=myTemplate).move_to(x_points[i].get_center() + np.array([0,y_len+0.5, 0]) ) )
            for i in range(4)
        ])
        def get_mean_transition(amps, target, val):
            b = amps.copy()
            b[target] = val
            return np.mean(b)

        def flip_target(target):
            amp = amps[target]
            p1 = Dot().move_to(x_points[target].get_center())#.set_color(rgba_to_color())
            p2 = Dot(radius=0.05).move_to(p1).shift(amp*self.y_scale*UP).set_color(RED)
            arrow = Line(p1.get_center(), p2.get_center(), buff=0).set_color(RED)
            
            theta = ValueTracker(0)
            
            p2.add_updater(lambda p: p.set_y(p1.get_y() + amp * self.y_scale * np.cos(theta.get_value())))
            arrow.add_updater(lambda l: l.become(Line (p1.get_center(), p2.get_center(), buff=0).set_color(RED) ))
            
            mean_line.add_updater(lambda l:
                                 l.become(DashedLine(
                                     [y_axis.get_center()[0], 
                                      self.y_scale*get_mean_transition(amps, target, amp*np.cos(theta.get_value())) + x_axis.get_center()[1], 
                                      0], 
                                     [y_axis.get_center()[0]+x_len, 
                                      self.y_scale*get_mean_transition(amps, target, amp*np.cos(theta.get_value())) + x_axis.get_center()[1], 
                                      0], 
                                     dash_length=x_len/16, dashed_ratio=0.8, color='yellow', stroke_width=2
                                 )))
            
            self.play(FadeOut(amp_bars[target]), Create(arrow), Create(p2))
            self.play(theta.animate.increment_value(np.pi), rate_func=linear, run_time=1)
            p2.clear_updaters()
            arrow.clear_updaters()
            
            amps[target] = -amp
            amp_bars[target].become(Line(x_points[target].get_center(), x_points[target].get_center() + self.y_scale*amps[target]*UP, stroke_width=10))
            self.play(FadeOut(arrow), FadeOut(p2), Create(amp_bars[target]))
            
            mean_line.clear_updaters()
            global mean
            mean = np.mean(amps)
        
        def flip_about_mean(amps):
            global mean
            p1 = [Dot().move_to(x_points[i].get_center() + mean*self.y_scale*UP) for i in range(4)]
            p2 = [Dot(radius=0.05).move_to(p1[i]).shift((amps[i] - mean)*self.y_scale*UP).set_color(RED) for i in range(4)]
            arrows = [Line(p1[i].get_center(), p2[i].get_center(), buff=0).set_color(RED) for i in range(4)]
            
            theta = ValueTracker(0)
            for i in range(4):
                p2[i].add_updater(lambda p, i=i: p.set_y(p1[i].get_y() + (amps[i]-mean) * self.y_scale * np.cos(theta.get_value())))
                arrows[i].add_updater(lambda l, i=i: l.become(Line (p1[i].get_center(), p2[i].get_center(), buff=0).set_color(RED) ))
                        
            self.play(FadeOut(amp_bars), *[Create(arrow) for arrow in arrows], *[Create(p) for p in p2])
            self.play(theta.animate.increment_value(np.pi),
                      *[probs[i].animate.set_value((2*mean - amps[i])**2) for i in range(4)],
                      rate_func=linear, run_time=1)
            for i in range(4):
                p2[i].clear_updaters()
                arrows[i].clear_updaters()
            # calculate new mean and amplitudes
            amps *= -1
            amps += 2*mean
            mean = np.mean(amps)
            for i in range(4):
                amp_bars[i].become(Line(x_points[i].get_center(), p2[i].get_center(), stroke_width=10))
            
            self.play(*[FadeOut(arrow) for arrow in arrows], 
                      *[FadeOut(p) for p in p2], 
                      *[Create(amp_bar) for amp_bar in amp_bars],
                      mean_line.animate.move_to(x_axis.get_center() + self.y_scale*mean*UP),
                     )
            return amps
        
        flip_target(3)
        amps = flip_about_mean(amps)
        
        self.wait(1)
        

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

                                                                                                                       

In [22]:
for i in range(4):
    s = '\ket{{{:02b}}}\ket{{{:01b}}}'.format(i,int(i==3))
    print(s)

\ket{00}\ket{0}
\ket{01}\ket{0}
\ket{10}\ket{0}
\ket{11}\ket{1}


In [136]:
a = [False]*4
a[2] = True


TypeError: only integer scalar arrays can be converted to a scalar index