## Imports

In [1]:
from manim import Matrix
from manim import * 
import json
import math 

## Data and configs

In [2]:
color_dict = {
    "Eggshell": "#f4f1de",
    "Burnt sienna": "#e07a5f",
    "Delft Blue": "#3d405b",
    "Cambridge blue": "#81b29a",
    "Sunset": "#f2cc8f",
}

In [3]:
with open('data/radar_results_ViT-L-14.json') as f:
    data = json.load(f)

eight_tasks = data['8 Tasks']

method_order = ["Zero-shot", "Weight Average", "Task Arithmetic", "Consensus TA", "TSV-M"]
dataset_order = ['Cars', 'DTD', 'EuroSAT', 'GTSRB', 'MNIST', 'RESISC45', 'SVHN', 'SUN397']

method_data = [[eight_tasks[method][dataset] for dataset in dataset_order] for method in method_order]

## Scene

In [6]:
%%manim -qm -v WARNING RadarChartScene

class RadarChartScene(Scene):
    def construct(self):

        methods = [
            r"Zeroshot",
            r"Weight~Averaging",
            r"Task~Arithmetic",
            r"Consensus~TA",
            r"\textbf{TSV-M}",
        ]

        bar_colors = [
            color_dict["Eggshell"],
            color_dict["Sunset"],
            color_dict["Burnt sienna"],
            color_dict["Delft Blue"],
            color_dict["Cambridge blue"],
        ]

        dataset_labels = [rf"{dataset_order[i]}" for i in range(8)]  

        num_datasets = len(dataset_labels)  
        num_methods = len(methods)          
        max_radius = 3.0                    # how far out the max value (1.0) extends

        # Shift the radar chart well to the left:
        center = 3 * LEFT

        #### RADAR CHART####

        # 2) Draw radial axes (D1..D8)
        axes_group = VGroup()
        for i in range(num_datasets):
            angle = TAU * i / num_datasets
            end_point = center + max_radius * np.array([
                math.cos(angle),
                math.sin(angle),
                0
            ])
            axis = Line(center, end_point, color=GRAY)
            axes_group.add(axis)

        # 3) Concentric circles (spider web)
        ring_group = VGroup()
        num_rings = 4
        for r_i in np.linspace(max_radius / num_rings, max_radius, num_rings):
            ring = Circle(radius=r_i, stroke_color=GRAY, stroke_opacity=0.5)
            ring_group.add(ring)
        # Move the rings to the same center
        ring_group.move_to(center)

        # 4) Axis labels for D1..D8
        axis_labels = VGroup()
        label_offset = 0.4
        for i, d_label in enumerate(dataset_labels):
            angle = TAU * i / num_datasets
            label_pos = center + (max_radius + label_offset) * np.array([
                math.cos(angle),
                math.sin(angle),
                0
            ])
            label_mobj = Tex(d_label, font_size=24).move_to(label_pos)
            axis_labels.add(label_mobj)

        # 5) Polygons for each method
        polygons = VGroup()
        for method_idx in range(num_methods):
            data_values = method_data[method_idx]  # 8 values
            points = []
            for i, val in enumerate(data_values):
                angle = TAU * i / num_datasets
                r = val * max_radius  # data in [0,1]
                x = r * math.cos(angle)
                y = r * math.sin(angle)
                points.append([x, y, 0])
            polygon = Polygon(*points, color=bar_colors[method_idx])
            polygon.set_fill(bar_colors[method_idx], opacity=0.3)
            polygons.add(polygon)

        # Move polygons to the left-center
        polygons.move_to(center)

        # 6) Legend (one item per method)
        legend_items = VGroup()
        for i, method_name in enumerate(methods):
            color_swatch = Square(side_length=0.3, color=bar_colors[i], fill_opacity=0.3)
            text_label = Tex(method_name, font_size=28)
            item_group = VGroup(color_swatch, text_label).arrange(RIGHT, buff=0.2)
            legend_items.add(item_group)

        # Stack legend vertically, align left edge
        legend_items.arrange(DOWN, aligned_edge=LEFT, buff=0.3)
        # Place it somewhat to the right (relative to the center).
        # You can adjust as needed.
        legend_items.to_edge(RIGHT, buff=1).shift(DOWN * 1)

        # 7) Animate radar chart elements
        self.play(Create(axes_group))
        self.play(Create(ring_group))
        self.play(FadeIn(axis_labels))

        # Animate polygons & bars one by one
        for idx, poly in enumerate(polygons):
            # Grow the bar from the bottom, show bar's numeric label & method label
            # simultaneously create the corresponding polygon
            # and reveal that method's legend entry
            self.play(
                Create(poly),
                FadeIn(legend_items[idx]),
                run_time=1,
            )
            self.wait(0.3)

        self.wait(2)


                                                                                     