In [2]:
from manim import *
from manim_slides import Slide

config.media_width = "65%"
config.media_embed = True
config.verbosity = "WARNING"
config.background_color = BLACK

In [None]:
%%manim -v WARNING --disable_caching -qm intro

toc=Group(
    Tex("1.~Puzzles Concerning Place Fields"),
    Tex("2.~Back to the First Principles"),
    Tex("3.~Check in with Reality"),
    Tex("4.~Optimality Considerations"),
).arrange(DOWN,aligned_edge=LEFT,buff=0.5).move_to(ORIGIN)

intro_words = Title("""
            A Unified Theory of Place Fields
        """,)
name = Text("Nischal Mainali", font_size=24).next_to(intro_words, DOWN)

one = Title("""
            Puzzles Concerning Place Fields
        """)

bg = ImageMobject("img/bg.png").scale(1.3).shift(RIGHT*0.75)

# class intro(Scene):
class intro(Slide):
    def construct(self):
        self.add(bg)
        self.play(Write(intro_words), Write(name))
        self.pause()
        self.play(Unwrite(name),FadeOut(bg))
        self.pause()
        self.play(FadeIn(toc))
        self.pause()

        self.play(toc[0].animate.scale(1.2).set_color(YELLOW))
        self.pause()

        for i in range(1,len(toc)):
           self.play(toc[i].animate.scale(1.2).set_color(YELLOW),toc[i-1].animate.scale(1/1.2).set_color(WHITE))
           self.pause()

        self.play(toc[-1].animate.scale(1/1.2).set_color(WHITE))
        self.pause()

        self.play(Transform(intro_words,one) , FadeOut(toc))
        self.pause()

        self.wait()


In [None]:
%%manim -v WARNING --disable_caching -qm part_one


cell2= ImageMobject("img/cell2.png", invert=True)
cell2_text = Text("Ulanovsky Lab Place Cell in the Bat Hippocampus", font_size=24).to_corner(DOWN)
cell2.height = 5

log_normal = ImageMobject("img/log_normal.png", invert=True)
log_normal_text = Text("Log Normal Distribution", font_size=24).to_corner(DOWN)
log_normal.height = 5

expo = ImageMobject("img/expo.png", invert=True)
expo_text = Text("Distance between Receptive Fields", font_size=24).to_corner(DOWN)
expo.height = 5

mouse_cell = ImageMobject("img/mouse_cell.png", invert=True)
mouse_cell_text = Text("Fellous Lab: Place Cells in the Mouse Hippocampus", font_size=24).to_corner(DOWN)
mouse_cell.height = 4.5

mouse_rf = ImageMobject("img/mouse_rf.png", invert=True)
mouse_rf_text = Text("Mouse Receptive Fields", font_size=24).to_corner(DOWN)
mouse_rf.height = 4

mouse_rf_dist = ImageMobject("img/mouse_rf_dist.png", invert=True)
mouse_rf_dist_text = Text("Mouse Receptive Field Distance Distribution", font_size=24).to_corner(DOWN)
mouse_rf_dist.height = 4

meme = ImageMobject("img/meme.png", invert=False)
meme.height = 4

class part_one(Scene):
# class part_one(Slide):
    def construct(self):

        self.add(cell2, cell2_text)
        self.pause(4)
        self.play(cell2.animate.scale(0.4).to_corner(UL), FadeOut(cell2_text))     
        self.pause()

        self.add(log_normal, log_normal_text)
        self.pause(4)
        self.play(log_normal.animate.scale(0.4).to_corner(LEFT), FadeOut(log_normal_text))
        self.pause()

        self.add(meme)
        self.pause(4)
        self.play(FadeOut(meme))

        self.add(expo, expo_text)
        self.pause(4)
        self.play(expo.animate.scale(0.4).to_corner(DL), FadeOut(expo_text))
        self.pause()

        self.add(mouse_cell, mouse_cell_text)
        self.pause(4)
        self.play(mouse_cell.animate.scale(0.4).to_corner(UR), FadeOut(mouse_cell_text))
        self.pause()

        self.add(mouse_rf, mouse_rf_text)
        self.pause(4)
        self.play(mouse_rf.animate.scale(0.4).to_corner(RIGHT), FadeOut(mouse_rf_text))
        self.pause()

        self.add(mouse_rf_dist, mouse_rf_dist_text)
        self.pause(4)
        self.play(mouse_rf_dist.animate.scale(0.4).to_corner(DR), FadeOut(mouse_rf_dist_text))
        self.pause()


        dim_group = VGroup(MarkupText("Phenomenological Differences"), MarkupText("by Dimension")).arrange(DOWN)
        self.play(FadeIn(dim_group.scale(0.7)))
        self.pause(5)




