In [14]:
import numpy as np
import manim as mn
from manim import *

config.media_width = "75%"
config.verbosity = "WARNING"

print(mn.__version__)

0.19.0


# Rescorla Wagner Parameters

## Key Equations
### Associative strength update
$\Delta \vec{S}= \alpha\cdot \vec{c} (T-\sum S)$
### Prediction is the sum of cues times their associative strengths
$S_t = \vec{c^T}\cdot \vec{S}$
<!-- ### Learning rate matrix - update gain times diagonal matrix of cue saliences 
$\vec{A}=\beta \cdot \text{diag}(\alpha)$ -->
### Update Rule
$S_{t+1} = S_t + \Delta S$

## Single Cue Example

In [2]:
# let's assume this is constant for the conditioned stimuli
update_gain = 0.05
# we start out with 0 associative strength for just one US
S = np.zeros((1,1))
# track trials
s_log = []
error_log = []
trials = 500

# for each trial we are exposed to 2 cues that make up our current state
for i in range(trials):
    # randomly generate cue - present (1) or absent (1)
    cues = np.random.choice([1,0], size=(1,1), replace=True)
    
    # calculate predicted outcome -> will be a single value
    prediction = np.dot(cues.T,S)

    # target is predicted by CS
    T = 1 if cues.flatten()[0] else 0
    
    # calculate prediction error
    delta_S = update_gain*cues*(T-prediction)

    S = S + delta_S
    s_log.append(S.T.flatten().tolist())
    error_log.append(((T-prediction)**2).flatten().item())

## Visualization - We see striaghtforward associative learning

In [10]:
%%manim -qm AnimateGraph

class AnimateGraph(Scene):
    def construct(self):
        # Define the axes for the graph
        axes = Axes(
            x_range=[0, trials-1, 50],  # Min, Max, Step
            y_range=[0, 1.25, 0.25],
            axis_config={"include_numbers": True}
        ).scale(.8)
        # Add labels to the axes (optional)
        x_label = axes.get_x_axis_label("t")
        y_label = axes.get_y_axis_label("S")

        self.add(axes, x_label, y_label)

        # access the log from simulation
        def cue1(x):
            # this is the learned strength for cue 1
            return s_log[int(x)][0]

        def error(x):
            return error_log[int(x)]
           
        # Create the graph mobject
        graph1 = axes.plot(cue1, color=BLUE)
        graph4 = axes.plot(error, color=WHITE)

        # add text to describe cues
        text1 = Text("Cue 1", font_size=32, color=BLUE).next_to(graph1, LEFT, buff=0.75)
        
        # Animate the creation of the graph
        self.play(Write(text1))
        self.play(Create(graph1),run_time=5)
        self.play(FadeOut(text1))
        self.wait(1) 

        self.play(FadeOut(y_label, graph1, shift=LEFT, scale=0.5))

        # make this disapear and plot the error
        y_label = axes.get_y_axis_label("MS Error")
        self.add(y_label)
        self.play(Create(graph4),run_time=5)
        self.wait(1) 

                                                                                                                       

## Multiple Cue Example - Cues 2 and 3 are predictive, 1 is not!

In [21]:
# let's assume this is constant for the unconditioned stimuli
update_gain = 0.05
# we start out with 0 associative strength for all cues
S = np.zeros((3,1))
# track trials
s_log = []
error_log = []
trials = 500

# for each trial we are exposed to 2 cues that make up our current state
for i in range(trials):
    # randomly generate three cues - present (1) or absent (1)
    # cues = np.array([[1 if thresh <=.35 else 0],[1 if thresh <=.3 else 0],[1 if thresh <=.25 else 0]])
    cues = np.random.choice([1,0], size=(3,1), replace=True)
    
    # calculate predicted outcome -> will be a single value
    prediction = np.dot(cues.T,S)

    # target is predicted by cues 2 or 3
    # T = np.random.choice([1,0], 1, p=[.9,.1] if cues.flatten()[1]+cues.flatten()[2]==2 else [.2,.8])
    T = 1 if cues.flatten()[1] or cues.flatten()[2] else 0
    
    # calculate prediction error
    delta_S = update_gain*cues*(T-prediction)

    S = S + delta_S
    s_log.append(S.T.flatten().tolist())
    error_log.append(((T-prediction)**2).flatten().item())

