## Imports

In [21]:
from manim import Matrix
from manim import *

## Configuration

In [None]:
background = WHITE
font_color = BLACK if background == WHITE else WHITE

config.background_color = background

Text.set_default(color=font_color)
Tex.set_default(color=font_color)
MathTex.set_default(color=font_color)

## Helper functions

In [23]:
row_n_dots = lambda n: [r"\bullet" for _ in range(n)]
matrix_elems_n_by_m = lambda n, m: [row_n_dots(m) for _ in range(n)]


# Creates an n x m Matrix of dots
def matrix_n_by_m(n, m):
    return Matrix(matrix_elems_n_by_m(n, m), element_alignment_corner=DOWN)


# Creates an n x n diagonal matrix with dots only on the diagonal
def diag_matrix_n(n):
    elems = [[r"\bullet" if i == j else "" for j in range(n)] for i in range(n)]
    return Matrix(elems, element_alignment_corner=DOWN)


def color_matrix(matrix, color):

    for entry in matrix.get_entries():
        entry.set_color(color)

    return matrix

In [24]:
def pick_rows(flat_mat, n, m, row_indices):
    assert all(0 <= r < n for r in row_indices), "Invalid row index"
    assert all(isinstance(r, int) for r in row_indices), "Row index must be an integer"
    assert n * m == len(flat_mat), "Invalid matrix dimensions"

    new_flat = []
    for r in row_indices:
        start_idx = r * m
        end_idx = start_idx + m
        new_flat.extend(flat_mat[start_idx:end_idx])

    return new_flat


def pick_columns(flat_mat, n, m, column_indices):
    """
    Select only certain columns from a flattened N x M matrix (row-major order).

    Parameters:
        flat_mat (List[any]): Flattened matrix of length n*m.
        n (int): Original number of rows.
        m (int): Original number of columns.
        column_indices (List[int]): Which column indices to keep.

    Returns:
        (new_flat, new_n, new_m):
            new_flat (List[any]): Flattened matrix containing only the selected columns.
            new_n (int): The new number of rows (same as original n).
            new_m (int): The new number of columns (len of column_indices).
    """

    assert all(0 <= c < m for c in column_indices), "Invalid column index"
    assert all(
        isinstance(c, int) for c in column_indices
    ), "Column index must be an integer"
    assert n * m == len(flat_mat), "Invalid matrix dimensions"

    new_flat = []
    for row in range(n):
        row_start = row * m
        for c in column_indices:
            new_flat.append(flat_mat[row_start + c])

    return new_flat

## Data and configs

In [33]:
color_dict = {
    "Eggshell": "#f4f1de",
    "Burnt sienna": "#e07a5f",
    "Delft Blue": "#3d405b",
    "Cambridge blue": "#81b29a",
    "Sunset": "#f2cc8f",
}

GREEN = color_dict["Cambridge blue"]
RED = color_dict["Burnt sienna"]
YELLOW = color_dict["Sunset"]

In [34]:
%%manim -qh -v WARNING Method 