In [7]:
%%manim -v WARNING --disable_caching -qm check_in

# cell2= ImageMobject("img/cell2.png", invert=True)
# cell2_text = Text("Example Place Cell activity in the Bat Hippocampus", font_size=24).to_corner(DOWN)
# cell2.height = 5

# log_normal = ImageMobject("img/log_normal.png", invert=True)
# log_normal_text = Text("Ulanovsky's Log Normal Distribution", font_size=24).to_corner(DOWN)
# log_normal.height = 5

# expo = ImageMobject("img/expo.png", invert=True)
# expo_text = Text("Distance between Receptive Fields", font_size=24).to_corner(DOWN)
# expo.height = 5

# mouse_rf = ImageMobject("img/mouse_rf.png", invert=True)
# mouse_rf_text = Text("Mouse Receptive Fields", font_size=24).to_corner(DOWN)
# mouse_rf.height = 4

### 

example_cell= ImageMobject("img/example_cell.png", invert=False)
example_cell_text = Text("Example Place Cell from the RF model", font_size=24).to_corner(DOWN)
example_cell.height = 4

k = ImageMobject("img/k.png", invert=False).scale(0.5)
k_text = Text("Rayleigh Distribution", font_size=24).to_corner(DOWN)
k.height = 4

k_bar = ImageMobject("img/k_bar.png", invert=False).scale(0.5)
k_bar_text = Text("Exponential Distribution", font_size=24).to_corner(DOWN)
# k_bar.height = 5

k_dim2 = ImageMobject("img/k_dim2.png", invert=False).scale(0.5)
k_dim2_text = Text("Exponential Distribution", font_size=24).to_corner(DOWN)
# k_dim2.height = 5

class check_in(Scene):
    def construct(self):

        # self.play(cell2.animate.scale(2.2).move_to(ORIGIN), Write(cell2_text))
        # self.pause(2)
        # self.play(cell2.animate.scale(0.3).to_corner(UL), FadeOut(cell2_text))     
        # self.pause()

        self.add(example_cell, example_cell_text)
        self.pause(2)
        self.play(FadeOut(example_cell), FadeOut(example_cell_text))

        # self.play(log_normal.animate.scale(2.2).move_to(ORIGIN), Write(log_normal_text))
        # self.pause(4)
        # self.play(log_normal.animate.scale(0.3).to_corner(UL), FadeOut(log_normal_text))

        self.add(k, k_text)
        self.pause(4)
        self.play(FadeOut(k), FadeOut(k_text))

        # self.play(expo.animate.scale(2.2).move_to(ORIGIN), Write(expo_text))
        # self.pause(4)
        # self.play(expo.animate.scale(0.3).to_corner(UL), FadeOut(expo_text))

        # self.add(k_bar)
        # self.pause(4)
        # self.play(FadeOut(k_bar), FadeOut(expo))

        # self.play(mouse_rf.animate.scale(2.2).move_to(ORIGIN), Write(mouse_rf_text))
        # self.pause(4)
        # self.play(mouse_rf.animate.scale(0.3).to_corner(UL), FadeOut(mouse_rf_text))
        # self.pause()

        # self.add(k_dim2)
        # self.pause(4)
        # self.play(FadeOut(k_dim2), FadeOut(mouse_rf))








                                                                                         

In [None]:
%%manim -v WARNING --disable_caching -ql part_two

import numpy as np
# from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import gaussian_filter
from scipy import interpolate
def GausField(sigma):
    im = gaussian_filter(np.random.normal(0,1,1000), sigma = sigma, mode = 'wrap')
    return im/(np.sqrt(np.mean(im**2)))