## Visualization - learning occurs for cues 2 and 3, but 1 remains low associative strength due to lack of predictive power

In [16]:
%%manim -ql AnimateGraph

class AnimateGraph(Scene):
    def construct(self):
        # Define the axes for the graph
        axes = Axes(
            x_range=[0, trials-1, 50],  # Min, Max, Step
            y_range=[0, 1.25, 0.25],
            axis_config={"include_numbers": True}
        ).scale(.8)
        # Add labels to the axes (optional)
        x_label = axes.get_x_axis_label("t")
        y_label = axes.get_y_axis_label("S")

        self.add(axes, x_label, y_label)

        # access the log from simulation
        def cue1(x):
            # this is the learned strength for cue 1
            return s_log[int(x)][0]
        def cue2(x):
            return s_log[int(x)][1]
        def cue3(x):
            return s_log[int(x)][2]

        def error(x):
            return error_log[int(x)]
           
        # Create the graph mobject
        graph1 = axes.plot(cue1, color=BLUE)
        graph2 = axes.plot(cue2, color=RED)
        graph3 = axes.plot(cue3, color=GREEN)
        graph4 = axes.plot(error, color=WHITE)

        # add text to describe cues
        text1 = Text("Cue 1", font_size=32, color=BLUE).next_to(graph1, LEFT, buff=0.75)
        text2 = Text("Cue 2", font_size=32, color=RED).next_to(graph2, LEFT, buff=0.75)
        text3 = Text("Cue 3", font_size=32, color=GREEN).next_to(graph3, LEFT, buff=0.75)
        
        # Animate the creation of the graph
        self.play(Write(text1))
        self.play(Create(graph1),run_time=5)
        self.play(FadeOut(text1))
        self.wait(1) 
        self.play(Write(text2))
        self.play(Create(graph2),run_time=5)
        self.play(FadeOut(text2))
        self.wait(1) 
        self.play(Write(text3))
        self.play(Create(graph3),run_time=5)
        self.wait(1) 

        self.play(FadeOut(y_label, text3, graph1, graph2, graph3, shift=LEFT, scale=0.5))

        # make this disapear and plot the error
        y_label = axes.get_y_axis_label("MS Error")
        self.add(y_label)
        self.play(Create(graph4),run_time=5)
        self.wait(1) 

                                                                                                                       

## Blocking Example

Blocking (A trained alone, then A+B compound — we expect B to learn little).

In [17]:
# let's assume this is constant for the unconditioned stimuli
update_gain = 0.05
# we start out with 0 associative strength for just 2 cues
S = np.zeros((2,1))
# track trials
s_log = []
error_log = []
trials = 500
additional_cue_t = 250

# for each trial we are exposed to 2 cues that make up our current state
for i in range(trials):
    if i < additional_cue_t:
        # Phase 1: only A present
        cues = np.array([[1],[0]])
    else:
        # Phase 2: A and B always together
        cues = np.array([[1],[1]])
    
    prediction = np.dot(cues.T, S)
    T = 1  # US always present in both phases
    delta_S = update_gain * cues * (T - prediction)
    S = S + delta_S
    s_log.append(S.T.flatten().tolist())
    error_log.append(((T - prediction)**2).flatten().item())

## Viz - Because Cue 1 sufficiently explains the US, the addition of C2 doesn't change much

In [18]:
%%manim -qm AnimateGraph