class Method(Scene):
    def construct(self):

        step_1_text = Tex(r"1. Consider \textbf{per-layer task matrices} for task $i$ and $j$", tex_to_color_map={"per-layer task matrices": YELLOW}).to_edge(DOWN).shift(UP*0.5).scale(0.8)
        step_2_text = Tex(r"2. Compute their \textbf{SVD}", tex_to_color_map={"SVD": YELLOW}).to_edge(DOWN).shift(UP*0.5).scale(0.8)
        step_3_text = Tex(r"3. Perform a \textbf{low-rank} approximation", tex_to_color_map={'low-rank': YELLOW}).to_edge(DOWN).shift(UP*0.5).scale(0.8)
        step_4_text = Tex(r"4. Concatenate the \textbf{task singular vectors} and values", tex_to_color_map={'task singular vectors': YELLOW}).to_edge(DOWN).shift(UP*0.5).scale(0.8)
        step_5_text = Tex(r"5. Make the task singular vectors \textbf{orthogonal}", tex_to_color_map={'orthogonal': YELLOW}).to_edge(DOWN).shift(UP*0.5).scale(0.8)
        step_6_text = Tex(r"6. Compute the \textbf{new} per-layer task matrix", tex_to_color_map={'new': YELLOW}).to_edge(DOWN).shift(UP*0.5).scale(0.8)

        # -----------------------------------------------------
        # 2) Create the FIRST system: D_i and its SVD (U_i, Σ_i, V_i^T)
        # -----------------------------------------------------

        
        # D_i

        D_i = matrix_n_by_m(4, 3)
        # change color 
        D_i_label = MathTex(r"\Delta_i").next_to(D_i, DOWN).scale(1.5)
        
        # Equals sign
        equal_sign_i = MathTex("=").next_to(D_i, RIGHT, buff=0.5)
        approx_sign_i = MathTex(r"\approx").next_to(D_i, RIGHT, buff=0.5)

        # U_i
        U_i = matrix_n_by_m(4, 3).next_to(equal_sign_i, RIGHT, buff=0.5)
        lines_U_i = []
        for i in range(3):
            lines_U_i.append(Line(U_i.get_columns()[i].get_top(), U_i.get_columns()[i].get_bottom()).set_color(RED))
            U_i.add(lines_U_i[i])
        U_i_label = MathTex(r"U_i").next_to(U_i, DOWN).scale(1.5)
        for entry in U_i.get_entries():
            entry.set_opacity(0)

        # Sigma_i
        Sigma_i = diag_matrix_n(3)
        # Place it to the right of U_i
        Sigma_i.next_to(U_i, RIGHT, buff=1.0).shift(UP*0.0)
        Sigma_i_label = MathTex(r"\Sigma_i").next_to(Sigma_i, DOWN).scale(1.5)

        # V_i^T
        V_i_t = matrix_n_by_m(3, 3)
        V_i_t.next_to(Sigma_i, RIGHT, buff=1.0)
        lines_V_i_t = []
        for i in range(3):
            lines_V_i_t.append(Line(V_i_t.get_rows()[i].get_left(), V_i_t.get_rows()[i].get_right()).set_color(RED))
            V_i_t.add(lines_V_i_t[i])
        V_i_t_label = MathTex(r"V_i^T").next_to(V_i_t, DOWN).scale(1.5)
        for entry in V_i_t.get_entries():
            entry.set_opacity(0)

        # Group and position the entire (D_i, U_i, Σ_i, V_i^T)
        group_i = VGroup(
            D_i, D_i_label,
            equal_sign_i,
            U_i, U_i_label,
            Sigma_i, Sigma_i_label,
            V_i_t, V_i_t_label,
            approx_sign_i,
        )
        group_i.scale(0.5)

        group_i.move_to(ORIGIN).shift(UP*2)

        # -----------------------------------------------------
        # 5) Create the SECOND system: D_j (below the first one)
        # -----------------------------------------------------

        # D_j
        D_j = matrix_n_by_m(4, 3)
        D_j_label = MathTex(r"\Delta_j").next_to(D_j, DOWN).scale(1.5)

        # equals sign
        equal_sign_j = MathTex("=").next_to(D_j, RIGHT, buff=0.5)
        approx_sign_j = MathTex(r"\approx").next_to(D_j, RIGHT, buff=0.5)

        # U_j
        U_j = matrix_n_by_m(4, 3).next_to(equal_sign_j, RIGHT, buff=0.5)
        lines_U_j = []
        for i in range(3):
            lines_U_j.append(Line(U_j.get_columns()[i].get_top(), U_j.get_columns()[i].get_bottom()).set_color(GREEN))
            U_j.add(lines_U_j[i])        
        U_j_label = MathTex(r"U_j").next_to(U_j, DOWN).scale(1.5)
        for entry in U_j.get_entries():
            entry.set_opacity(0)

        # Sigma_j
        Sigma_j = diag_matrix_n(3).next_to(U_j, RIGHT, buff=1.0)
        Sigma_j_label = MathTex(r"\Sigma_j").next_to(Sigma_j, DOWN).scale(1.5)

        # V_j^T
        V_j_t = matrix_n_by_m(3, 3).next_to(Sigma_j, RIGHT, buff=1.0)
        lines_V_j_t = []
        for i in range(3):
            lines_V_j_t.append(Line(V_j_t.get_rows()[i].get_left(), V_j_t.get_rows()[i].get_right()).set_color(GREEN))
            V_j_t.add(lines_V_j_t[i])
        V_j_t_label = MathTex(r"V_j^T").next_to(V_j_t, DOWN).scale(1.5)
        for entry in V_j_t.get_entries():
            entry.set_opacity(0)

        # Group them and move them below the i-group
        group_j = VGroup(
            D_j, D_j_label,
            equal_sign_j,
            U_j, U_j_label,
            Sigma_j, Sigma_j_label,
            V_j_t, V_j_t_label,
            approx_sign_j,
        )
        group_j.scale(0.5)
        group_j.next_to(group_i, DOWN, buff=1.0)  # place below the first system

        # SET COLORS

        color_matrix(D_i, RED)
        color_matrix(U_i, RED)
        color_matrix(Sigma_i, RED)
        color_matrix(V_i_t, RED)

        color_matrix(D_j, GREEN)
        color_matrix(U_j, GREEN)
        color_matrix(Sigma_j, GREEN)
        color_matrix(V_j_t, GREEN)

        self.add(D_i, D_i_label)
        self.add(D_j, D_j_label)
        self.add(step_1_text)

        self.wait(2)

        # next step text
        self.play(ReplacementTransform(step_1_text, step_2_text))
        self.wait(1)

        self.play(
            Write(equal_sign_i),
            Write(U_i), Write(U_i_label),
            Write(Sigma_i), Write(Sigma_i_label),
            Write(V_i_t), Write(V_i_t_label),
            Write(equal_sign_j),
            Write(U_j), Write(U_j_label),
            Write(Sigma_j), Write(Sigma_j_label),
            Write(V_j_t), Write(V_j_t_label),
        )
        self.wait(4)

        # -----------------------------------------------------------
        # 3) Highlight & remove the last singular vector/value (i-set)
        # -----------------------------------------------------------
        # "Last" in a 3×3 sense => 3rd column of U_i, 3rd diag of Σ_i, 3rd row of V_i^T

        # U_i is 4x3 => last column = col=2
        u_i_third_col = VGroup(*[
            U_i.get_entries()[roD_idx * 3 + 2]
            for roD_idx in range(4)
        ])

        u_i_third_line = U_i.get_columns()[2]

        # Sigma_i is 3x3 => last diagonal => (2,2)
        sigma_i_third_diag = Sigma_i.get_entries()[2 * 3 + 2]

        # V_i^T is 3x3 => last row => row=2
        v_i_t_third_row = VGroup(*[
            V_i_t.get_entries()[2 * 3 + col_idx]
            for col_idx in range(3)
        ])

        # U_j => last col
        u_j_third_col = VGroup(*[
            U_j.get_entries()[roD_idx * 3 + 2]
            for roD_idx in range(4)
        ])
        # Sigma_j => last diag
        sigma_j_third_diag = Sigma_j.get_entries()[2 * 3 + 2]
        # V_j^T => last row
        v_j_t_third_row = VGroup(*[
            V_j_t.get_entries()[2 * 3 + col_idx]
            for col_idx in range(3)
        ])

        self.play(ReplacementTransform(step_2_text, step_3_text))

        # Highlight them in yellow
        highlight_color = YELLOW
        self.play(
            u_i_third_col.animate.set_color(highlight_color),
            lines_U_i[2].animate.set_color(highlight_color),
            sigma_i_third_diag.animate.set_color(highlight_color),
            v_i_t_third_row.animate.set_color(highlight_color),
            lines_V_i_t[2].animate.set_color(YELLOW),
            u_j_third_col.animate.set_color(YELLOW),
            lines_U_j[2].animate.set_color(YELLOW),
            sigma_j_third_diag.animate.set_color(YELLOW),
            v_j_t_third_row.animate.set_color(YELLOW),
            lines_V_j_t[2].animate.set_color(YELLOW),
            run_time=0.5
        )

        self.play(
            FadeOut(u_j_third_col),
            FadeOut(lines_U_j[2]),
            FadeOut(sigma_j_third_diag),
            FadeOut(v_j_t_third_row),
            FadeOut(lines_V_j_t[2]),
            FadeOut(u_i_third_col),
            FadeOut(u_i_third_line),
            FadeOut(lines_U_i[2]),
            FadeOut(sigma_i_third_diag),
            FadeOut(v_i_t_third_row),
            FadeOut(lines_V_i_t[2]),
            ReplacementTransform(equal_sign_i, approx_sign_i),
            ReplacementTransform(equal_sign_j, approx_sign_j),
            run_time=0.5
        )
        self.wait(2)

        ##############################################################################
        # NEXT STEP: CONCATENATE THE REMAINING LOW-RANK PARTS FROM i AND j
        ##############################################################################

        # Identify sub-blocks
        # u_j_kept_cols = VGroup(*[
        #     U_j.get_entries()[row_idx * 3 + col_idx]
        #     for row_idx in range(4)
        #     for col_idx in [0, 1]
        # ])
        # sigma_j_kept_diag = VGroup(*[
        #     Sigma_j.get_entries()[r * 3 + c]
        #     for r in [0, 1]
        #     for c in [0, 1]
        # ])
        # v_j_t_kept_rows = VGroup(*[
        #     V_j_t.get_entries()[row_idx * 3 + col_idx]
        #     for row_idx in [0, 1]
        #     for col_idx in range(3)
        # ])

        # u_i_kept_cols = VGroup(*[
        #     U_i.get_entries()[row_idx * 3 + col_idx]
        #     for row_idx in range(4)
        #     for col_idx in [0, 1]
        # ])

        # sigma_i_kept_diag = VGroup(*[
        #     Sigma_i.get_entries()[r * 3 + c]
        #     for r in [0, 1]
        #     for c in [0, 1]
        # ])

        # v_i_t_kept_rows = VGroup(*[
        #     V_i_t.get_entries()[row_idx * 3 + col_idx]
        #     for row_idx in [0, 1]
        #     for col_idx in range(3)
        # ])

        # U_ij: 4 x (2+2) = 4 x 4
        U_ij = Matrix(
            [[r"\bullet"] * 4 for _ in range(4)],  # placeholder
            element_alignment_corner=DOWN
        )

        lines_U_ij = []
        for i in range(2):
            lines_U_ij.append(Line(U_ij.get_columns()[i].get_top(), U_ij.get_columns()[i].get_bottom()).set_color(RED))
            U_i.add(lines_U_ij[i])
        for i in range(2, 4):
            lines_U_ij.append(Line(U_ij.get_columns()[i].get_top(), U_ij.get_columns()[i].get_bottom()).set_color(GREEN))
            U_i.add(lines_U_ij[i])

        # Sigma_ij: (2+2) x (2+2) = 4 x 4, block diagonal
        # We'll just create a 4x4 placeholder of bullets. 
        # (You could also manually construct block diag or use an approach like 
        #  each row = [ \bullet if i == j else '' ], but let's keep it simple.)
        Sigma_ij = Matrix(
            [[r"\bullet" if r == c else "" for c in range(4)]
            for r in range(4)],
            element_alignment_corner=DOWN
        )

        # V_ij^T: (2+2) x 3 = 4 x 3
        V_ij_t = Matrix(
            [[r"\bullet"] * 3 for _ in range(4)],  # placeholder
            element_alignment_corner=DOWN
        )

        lines_V_ij_t = []
        for i in range(2):
            lines_V_ij_t.append(Line(V_ij_t.get_rows()[i].get_left(), V_ij_t.get_rows()[i].get_right()).set_color(RED))
            V_ij_t.add(lines_V_ij_t[i])
        for i in range(2, 4):
            lines_V_ij_t.append(Line(V_ij_t.get_rows()[i].get_left(), V_ij_t.get_rows()[i].get_right()).set_color(GREEN))
            V_ij_t.add(lines_V_ij_t[i])

        # # 2) Hide those placeholders by setting opacity to 0
        for entry in U_ij.get_entries():
            entry.set_opacity(0)

        for entry in Sigma_ij.get_entries():
            entry.set_opacity(0)

        for entry in V_ij_t.get_entries():
            entry.set_opacity(0)

        U_ij_label = MathTex(r"U_{ij}").scale(1.5)
        Sigma_ij_label = MathTex(r"\Sigma_{ij}").scale(1.5)
        V_ij_t_label = MathTex(r"V_{ij}^T").scale(1.5)

        D_ij_equal_sign = MathTex("=").next_to(U_ij, LEFT)

        # Show the new matrix D_ij
        D_ij = matrix_n_by_m(4, 3).next_to(D_ij_equal_sign, LEFT)

        D_ij_label = MathTex(r"\Delta_{ij}").next_to(D_ij, 2*DOWN).scale(1.5)

        group_ij = VGroup(U_ij, *lines_U_ij, Sigma_ij, V_ij_t, *lines_V_ij_t, U_ij_label, Sigma_ij_label, V_ij_t_label, D_ij, D_ij_label, D_ij_equal_sign)

        # to center and scale
        group_ij.scale(0.5)

        # move to center of screen
        group_ij.move_to(ORIGIN).shift(LEFT*2)

        Sigma_ij.next_to(U_ij, RIGHT)
        V_ij_t.next_to(Sigma_ij, RIGHT)

        U_ij_label.next_to(U_ij, DOWN)
        Sigma_ij_label.next_to(Sigma_ij, DOWN)
        V_ij_t_label.next_to(V_ij_t, DOWN)

        # color first 2 columns of U_ij blue
        for entry in pick_columns(U_ij.get_entries(), 4, 4, [0, 1]):
            entry.set_color(RED)

        # color first 2 rows of V_ij^T blue
        for entry in pick_rows(V_ij_t.get_entries(), 4, 3, [0, 1]):
            entry.set_color(RED)

        # color first 2x2 block of Sigma_ij blue
        for r in range(2):
            for c in range(2):
                Sigma_ij.get_entries()[r * 4 + c].set_color(RED)

        # color last 2 columns of U_ij green
        for entry in pick_columns(U_ij.get_entries(), 4, 4, [2, 3]):
            entry.set_color(GREEN)

        # color last 2 rows of V_ij^T green
        for entry in pick_rows(V_ij_t.get_entries(), 4, 3, [2, 3]):
            entry.set_color(GREEN)

        # color last 2x2 block of Sigma_ij green
        for r in range(2, 4):
            for c in range(2, 4):
                Sigma_ij.get_entries()[r * 4 + c].set_color(GREEN)

        ##############################################################################
        # Actually transform the kept sub-blocks into the new big matrices U_ij, Sigma_ij, V_ij^T
        ##############################################################################
        transforms = []

        #####################################################
        # 1) U_ij: shape (4 x 4)
        #    - i’s 2 kept columns -> columns [0, 1] of U_ij
        #    - j’s 2 kept columns -> columns [2, 3] of U_ij
        #####################################################
        # We'll do row by row

        for row in range(4):
            # U_i col0 -> U_ij col0
            transforms.append(
                ReplacementTransform(
                    U_i.get_entries()[row*3 + 0],  # source
                    U_ij.get_entries()[row*4 + 0]         # target
                )
            )
            # set opacity to 1
            #transforms.append(U_ij.get_entries()[row*4 + 0].animate.set_opacity(1))
            
            # U_i col1 -> U_ij col1
            transforms.append(
                ReplacementTransform(
                    U_i.get_entries()[row*3 + 1],
                    U_ij.get_entries()[row*4 + 1]
                )
            )
            # set opacity to 1
            #transforms.append(U_ij.get_entries()[row*4 + 1].animate.set_opacity(1))

            # U_j col0 -> U_ij col2
            transforms.append(
                ReplacementTransform(
                    U_j.get_entries()[row*3 + 0],
                    U_ij.get_entries()[row*4 + 2]
                )
            )
            # set opacity to 1
            #transforms.append(U_ij.get_entries()[row*4 + 2].animate.set_opacity(1))

            # U_j col1 -> U_ij col3
            transforms.append(
                ReplacementTransform(
                    U_j.get_entries()[row*3 + 1],
                    U_ij.get_entries()[row*4 + 3]
                )
            )
            # set opacity to 1
            #transforms.append(U_ij.get_entries()[row*4 + 3].animate.set_opacity(1))
        for row in range(2):
            transforms.append(
                ReplacementTransform(
                    lines_U_i[row],
                    lines_U_ij[row]
                )
            )
            transforms.append(lines_U_ij[row].animate.set_opacity(1))
            transforms.append(
                ReplacementTransform(
                    lines_U_j[row],
                    lines_U_ij[row+2]
                )
            )
            transforms.append(lines_U_ij[row+2].animate.set_opacity(1))
            transforms.append(
                ReplacementTransform(
                    lines_V_i_t[row],
                    lines_V_ij_t[row]
                )
            )
            transforms.append(lines_V_ij_t[row].animate.set_opacity(1))
            transforms.append(
                ReplacementTransform(
                    lines_V_j_t[row],
                    lines_V_ij_t[row+2]
                )
            )
            transforms.append(lines_V_ij_t[row+2].animate.set_opacity(1))
        
        
        #####################################################
        # 2) Sigma_ij: shape (4 x 4)
        #    - i’s 2×2 block -> top-left of Sigma_ij
        #    - j’s 2×2 block -> bottom-right of Sigma_ij
        #####################################################

        # i-block => top-left corner [0..1, 0..1]
        for r in [0, 1]:
            for c in [0, 1]:
                src_index = r * 3 + c    # Sigma_i is 3x3
                dst_index = r * 4 + c    # Sigma_ij is 4x4
                transforms.append(
                    ReplacementTransform(
                        Sigma_i.get_entries()[src_index],
                        Sigma_ij.get_entries()[dst_index]
                    )
                )
                # set opacity to 1
                transforms.append(Sigma_ij.get_entries()[dst_index].animate.set_opacity(1))


        # j-block => bottom-right corner [2..3, 2..3]
        for r in [0, 1]:
            for c in [0, 1]:
                src_index = r * 3 + c         # Sigma_j is 3x3
                dst_index = (r+2)*4 + (c+2)   # shift row by +2, col by +2
                transforms.append(
                    ReplacementTransform(
                        Sigma_j.get_entries()[src_index],
                        Sigma_ij.get_entries()[dst_index]
                    )
                )
                # set opacity to 1
                transforms.append(Sigma_ij.get_entries()[dst_index].animate.set_opacity(1))

        #####################################################
        # 3) V_ij^T: shape (4 x 3)
        #    - i’s 2 kept rows -> rows [0, 1] of V_ij^T
        #    - j’s 2 kept rows -> rows [2, 3] of V_ij^T
        #####################################################

        # i => top 2 rows
        for row in [0, 1]:
            for col in range(3):
                # row in V_i_t is row, col
                src_index = row*3 + col
                # row in V_ij_t is same row for i
                dst_index = row*3 + col
                transforms.append(
                    ReplacementTransform(
                        V_i_t.get_entries()[src_index],
                        V_ij_t.get_entries()[dst_index]
                    )
                )
                # set opacity to 1
                #transforms.append(V_ij_t.get_entries()[dst_index].animate.set_opacity(1))

        # j => bottom 2 rows
        for row in [0, 1]:
            for col in range(3):
                # row in V_j_t is row, col
                src_index = row*3 + col
                # row in V_ij_t is row+2 in the new matrix
                dst_index = (row+2)*3 + col
                transforms.append(
                    ReplacementTransform(
                        V_j_t.get_entries()[src_index],
                        V_ij_t.get_entries()[dst_index]
                    )
                )
                # set opacity to 1
                #transforms.append(V_ij_t.get_entries()[dst_index].animate.set_opacity(1))

        to_fade = VGroup(
            D_i, D_i_label,
            U_i.brackets, U_i_label,
            Sigma_i.brackets, Sigma_i_label,
            V_i_t.brackets, V_i_t_label,
            D_j, D_j_label,
            U_j.brackets, U_j_label,
            Sigma_j.brackets, Sigma_j_label,
            V_j_t.brackets, V_j_t_label,
            approx_sign_i,
            approx_sign_j,
        )

        to_fade_in = VGroup(
            U_ij, U_ij_label,
            Sigma_ij, Sigma_ij_label,
            V_ij_t, V_ij_t_label,
            *lines_U_ij, *lines_V_ij_t
        )

        self.play(ReplacementTransform(step_3_text, step_4_text))
        self.wait(2)

        # Now play all the transforms together
        self.play(FadeIn(*to_fade_in), *transforms, FadeOut(*to_fade), run_time=3)
        self.wait(2)

        U_ij_ortho_label = MathTex(r"U_{ij}^{\perp}").scale(1.5).scale(0.5).next_to(U_ij, DOWN)
        #Sigma_ij_ortho_label = MathTex(r"\Sigma_{ij}^{\perp}").scale(1.5).scale(0.5).next_to(Sigma_ij, DOWN) # useless numbers are not orthogonalizable
        V_ij_t_ortho_label = MathTex(r"(V_{ij}^T)^{\perp}").scale(1.5).scale(0.5).next_to(V_ij_t, DOWN)

        transforms = [
            ReplacementTransform(U_ij_label, U_ij_ortho_label),
            #ReplacementTransform(Sigma_ij_label, Sigma_ij_ortho_label), # useless numbers are not orthogonalizable
            ReplacementTransform(V_ij_t_label, V_ij_t_ortho_label),
        ]

        self.play(ReplacementTransform(step_4_text, step_5_text))
        self.play(*transforms)
        self.play(
            Indicate(U_ij_label, scale_factor=1.5, color=YELLOW),
            #Indicate(Sigma_ij_label, scale_factor=1.5, color=YELLOW),
            Indicate(V_ij_t_label, scale_factor=1.5, color=YELLOW),
        )

        self.wait(2)

        self.play(ReplacementTransform(step_5_text, step_6_text))

        # set the new D_ij to yellow
        for entry in D_ij.get_entries():
            entry.set_color(YELLOW)

        self.play(
            Write(D_ij), Write(D_ij_label),
            Write(D_ij_equal_sign),
        )

        self.play(Indicate(D_ij_label, scale_factor=1.5, color=YELLOW))

        self.wait(2)

        # fade everything out 
        self.play(FadeOut(group_ij), FadeOut(step_6_text), FadeOut(D_ij_label), FadeOut(D_ij_equal_sign), FadeOut(U_ij_ortho_label),  FadeOut(V_ij_t_ortho_label)) # FadeOut(Sigma_ij_ortho_label), useless numbers are not orthogonalizable
        self.wait(1)
        

                                                                                                                                                             