two = Title("""
            Back to the First Principles
        """)

class part_two(Scene):
    def construct(self):
        self.play(Write(two))
        self.pause()

        x_start = np.array([-2,0,0])
        x_end = np.array([2,0,0])
        x_axis = Line(x_start, x_end)

        text1 = MarkupText(
            f'We want to encode (1D) space, a continuous variable, with N Neurons.'
        ).next_to(x_axis, DOWN, buff=1)
        text2 = MarkupText(
            f'a <span fgcolor="{RED}">BIG</span> space of size L. How?'
        ).next_to(x_axis, DOWN, buff=1)
        text3 = MarkupText(
            f'Shanon taught us to generate codes <span fgcolor="{RED}">RANDOMLY</span>.'
        ).next_to(x_axis, UP)
        text4 = MarkupText(
            f'A random variable indexed on a continuous space is called a <span fgcolor="{BLUE}">Stochastic Process</span>.'
        ).next_to(x_axis, UP)
        text5 = MarkupText(
            f'And we know that constrained processes with maxmimum entropy is a <span fgcolor="{BLUE}">Gaussian Process</span>.'
        ).next_to(x_axis, UP)
        text6 = MarkupText(
            f'But a crucial ingredient is missing: <span fgcolor="{RED}">Thresholding!</span>'
        ).scale(0.8).to_corner(UP)
        text7 = MarkupText(
            f'Ok, but we need a covariance function. Which one?'
        ).scale(0.7).next_to(x_axis, DOWN, buff=2)
        text8 = MarkupText(
            f'Any covariance that <span fgcolor="{BLUE}">decays</span> and has a <span fgcolor="{BLUE}">second derivative</span>.'
        ).scale(0.7).next_to(x_axis, DOWN, buff=2)
        

        self.add(x_axis, text1.scale(0.5))
        self.pause()

        self.play(ScaleInPlace(x_axis, 2.5), Transform(text1, text2))
        self.pause()

        self.play(FadeOut(text1), x_axis.animate.to_corner(DOWN), FadeOut(two))
        self.pause()

        self.play(Write(text3.scale(0.7)))
        self.pause()

        self.play(Transform(text3, text4.scale(0.5)))
        self.pause()

        ax = Axes(x_range=[-5, 5], y_range=[-10, 10])

        f0 = interpolate.interp1d(np.arange(-5,5,0.01),GausField(1))    
        gp_graph0 = ax.plot(lambda x: f0(x), color=BLUE_C, x_range=[-4, 4]).next_to(text3, UP, buff=1)
        gp_graph0_thresholded = ax.plot(lambda x: f0(x), color=BLUE_C, x_range=[-4, 4]).next_to(text3, UP, buff=1)

        f1 = interpolate.interp1d(np.arange(-5,5,0.01),GausField(3))    
        gp_graph1 = ax.plot(lambda x: f1(x), color=BLUE_C, x_range=[-4, 4], use_smoothing=False)
        gp_graph1_thresholded = ax.plot(lambda x: 2*np.maximum(1,f1(x)), color=BLUE_C, x_range=[-4, 4], use_smoothing=False)

        f2 = interpolate.interp1d(np.arange(-5,5,0.01),GausField(15))
        gp_graph2 = ax.plot(lambda x: f2(x), color=BLUE_D, x_range=[-4, 4], use_smoothing=False)
        gp_graph2_thresholded = ax.plot(lambda x: 2*(np.maximum(1,f2(x))-1), color=BLUE_D, x_range=[-4, 4], use_smoothing=False)
        f3 = interpolate.interp1d(np.arange(-5,5,0.01),GausField(30))
        gp_graph3 = ax.plot(lambda x: f3(x), color=BLUE_E, x_range=[-4, 4], use_smoothing=False)
        gp_graph3_thresholded = ax.plot(lambda x: 2*(np.maximum(1,f3(x))-1), color=BLUE_E, x_range=[-4, 4], use_smoothing=False)
        threshold = ax.plot(lambda x: 1, color=RED, x_range=[-4.2, 4.2])

        gr = VGroup(gp_graph1, gp_graph2, gp_graph3).arrange(DOWN,buff=0.8) 

        x_coords = np.arange(-5,5,0.01)
        y_coords = GausField(1)

        plot = ax.plot_line_graph(x_coords[1::40], y_coords[1::40])
        plot2 = ax.plot_line_graph(x_coords[1::20], y_coords[1::20])
        plot3 = ax.plot_line_graph(x_coords[1::10], y_coords[1::10])

        self.play(
            Create(gp_graph0), Transform(text3, text5.scale(0.5)),
            run_time=3
        )
        self.pause(3)

        gpt = Title("""
            Gaussian Process: A Primer
        """)
        self.play(Write(gpt), FadeOut(text3), gp_graph0.animate.move_to(ORIGIN), 
        )
        self.pause()
        self.play(FadeOut(gp_graph0))

        arrow_gp =DoubleArrow(start=ax.coords_to_point(0, 6) , end=ax.coords_to_point(0, -6), color=GOLD_B, buff=1.2).to_corner(LEFT)

        self.play(Create(plot["vertex_dots"]))   
        self.play(Create(plot["line_graph"]))
        self.pause()
        self.play(Transform(plot["line_graph"], plot2["line_graph"]), 
            Transform(plot["vertex_dots"], plot2["vertex_dots"]))
        self.pause()
        self.play(Transform(plot["line_graph"], plot3["line_graph"]),
            Transform(plot["vertex_dots"], plot3["vertex_dots"]))
        self.pause()

    
        r2 = MathTex(r"r''", fill_color=GOLD_D, ).scale(1.2)
        r = MathTex(r"r", fill_color=GOLD_D, ).scale(1.2)

        self.play(Write(text7))
        self.pause()
        self.play(Transform(text7, text8))
        self.pause()
        self.play(Write(r.next_to(text7, DOWN, buff=0.2)),Create(arrow_gp))
        self.pause()
        self.play(Transform(r, r2.next_to(text7, DOWN, buff=0.2)))

        self.play(FadeOut(plot["line_graph"]), FadeOut(plot["vertex_dots"]), FadeOut(gpt), FadeOut(text7), FadeOut(r), FadeOut(arrow_gp))
        self.pause()
        

        # Threshold
        

        gr2 = VGroup(gp_graph1_thresholded, gp_graph2_thresholded, gp_graph3_thresholded).arrange(DOWN,buff=0.4)    
        arrow =Arrow(start=config.top , end=config.bottom, color=GOLD_B, buff=1.2).to_corner(LEFT)
        r2 = MathTex(r"r''", fill_color=GOLD_D, ).scale(0.9).next_to(arrow.get_center(), RIGHT, buff=0.4)
        self.play(
            Create(gr[0]),Create(gr[1]),Create(gr[2]), Create(arrow), Write(r2),
            run_time=3
        )
        self.pause()
        
        self.play(Write(text6.scale(0.9)))
        self.pause()
        self.play(FadeOut(text6))
        self.pause()

        # f0 = interpolate.interp1d(np.arange(-5,5,0.01),GausField(3))    
        gp_graph0 = ax.plot(lambda x: f1(x), color=BLUE_C, x_range=[-4, 4])
        gp_graph0_thresholded = ax.plot(lambda x: np.maximum(1,f1(x)), color=BLUE_C, x_range=[-4, 4], use_smoothing=False)
        s_gp_graph0_thresholded = ax.plot(lambda x: 2*(np.maximum(1,f1(x))-1), color=BLUE_C, x_range=[-4, 4], use_smoothing=False)
        
        self.play(
            FadeOut(gr[1]), FadeOut(gr[2]),
            ReplacementTransform(gr[0], gp_graph0),
            run_time=3
        )
        self.play(Create(threshold), run_time=3)
        self.pause()
        self.play(ReplacementTransform(gp_graph0, gp_graph0_thresholded),run_time=3)
        self.pause()
        self.play(ReplacementTransform(gp_graph0_thresholded, s_gp_graph0_thresholded),FadeOut(threshold),run_time=3)
        self.pause()

        self.play(ReplacementTransform(s_gp_graph0_thresholded, gr2), run_time=4)
        self.pause()

        ## ADD THRESHOLDING 

        