class AnimateGraph(Scene):
    def construct(self):
        # Define the axes for the graph
        axes = Axes(
            x_range=[0, trials-1, 50],  # Min, Max, Step
            y_range=[0, 1.25, 0.25],
            axis_config={"include_numbers": True}
        ).scale(.8)
        # Add labels to the axes (optional)
        x_label = axes.get_x_axis_label("t")
        y_label = axes.get_y_axis_label("S")

        self.add(axes, x_label, y_label)

        # access the log from simulation
        def cue1(x):
            # this is the learned strength for cue 1
            return s_log[int(x)][0]
        def cue2(x):
            return s_log[int(x)][1]

        def error(x):
            return error_log[int(x)]
           
        # Create the graph mobject
        graph1 = axes.plot(cue1, color=BLUE)
        graph2 = axes.plot(cue2, color=RED)
        graph4 = axes.plot(error, color=WHITE)

        # add text to describe cues
        text1 = Text("Cue 1", font_size=32, color=BLUE).next_to(graph1, LEFT, buff=0.75)
        text2 = Text("Cue 2", font_size=32, color=RED).next_to(graph2, LEFT, buff=0.75)
        
        # Animate the creation of the graph
        self.play(Write(text1))
        self.play(Create(graph1),run_time=5)
        self.play(FadeOut(text1))
        self.wait(1) 
        self.play(Write(text2))
        self.play(Create(graph2),run_time=5)
        self.play(FadeOut(text2))
        self.wait(1) 

        self.play(FadeOut(y_label, graph1, graph2, shift=LEFT, scale=0.5))

        # make this disapear and plot the error
        y_label = axes.get_y_axis_label("MS Error")
        self.add(y_label)
        self.play(Create(graph4),run_time=5)
        self.wait(1) 

                                                                                                                       

## Latent Inhibition Failure

In [19]:
update_gain = 0.05

def run_latent_inhibition(pre_expose=False):
    S = np.zeros((1,1))
    s_log = []
    error_log = []
    trials_pre = 250
    trials_acq = 250

    # Phase 1: pre-exposure (CS alone, no US)
    if pre_expose:
        for _ in range(trials_pre):
            cues = np.array([[1]])   # CS present
            prediction = np.dot(cues.T, S)
            T = 0                    # no US
            delta_S = update_gain * cues * (T - prediction)
            S = S + delta_S
            s_log.append(S.item())
            error_log.append(((T - prediction)**2).flatten().item())

    else: 
        trials_acq = 500

    # Phase 2: acquisition (CS + US)
    for _ in range(trials_acq):
        cues = np.array([[1]])
        prediction = np.dot(cues.T, S)
        T = 1                        # US always present
        delta_S = update_gain * cues * (T - prediction)
        S = S + delta_S
        s_log.append(S.item())
        error_log.append(((T - prediction)**2).flatten().item())

    
    return s_log, error_log

no_pre_s, no_pre_e = run_latent_inhibition(pre_expose=False)
with_pre_s, with_pre_e = run_latent_inhibition(pre_expose=True)


## Viz
In real experiments, the pre-exposed group learns *slower* (latent inhibition).

In [20]:
%%manim -qm AnimateGraph

class AnimateGraph(Scene):
    def construct(self):
        # Define the axes for the graph
        axes = Axes(
            x_range=[0, trials-1, 50],  # Min, Max, Step
            y_range=[0, 1.25, 0.25],
            axis_config={"include_numbers": True}
        ).scale(.8)
        # Add labels to the axes (optional)
        x_label = axes.get_x_axis_label("t")
        y_label = axes.get_y_axis_label("S")

        self.add(axes, x_label, y_label)

        # access the log from simulation
        def cue1(x):
            # this is the learned strength for cue 1
            return no_pre_s[int(x)]
        def cue2(x):
            return with_pre_s[int(x)]
           
        # Create the graph mobject
        graph1 = axes.plot(cue1, color=GRAY)
        graph2 = axes.plot(cue2, color=WHITE)
        
        
        # Animate the creation of the graph

        title = Title("Control group (no pre-exposure)", include_underline=False)
        self.add(title)
        
        self.play(Create(graph1),run_time=5)
        self.wait(1) 

        self.play(FadeOut(title))
        title = Title("Pre-exposed group (latent inhibition group)", include_underline=False)
        
        point = axes @ (250, 1)
        dot = Dot(point)
        phase2 = axes.get_vertical_line(point, line_func=Line, color=RED)
        text1 = Text("Phase 1: pre-exposure \n(CS alone, no US)").next_to(title, DOWN).scale(0.5).shift([-2,0,0])
        text2 = Text("Phase 2: acquisition \n(CS + US)").next_to(text1, RIGHT).scale(0.5).shift([-1,0,0])
        
        self.add(title)
        self.wait(1) 
        self.play(Create(phase2))
        self.wait(1) 
        self.play(Create(text1))
        self.wait(1) 
        self.play(Create(text2))
        self.wait(1)
        self.play(Create(graph2),run_time=5)
        self.wait(1) 

                                                                                                                       