In [8]:
from manim import *
import numpy as np

In [13]:
%%manim -ql FMRIVisualization


class FMRIVisualization(Scene):
    def construct(self):
        # Configuration
        animation_duration = 15  # Shortened animation duration to 10 seconds
        total_duration = 40  # Extended total data duration to 40 seconds
        go_duration = 1  # Duration of "go" state (in seconds)
        stop_duration = 7  # Duration of "stop" state (in seconds)
        window_size = 10  # Show 10 seconds of data at a time

        # Create the stimulus display screen
        screen = Rectangle(height=2, width=3)
        screen.set_stroke(WHITE, 2)
        screen.move_to(np.array([-5, 2.5, 0]))
        self.add(screen)

        # The stimulus text
        stimulus_text = Text("stop", color=RED).scale(0.7)
        stimulus_text.move_to(screen.get_center())
        self.add(stimulus_text)

        screen_label = Text("Subject's Screen", font_size=20)
        screen_label.next_to(screen, UP, buff=0.2)
        self.add(screen_label)

        # Create the brain slice image
        # Use a placeholder if you don't have the actual image
        try:
            brain_slice = ImageMobject("assets/brain-sagittal.png")
            brain_slice.set_height(3)
        except:
            brain_slice = Rectangle(height=3, width=3, color=GRAY, fill_opacity=0.5)
            mock_brain = Circle(radius=1.3, color=LIGHT_GRAY, fill_opacity=0.5)
            mock_brain.move_to(brain_slice.get_center())
            brain_slice.add(mock_brain)

        brain_slice.move_to(np.array([-5, -1, 0]))
        self.add(brain_slice)

        # Define two voxels of interest in the brain
        voxel1 = Square(side_length=0.2, color=YELLOW, fill_opacity=0.5)
        voxel1.move_to(brain_slice.get_center() + RIGHT * 0.6 + UP * 0.3)
        self.add(voxel1)

        voxel2 = Square(side_length=0.2, color=GREEN, fill_opacity=0.5)
        voxel2.move_to(brain_slice.get_center() + LEFT * 0.7 + UP * 0.5)
        self.add(voxel2)

        # Create axes for the main trigger graph
        trigger_axes = Axes(
            x_range=[0, window_size, 5],
            y_range=[0, 1.5, 0.5],
            axis_config={"include_tip": False},
            x_length=6,
            y_length=2,
        )
        trigger_axes.move_to(np.array([2, 2.5, 0]))
        self.add(trigger_axes)

        # Create axes for the two voxel time series
        voxel2_axes = Axes(
            x_range=[0, window_size, 5],
            y_range=[-0.4, 0.6, 0.2],
            axis_config={"include_tip": False},
            x_length=6,
            y_length=2,
        )
        voxel2_axes.move_to(np.array([2, 0, 0]))
        self.add(voxel2_axes)

        voxel1_axes = Axes(
            x_range=[0, window_size, 5],
            y_range=[-0.2, 1.2, 0.2],
            axis_config={"include_tip": False},
            x_length=6,
            y_length=2,
        )
        voxel1_axes.move_to(np.array([2, -2.5, 0]))
        self.add(voxel1_axes)

        # Add labels
        trigger_label = Text("Stimulus", font_size=20)
        trigger_label.next_to(trigger_axes, UP, buff=0.2)
        self.add(trigger_label)

        voxel1_label = Text("Task-related Voxel", font_size=20, color=YELLOW)
        voxel1_label.next_to(voxel1_axes, LEFT, buff=0.2)
        self.add(voxel1_label)

        voxel2_label = Text("Non-task Voxel", font_size=20, color=GREEN)
        voxel2_label.next_to(voxel2_axes, LEFT, buff=0.2)
        self.add(voxel2_label)

        # Create connection lines from voxels to graphs
        line1_start = voxel1.get_corner(UR)
        line1_end = voxel1_label.get_center() + UP * 0.2
        connecting_line1 = DashedLine(line1_start, line1_end, color=YELLOW)
        self.add(connecting_line1)

        line2_start = voxel2.get_corner(UR)
        line2_end = voxel2_label.get_center() + DOWN * 0.2
        connecting_line2 = DashedLine(line2_start, line2_end, color=GREEN)
        self.add(connecting_line2)

        # Function to model hemodynamic response function (HRF)
        def hrf(t, scale_factor=1.0):
            """
            Generate a canonical hemodynamic response function (HRF) with improved initial dip,
            peak response, and post-stimulus undershoot, compressed to fit in 7 seconds.

            This function works consistently whether called with a single time point
            or an array of time points.

            Parameters:
            -----------
            t : float or numpy.ndarray
                Time point(s) in seconds.
            scale_factor : float
                Scaling factor for amplitude.

            Returns:
            --------
            hrf_values : float or numpy.ndarray
                HRF values at the specified time point(s).
            """
            # Compression factor (standard HRF is ~25-30s)
            compression = 7.0 / 25.0

            # Convert input to numpy array if it's not already
            t_array = np.asarray(t)
            single_input = np.isscalar(t)

            # Parameters for the three components (initial dip, peak, undershoot)
            # Scaled to fit in 7 seconds

            # Initial dip component - enhanced and more precisely timed
            dip_center = 1.0 * compression
            dip_width = 0.8 * compression
            dip_amplitude = 0.25  # Increased for visibility

            # Main positive response
            peak_time = 6.0 * compression
            peak_dispersion = 1.0 * compression
            peak_amplitude = 1.0

            # Post-stimulus undershoot
            undershoot_time = 16.0 * compression
            undershoot_dispersion = 1.0 * compression
            undershoot_amplitude = 0.2

            # Calculate the three components

            # Initial dip - using a more controlled Gaussian-like function
            # This ensures it appears early and has appropriate width
            dip = -dip_amplitude * np.exp(-((t_array - dip_center) ** 2) / (2 * dip_width ** 2))
            # Only apply dip between t=0 and t=peak_start
            peak_start = 2.0 * compression
            dip_mask = (t_array > 0) & (t_array < peak_start)
            dip = np.where(dip_mask, dip, 0)

            # Main positive response
            positive = ((t_array / peak_time) ** 6) * np.exp(-(t_array - peak_time) / peak_dispersion) * peak_amplitude
            positive = np.where(t_array > peak_start * 0.8, positive, 0)  # Start after initial dip

            # Post-stimulus undershoot (delayed negative component)
            negative = ((t_array / undershoot_time) ** 12) * np.exp(
                -(t_array - undershoot_time) / undershoot_dispersion) * undershoot_amplitude
            negative = np.where(t_array > peak_time, negative, 0)  # Start after peak

            # Pre-computed maximum value for normalization
            # This is the value at time ~1.69 seconds (peak of the response)
            # This ensures consistent normalization regardless of input
            max_positive_value = 1.0  # Pre-computed for t range from 0 to 8

            # Normalize the positive component using pre-computed maximum
            positive = positive / max_positive_value * peak_amplitude

            # Combine all components
            hrf_values = dip + positive - negative

            # Apply additional overall scaling if requested
            if scale_factor != 1.0:
                hrf_values = hrf_values * scale_factor

            # Return a scalar if input was a scalar
            if single_input:
                return np.asscalar(hrf_values) if hasattr(np, 'asscalar') else hrf_values.item()
            else:
                return hrf_values

        # Generate data points for the entire 40 seconds simulation
        time_points = np.arange(0, total_duration + 0.1, 0.1)  # 0.1s resolution

        # Store all data points
        trigger_values = []
        voxel1_values = []
        voxel2_values = []
        trigger_times = []

        # Generate stimulus pattern and trigger points for the full 40 seconds
        for t in time_points:
            cycle_position = t % (go_duration + stop_duration)
            is_go = cycle_position < go_duration

            # Record trigger value
            trigger_values.append(1.0 if is_go else 0.0)

            # Record trigger event start times
            if cycle_position < 0.1 and is_go:
                trigger_times.append(t)

        # Generate voxel responses for the full 40 seconds
        for i, t in enumerate(time_points):
            # Task-related voxel with HRF convolution
            voxel1_value = 0
            for trigger_time in trigger_times:
                if t > trigger_time:
                    # Apply hemodynamic response function with delay
                    time_since_trigger = t - trigger_time
                    if time_since_trigger <= 8:
                        voxel1_value = hrf(time_since_trigger)
                        break

            voxel1_value += 0.05 * np.random.randn()
            voxel1_values.append(voxel1_value)

            # Non-task voxel (noise)
            slow_drift = 0.02 * np.sin(0.05 * t)
            cardiac = 0.05 * np.sin(1.2 * t)  # ~72 bpm
            respiratory = 0.07 * np.sin(0.2 * t)  # ~12 breaths per minute
            white_noise = 0.03 * np.random.randn()
            voxel2_value = 0.1 + slow_drift + cardiac + respiratory + white_noise
            voxel2_values.append(voxel2_value)

        # Initialize empty graphs
        trigger_graph = VMobject(color=BLUE)
        voxel1_graph = VMobject(color=YELLOW)
        voxel2_graph = VMobject(color=GREEN)
        self.add(trigger_graph, voxel1_graph, voxel2_graph)

        # Create a group for trigger lines
        trigger_lines_group = VGroup()
        self.add(trigger_lines_group)

        # Create time tracker
        time_tracker = ValueTracker(0)

        # This is the key fix - create empty graphs and update them properly
        def get_trigger_graph(t):
            start_time = max(0, t - window_size)
            end_time = t

            start_idx = int(start_time * 10)
            end_idx = min(int(end_time * 10), len(time_points) - 1)

            if start_idx >= end_idx:
                return VMobject()

            visible_times = time_points[start_idx:end_idx + 1]
            visible_trigger = trigger_values[start_idx:end_idx + 1]

            # Shift x-values to start at 0 in the window
            visible_x = [t - start_time for t in visible_times]

            graph = trigger_axes.plot_line_graph(
                x_values=visible_x,
                y_values=visible_trigger,
                line_color=BLUE,
                add_vertex_dots=False
            )
            return graph

        def get_voxel1_graph(t):
            start_time = max(0, t - window_size)
            end_time = t

            start_idx = int(start_time * 10)
            end_idx = min(int(end_time * 10), len(time_points) - 1)

            if start_idx >= end_idx:
                return VMobject()

            visible_times = time_points[start_idx:end_idx + 1]
            visible_voxel1 = voxel1_values[start_idx:end_idx + 1]

            # Shift x-values to start at 0 in the window
            visible_x = [t - start_time for t in visible_times]

            graph = voxel1_axes.plot_line_graph(
                x_values=visible_x,
                y_values=visible_voxel1,
                line_color=YELLOW,
                add_vertex_dots=False
            )
            return graph

        def get_voxel2_graph(t):
            start_time = max(0, t - window_size)
            end_time = t

            start_idx = int(start_time * 10)
            end_idx = min(int(end_time * 10), len(time_points) - 1)

            if start_idx >= end_idx:
                return VMobject()

            visible_times = time_points[start_idx:end_idx + 1]
            visible_voxel2 = voxel2_values[start_idx:end_idx + 1]

            # Shift x-values to start at 0 in the window
            visible_x = [t - start_time for t in visible_times]

            graph = voxel2_axes.plot_line_graph(
                x_values=visible_x,
                y_values=visible_voxel2,
                line_color=GREEN,
                add_vertex_dots=False
            )
            return graph

        def get_trigger_lines(t):
            start_time = max(0, t - window_size)
            end_time = t

            lines = VGroup()
            for trigger_time in trigger_times:
                if start_time <= trigger_time <= end_time:
                    # Position in the current window
                    x_pos = trigger_time - start_time

                    # Trigger line in trigger graph
                    line1 = DashedLine(
                        trigger_axes.c2p(x_pos, 0),
                        trigger_axes.c2p(x_pos, 1.5),
                        color=WHITE,
                        dash_length=0.05
                    )

                    # Line in voxel2 graph
                    line2 = DashedLine(
                        voxel2_axes.c2p(x_pos, -0.5),
                        voxel2_axes.c2p(x_pos, 1.5),
                        color=WHITE,
                        dash_length=0.05
                    )

                    # Line in voxel1 graph
                    line3 = DashedLine(
                        voxel1_axes.c2p(x_pos, -0.5),
                        voxel1_axes.c2p(x_pos, 1.5),
                        color=WHITE,
                        dash_length=0.05
                    )

                    lines.add(line1, line2, line3)

            return lines

        # Update function for trigger graph
        def update_trigger_graph(mob):
            t = time_tracker.get_value()
            if t <= animation_duration:  # Changed to animation_duration
                new_graph = get_trigger_graph(t)
                mob.become(new_graph)

        # Update function for voxel1 graph
        def update_voxel1_graph(mob):
            t = time_tracker.get_value()
            if t <= animation_duration:  # Changed to animation_duration
                new_graph = get_voxel1_graph(t)
                mob.become(new_graph)

        # Update function for voxel2 graph
        def update_voxel2_graph(mob):
            t = time_tracker.get_value()
            if t <= animation_duration:  # Changed to animation_duration
                new_graph = get_voxel2_graph(t)
                mob.become(new_graph)

        # Update function for trigger lines
        def update_trigger_lines(mob):
            t = time_tracker.get_value()
            if t <= animation_duration:  # Changed to animation_duration
                new_lines = get_trigger_lines(t)
                mob.become(new_lines)

        # Update function for stimulus text
        def update_stimulus_text(mob):
            t = time_tracker.get_value()
            if t <= animation_duration:  # Changed to animation_duration
                cycle_position = t % (go_duration + stop_duration)
                is_go = cycle_position < go_duration

                if is_go:
                    mob.become(Text("Go", color=GREEN).scale(0.7).move_to(screen.get_center()))
                else:
                    mob.become(Text("Stop", color=RED).scale(0.7).move_to(screen.get_center()))

        def update_voxel_brightness(mob, voxel_number):
            t = time_tracker.get_value()
            if t <= animation_duration:  # Changed to animation_duration
                if voxel_number == 1:  # Task-related voxel
                    # Get current voxel value
                    idx = min(int(t * 10), len(voxel1_values) - 1)
                    if idx >= 0:
                        value = voxel1_values[idx]
                        # Map value to opacity and color intensity
                        opacity = min(0.5 + value * 0.5, 1.0)
                        # Brighter yellow with higher values
                        color_intensity = np.clip(0.8 + value * 0.2, 0, 1)
                        color = interpolate_color(YELLOW, WHITE, color_intensity - 0.8)
                        mob.set_fill(color, opacity=opacity)

        # Add updaters to the appropriate objects
        trigger_graph.add_updater(update_trigger_graph)
        voxel1_graph.add_updater(update_voxel1_graph)
        voxel2_graph.add_updater(update_voxel2_graph)
        trigger_lines_group.add_updater(update_trigger_lines)
        stimulus_text.add_updater(update_stimulus_text)
        voxel1.add_updater(lambda m: update_voxel_brightness(m, 1))

        # Explicitly set camera to show everything
        self.camera.frame_width = 16
        self.camera.frame_height = 9

        timeseries_creation_group = Group(brain_slice, voxel1, voxel2, screen, stimulus_text,
                                          screen_label, trigger_label, voxel1_label, voxel2_label,
                                          connecting_line1, connecting_line2, trigger_axes, voxel1_axes, voxel2_axes,
                                          trigger_graph, voxel1_graph, voxel2_graph, trigger_lines_group)

        # Animate the time tracker for only 10 seconds
        self.play(time_tracker.animate.set_value(animation_duration), run_time=animation_duration, rate_func=linear)

        # Remove updaters and let it sit at the final state
        trigger_graph.clear_updaters()
        voxel1_graph.clear_updaters()
        voxel2_graph.clear_updaters()
        trigger_lines_group.clear_updaters()
        stimulus_text.clear_updaters()
        voxel1.clear_updaters()

        self.wait(0.5)

        # Create axes for the full timeseries view with 40 seconds of data
        full_trigger_axes = Axes(
            x_range=[0, total_duration, 10],
            y_range=[0, 1.5, 0.5],
            axis_config={"include_tip": False},
            x_length=10,
            y_length=2,
        )
        full_trigger_axes.move_to(np.array([0, 2.5, 0]))

        full_voxel1_axes = Axes(
            x_range=[0, total_duration, 10],
            y_range=[-0.2, 1.2, 0.2],
            axis_config={"include_tip": False},
            x_length=10,
            y_length=2,
        )
        full_voxel1_axes.move_to(np.array([0, -2.5, 0]))

        full_voxel2_axes = Axes(
            x_range=[0, total_duration, 10],
            y_range=[-0.4, 0.6, 0.2],
            axis_config={"include_tip": False},
            x_length=10,
            y_length=2,
        )
        full_voxel2_axes.move_to(np.array([0, 0, 0]))

        # Create the full graphs using the 40 seconds of data
        full_trigger_graph = full_trigger_axes.plot_line_graph(
            x_values=time_points,
            y_values=trigger_values,
            line_color=BLUE,
            add_vertex_dots=False
        )

        full_voxel1_graph = full_voxel1_axes.plot_line_graph(
            x_values=time_points,
            y_values=voxel1_values,
            line_color=YELLOW,
            add_vertex_dots=False
        )

        full_voxel2_graph = full_voxel2_axes.plot_line_graph(
            x_values=time_points,
            y_values=voxel2_values,
            line_color=GREEN,
            add_vertex_dots=False
        )

        # Add vertical lines for trigger points
        full_trigger_lines = VGroup()
        for t in trigger_times:
            # Only include trigger points within the 40 second range
            if t <= total_duration:
                # Trigger line in trigger graph
                line1 = DashedLine(
                    full_trigger_axes.c2p(t, 0),
                    full_trigger_axes.c2p(t, 1.5),
                    color=WHITE,
                    dash_length=0.05
                )

                # Line in voxel1 graph
                line2 = DashedLine(
                    full_voxel1_axes.c2p(t, -0.2),
                    full_voxel1_axes.c2p(t, 1.2),
                    color=WHITE,
                    dash_length=0.05
                )

                # Line in voxel2 graph
                line3 = DashedLine(
                    full_voxel2_axes.c2p(t, -0.4),
                    full_voxel2_axes.c2p(t, 0.6),
                    color=WHITE,
                    dash_length=0.05
                )

                full_trigger_lines.add(line1, line2, line3)

        # Create labels for the full view
        full_trigger_label = Text("Stimulus", font_size=20)
        full_trigger_label.next_to(full_trigger_axes, UP, buff=0.2)

        full_voxel1_label = Text("Task-related Voxel", font_size=20, color=YELLOW)
        full_voxel1_label.next_to(full_voxel1_axes, UP, buff=0.2)

        full_voxel2_label = Text("Non-task Voxel", font_size=20, color=GREEN)
        full_voxel2_label.next_to(full_voxel2_axes, UP, buff=0.2)

        # Group everything for the full view
        full_view_group = VGroup(
            full_trigger_axes, full_voxel1_axes, full_voxel2_axes,
            full_trigger_graph, full_voxel1_graph, full_voxel2_graph,
            full_trigger_lines,
            full_trigger_label, full_voxel1_label, full_voxel2_label
        )

        # Animate transition to the full view
        if isinstance(brain_slice, ImageMobject):
            # For ImageMobject, use opacity directly
            self.play(
                brain_slice.animate.set_opacity(0),
                voxel1.animate.set_opacity(0),
                voxel2.animate.set_opacity(0),
                FadeOut(screen, stimulus_text, screen_label, trigger_label,
                        voxel1_label, voxel2_label, connecting_line1, connecting_line2,
                        trigger_axes, voxel1_axes, voxel2_axes,
                        trigger_graph, voxel1_graph, voxel2_graph,
                        trigger_lines_group),
                FadeIn(full_view_group),
                run_time=2
            )
        else:
            # For placeholder (Rectangle and Circle)
            self.play(
                FadeOut(brain_slice, voxel1, voxel2, screen, stimulus_text,
                        screen_label, trigger_label, voxel1_label, voxel2_label,
                        connecting_line1, connecting_line2, trigger_axes, voxel1_axes, voxel2_axes,
                        trigger_graph, voxel1_graph, voxel2_graph, trigger_lines_group),
                FadeIn(full_view_group),
                run_time=2
            )

        self.remove(timeseries_creation_group)

        self.wait(3)

        # Step 1: Remove non-task graph and reposition remaining graphs to the left
        self.play(
            FadeOut(full_voxel2_graph, full_voxel2_axes, full_voxel2_label, full_trigger_label, full_voxel1_label),
            # Remove all trigger lines first
            FadeOut(full_trigger_lines),
            full_trigger_axes.animate.move_to(np.array([-2.5, 1.5, 0])),  # Moved more to the left
            full_trigger_graph.animate.move_to(np.array([-2.5, 1.5, 0])),  # Moved more to the left
            full_voxel1_axes.animate.move_to(np.array([-2.5, -1.5, 0])),  # Moved more to the left
            full_voxel1_graph.animate.move_to(np.array([-2.5, -1.5, 0])),  # Moved more to the left
            run_time=2
        )

        # Add brace and label for stopping time (T_stop)
        stopping_time_start = trigger_times[2] + go_duration
        stopping_time_end = stopping_time_start + stop_duration - 1
        stop_brace = BraceBetweenPoints(
            full_trigger_axes.c2p(stopping_time_start, 0),
            full_trigger_axes.c2p(stopping_time_end, 0),
            direction=DOWN
        )
        stop_label = MathTex("T_{stop}").next_to(stop_brace, DOWN, buff=0.1)

        tr_sample_points = []
        tr_sample_lines = VGroup()
        tr_value = 1

        for t in np.arange(0, total_duration, tr_value):
            # Create a dot at each sample point
            x_coord = t
            y_coord = np.interp(t, time_points, voxel1_values)  # Get y-value from the graph

            sample_point = Dot(
                full_voxel1_axes.c2p(x_coord, y_coord),
                color=RED,
                radius=0.05
            )
            tr_sample_points.append(sample_point)

            # Create a light dashed line through each sample point
            sample_line = DashedLine(
                full_voxel1_axes.c2p(x_coord, -0.2),
                full_voxel1_axes.c2p(x_coord, 1.2),
                color=GREY,
                dash_length=0.05,
                stroke_opacity=0.5
            )
            tr_sample_lines.add(sample_line)

        tr_sample_points_group = VGroup(*tr_sample_points)

        # Add brace and label for TR (Repetition Time)
        tr_brace = BraceBetweenPoints(
            full_voxel1_axes.c2p(8, 0),
            full_voxel1_axes.c2p(9, 0),
            direction=DOWN
        )
        tr_label = MathTex("T_R").next_to(tr_brace, DOWN, buff=0.1)

        # Explanatory text
        tstop_explanation_text = VGroup(
            Tex(r"Stopping Time ($T_{\text{stop}}$): Long time to account for",
                font_size=24),
            Tex(r"the time it takes the hemodynamic reposnse",
                font_size=24),
            Tex(r"to finish (around 18 seconds).",
                font_size=24)
        ).arrange(DOWN, aligned_edge=LEFT, buff=0.2)
        tstop_explanation_text.move_to(np.array([-1, 3, 0]))

        tr_explanation_text = VGroup(
            Tex(r"Repetition Time ($T_R$): The time between consecutive RF pulses",
                font_size=24),
            Tex(r"that determines the temporal resolution of the fMRI scan.",
                font_size=24)
        ).arrange(DOWN, aligned_edge=LEFT, buff=0.2)
        tr_explanation_text.move_to(np.array([-1, -3.5, 0]))

        # Animate the new elements
        self.play(
            FadeIn(tr_sample_lines, lag_ratio=0.1),
            FadeIn(tr_sample_points_group, lag_ratio=0.1)
        )
        self.play(GrowFromCenter(stop_brace), Write(stop_label))
        self.play(GrowFromCenter(tr_brace), Write(tr_label))
        self.play(Write(tstop_explanation_text), Write(tr_explanation_text))

        self.wait(4)

        self.play(FadeOut(tr_explanation_text, tstop_explanation_text, stop_brace, stop_label, tr_brace, tr_label, tr_sample_lines, tr_sample_points_group))

        # Add main title for GLM analysis
        glm_title = Text("General Linear Model (GLM) Analysis", font_size=32)
        glm_title.to_edge(UP, buff=0.3)
        self.play(Write(glm_title), run_time=1)

        # Step 2: Show the equation and HRF on the right

        # Create HRF display on the right
        hrf_time = np.linspace(0, 8, 80)
        hrf_values = np.array([hrf(t) for t in hrf_time])

        hrf_axes = Axes(
            x_range=[0, 8, 2],
            y_range=[-0.3, 1.1, 0.2],
            axis_config={"include_tip": False},
            x_length=4,
            y_length=3,
        )
        hrf_axes.move_to(np.array([5.5, 1, 0]))  # Adjusted position

        hrf_graph = hrf_axes.plot_line_graph(
            x_values=hrf_time,
            y_values=hrf_values,
            line_color=PURPLE,
            add_vertex_dots=False
        )

        # Add HRF axis labels
        hrf_x_label = Text("Time (s)", font_size=16)
        hrf_x_label.next_to(hrf_axes.get_x_axis(), DOWN, buff=0.2)

        hrf_y_label = Text("Response", font_size=16)
        hrf_y_label.next_to(hrf_axes.get_y_axis(), LEFT, buff=0.2)

        hrf_title = Text("Hemodynamic Response Function", font_size=20, color=PURPLE)
        hrf_title.next_to(hrf_axes, UP, buff=0.2)

        hrf_group = VGroup(hrf_axes, hrf_graph, hrf_title)

        # Create GLM equation
        glm_equation = MathTex(
            r"Y(t) = \beta_0 + \beta_1 X(t) + \varepsilon",
            font_size=36
        )
        glm_equation.move_to(np.array([5.5, -1.5, 0]))  # Adjusted position

        # Create parameter explanations
        param_explanations = VGroup()

        y_exp = Text("Y(t): BOLD signal (measured)", font_size=16, color=YELLOW)
        y_exp.next_to(glm_equation, DOWN, buff=0.4, aligned_edge=LEFT)

        x_exp = Text("X(t): Model prediction (stimulus ⊗ HRF)", font_size=16, color=RED)
        x_exp.next_to(y_exp, DOWN, buff=0.1, aligned_edge=LEFT)

        beta0_exp = Text("β₀: Baseline signal", font_size=16)
        beta0_exp.next_to(x_exp, DOWN, buff=0.1, aligned_edge=LEFT)

        beta1_exp = Text("β₁: Neural activity coefficient", font_size=16, color=BLUE)
        beta1_exp.next_to(beta0_exp, DOWN, buff=0.1, aligned_edge=LEFT)

        epsilon_exp = Text("ε: Noise", font_size=16, color=GREEN)
        epsilon_exp.next_to(beta1_exp, DOWN, buff=0.1, aligned_edge=LEFT)

        param_explanations.add(y_exp, x_exp, beta0_exp, beta1_exp, epsilon_exp)

        # Show the HRF, equation, and parameters simultaneously
        self.play(
            FadeIn(hrf_group),
            Write(glm_equation),
            FadeIn(param_explanations),
            run_time=2
        )

        # Wait to let the viewer absorb the information
        self.wait(1)

        # Step 3: Apply convolution and show it on the stimulus graph

        # Generate ideal HRF convolution
        # Create a new stimulus timeline with higher resolution for smooth visualization
        high_res_time = np.linspace(0, total_duration, 400)
        high_res_stim = np.zeros_like(high_res_time)

        # Generate stimulus pattern
        for i, t in enumerate(high_res_time):
            cycle_position = t % (go_duration + stop_duration)
            high_res_stim[i] = 1.0 if cycle_position < go_duration else 0.0

        # Function to compute convolution
        def convolve_with_hrf(stimulus, times):
            # Time resolution
            dt = times[1] - times[0]

            # Create HRF kernel
            kernel_length = int(8 / dt)  # 8 seconds of HRF
            kernel_times = np.arange(0, 8, dt)
            hrf_kernel = np.array([hrf(t) for t in kernel_times])

            # Perform convolution
            convolved = np.convolve(stimulus, hrf_kernel) * dt

            # Trim to original length
            return convolved[:len(stimulus)]

        # Compute convolution
        predicted_bold = convolve_with_hrf(high_res_stim, high_res_time)

        # Create the graph for the convolved HRF to be overlaid on stimulus
        conv_graph = full_trigger_axes.plot_line_graph(
            x_values=high_res_time,
            y_values=predicted_bold,
            line_color=RED,
            add_vertex_dots=False
        )

        # Add convolution explanation text
        conv_text = Text("Convolution: Stimulus ⊗ HRF", font_size=20, color=RED)
        conv_text.next_to(full_trigger_axes, UP, buff=0.2)

        # Create new trigger lines for visual clarity
        new_trigger_lines = VGroup()
        for t in trigger_times:
            if t <= total_duration:
                # Trigger line in trigger graph
                line1 = DashedLine(
                    full_trigger_axes.c2p(t, 0),
                    full_trigger_axes.c2p(t, 1.5),
                    color=WHITE,
                    dash_length=0.05,
                    stroke_opacity=0.7  # Slightly transparent
                )

                # Line in voxel1 graph
                line2 = DashedLine(
                    full_voxel1_axes.c2p(t, -0.2),
                    full_voxel1_axes.c2p(t, 1.2),
                    color=WHITE,
                    dash_length=0.05,
                    stroke_opacity=0.7  # Slightly transparent
                )

                new_trigger_lines.add(line1, line2)

        # Animate the convolution and trigger lines
        self.play(
            FadeIn(new_trigger_lines),
            Write(conv_text),
            run_time=1
        )

        # Animation showing the convolution appearing gradually
        self.play(
            Create(conv_graph),
            run_time=2
        )

        self.wait(1)

        # Step 4: Show model fitting with larger arrow animation

        # Create larger arrow indicating fitting process
        fit_arrow = Arrow(
            full_trigger_axes.get_bottom() + DOWN * 0.3,
            full_voxel1_axes.get_top() + UP * 0.3,
            buff=0.1,
            color=BLUE,
            max_tip_length_to_length_ratio=0.25,  # Larger arrow tip
            stroke_width=3.5  # Thicker stroke
        )

        fit_text = Text("Model Fitting:\nEstimate β parameters to\nminimize error",
                        font_size=18, color=BLUE)
        fit_text.next_to(fit_arrow, RIGHT, buff=0.2)  # Moved to right side

        # Show fitting arrow and text
        self.play(
            GrowArrow(fit_arrow),
            Write(fit_text),
            run_time=2.5
        )

        # Highlight relevant parts of the equation during fitting
        highlight_beta = SurroundingRectangle(glm_equation, color=BLUE_B, buff=0.1)
        self.play(
            Create(highlight_beta),
            run_time=1.5
        )

        # Remove highlight after a moment
        self.play(
            FadeOut(highlight_beta),
            run_time=0.5
        )

        # Final wait
        self.wait(2)



                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                

                                                                                                                                                