In [None]:
%%manim -qm optimality

rf_global_error = ImageMobject("img/rf_global_error.png").shift(DOWN*0.5)
rf_global_error.height = 5.5
rf_threshold_error = ImageMobject("img/rf_threshold_error.png").shift(DOWN*0.5)
rf_threshold_error.height = 5.5
delta = ImageMobject("img/delta.png").shift(DOWN*0.5)
delta.height = 5.5

class optimality(Scene):

    def get_rectangle_corners(self, y_range, par):
        y = y_range[1]
        x = par*(1-y/6)**0.5
        y0 = y_range[0]
        x0 = par*(1-y0/6)**0.5
        return [
            (x, y),
            (x0, y0),
            (x0, 0),
            (x, 0),
        ]
    
    
    def construct(self):

        opt = Title("Optimality Considerations")

        text10 = MarkupText(
            f'Is it a <span fgcolor="{RED}">Good</span> Idea?'
        )

        x = MathTex(r"\mathbb{X}", fill_color=BLUE_A, ).to_corner(LEFT)
        f = MathTex(r"f(\mathbb{X}}) + z", fill_color=BLUE_C).to_corner(UP)
        x_hat = MathTex(r"\hat{\mathbb{X}}}", fill_color=BLUE_E).to_corner(RIGHT)

        arrow1 = Arrow(x, f, buff=0.8)
        arrow2 = Arrow(f, x_hat, buff=0.8)

        self.play(Write(opt), Write(text1.scale(0.7)))
        self.pause()

        self.play(FadeOut(text1), FadeOut(opt), Create(arrow1), Create(arrow2), Write(x), Write(f), Write(x_hat))
        self.pause()

        eta = MathTex(r"z \sim \mathcal{N} (0,\eta^2)", fill_color=GOLD_D, ).scale(0.8).next_to(f, DOWN, buff=0.5)
        gp = Text("Gaussian Field Code", color=BLUE).scale(0.5).next_to(arrow1.get_center(), LEFT, buff=0.4).shift(UP*0.4)
        ml = Text("Maximum Likelihood", color=RED).scale(0.5).next_to(arrow2.get_center(), RIGHT, buff=0.4).shift(UP*0.4)
        self.play(Write(gp), Write(ml), Write(eta))
        self.pause()

        mse=MathTex(
            "\mathbb{E} \left[ \left( \mathbb{X} - \hat{\mathbb{X}} \\right)^2 \\right] =","\\frac{\eta^2}{N \cdot r''}","+",
            "\\frac{L^3}{2 \pi} \sqrt{\\frac{r''}{2r}} \left(1 + \\frac{r}{2\eta^2} \\right)^{-N/2}"
        ).to_corner(DOWN)
        mse_up = MathTex(
            "\mathbb{E} \left[ \left( \mathbb{X} - \hat{\mathbb{X}} \\right)^2 \\right] =","\\frac{\eta^2}{N \cdot r''}","+",
            "\\frac{L^3}{2 \pi} \sqrt{\\frac{r''}{2r}} \left(1 + \\frac{r}{2\eta^2} \\right)^{-N/2}"
        ).scale(1.1).to_corner(UP)

        self.play(Write(mse))
        framebox1 = SurroundingRectangle(mse_up[1], buff = .1, color=BLUE)
        framebox2 = SurroundingRectangle(mse_up[3], buff = .1, color=RED)

        self.play(ReplacementTransform(mse, mse_up), FadeOut(gp), FadeOut(ml), FadeOut(arrow1), FadeOut(arrow2), FadeOut(x), FadeOut(f), FadeOut(x_hat), FadeOut(eta))
        self.pause()

        self.play(
            Create(framebox1),
        )
        self.pause()

        ######animation!!!!

        a = ValueTracker(2)
        ax = Axes(x_range=[-20, 20, 1], y_range=[0, 6, 1], x_length=14, y_length=3, axis_config={"include_tip": False})
        parabola = ax.plot(lambda x: 
            np.maximum(0,-6*(x-a.get_value())*(x+a.get_value())/(a.get_value()**2)) 
            + np.maximum(0,-6*(x-5*a.get_value())*(x-7*a.get_value())/((a.get_value())**2))
            + np.maximum(0,-6*(x+5*a.get_value())*(x+7*a.get_value())/((a.get_value())**2)), 
        color=RED)
        parabola.add_updater(
            lambda mob: mob.become(ax.plot(lambda x: 
            np.maximum(0,-6*(x-a.get_value())*(x+a.get_value())/(a.get_value()**2)) 
            + np.maximum(0,-6*(x-5*a.get_value())*(x-7*a.get_value())/((a.get_value())**2))
            + np.maximum(0,-6*(x+5*a.get_value())*(x+7*a.get_value())/((a.get_value())**2)), 
        color=RED))
        )
        a_number = DecimalNumber(
            a.get_value(),
            color=RED,
            num_decimal_places=1,
            show_ellipsis=True
        ).next_to(ax.coords_to_point(a.get_value(), 0), DOWN)
        a_number.add_updater(
            lambda mob: mob.set_value(a.get_value()).next_to(parabola, DOWN)
        )

        def get_rectangle():
            polygon = Polygon(
                *[
                    ax.c2p(*i)
                    for i in self.get_rectangle_corners(
                        (5,6), a.get_value()
                    )
                ]
            )
            polygon.stroke_width = 1
            polygon.set_fill(BLUE, opacity=0.5)
            polygon.set_stroke(YELLOW_B)
            return polygon
        polygon = always_redraw(get_rectangle)

        area1 = ax.get_area(graph=parabola, x_range=(10,30), color=RED_A, opacity=0.3)
        area2 = ax.get_area(graph=parabola, x_range=(-30,-10), color=RED_A, opacity=0.3)

        self.play(Create(ax),Create(parabola), Write(a_number))
        self.pause()
        self.play(Create(polygon))
        self.pause()
        self.play(a.animate.set_value(5), run_time=5, rate_func=linear)
        self.pause()
        self.play(a.animate.set_value(2), run_time=5, rate_func=linear)
        self.pause()
        self.play(Create(area1), Create(area2),ReplacementTransform(framebox1,framebox2))
        self.pause()
        self.play(FadeOut(framebox2), FadeOut(area1), FadeOut(area2), FadeOut(polygon), FadeOut(parabola), FadeOut(a_number), FadeOut(ax))
        self.pause()
        self.add(rf_global_error)
        self.pause(3)

        th = Title("Optimal Thresholding")
        self.play(Write(th),FadeOut(rf_global_error), mse_up.animate.move_to(ORIGIN))
        self.pause()

        mse_delta = MathTex(
            "\mathbb{E} \left[ \left( \mathbb{X} - \hat{\mathbb{X}} \\right)^2 \\right] =","\\frac{\eta^2}{N \cdot r''}","+",
            "\\frac{L^3}{2 \pi} \sqrt{\\frac{r''}{2r}} \left(1 + \\frac{\Delta}{2\eta^2} \\right)^{-N/2}"
        ).scale(1.1)

        self.play(FadeOut(th),ReplacementTransform(mse_up, mse_delta))
        self.pause()

        self.play(mse_delta.animate.scale(0.8).to_corner(UP))
        self.pause()

        self.add(delta)
        self.pause(3)

        self.play(ReplacementTransform(delta, rf_threshold_error))
        self.pause(3)

        

In [None]:
%%manim -qm trash

class trash(Scene):
    def construct(self):
        fc=MathTex(
            "\\text{Field Count} \sim ",
            "\mathcal{P} \left[ \\frac{2 \phi \left( \\frac{\theta}{r}\\right)}{\\frac{1}{\pi} \sqrt{\\frac{r''}{r}} e^{\\frac{\theta^2}{2}}} \\right]"
        ).to_corner(UL).scale(0.6).shift(0.5*UL)

        fs = MathTex(
            "\mathbb{P} \left( \\text{Field Size} = k \\right) = ", "\\frac{2 \\beta}{D}", "k^{\\frac{2}{D} -1}"
            , "\exp \left( - \\beta k^{\\frac{2}{D}} \\right)"
        ).to_corner(UR).scale(0.6).shift(0.5*RIGHT)

        self.play(Write(fc), Write(fs))
        self.pause(5)