## Imports

In [9]:
from manim import Matrix
from manim import * 
import json
import math 
import numpy as np

## Data and configs

In [10]:
color_dict = {
    "Eggshell": "#f4f1de",
    "Burnt sienna": "#e07a5f",
    "Delft Blue": "#3d405b",
    "Cambridge blue": "#81b29a",
    "Sunset": "#f2cc8f",
}

In [11]:
accuracy_data = np.array(
    [
        [64.70, 68.20, 65.23],  # Zeroshot
        [79.56, 76.73, 71.60],  # Weight Averaging
        [84.93, 79.41, 74.01],  # Task Vector
        [86.34, 82.22, 79.00],  # Consensus TA
        [92.98, 89.17, 87.72],  # TSV-M (Ours)
    ]
)

In [12]:
bar_colors = [
    color_dict["Delft Blue"],
    color_dict["Eggshell"],
    color_dict["Burnt sienna"],
    color_dict["Sunset"],
    color_dict["Cambridge blue"],
]

with open('data/radar_results_ViT-L-14.json') as f:
    data = json.load(f)

eight_tasks = data['8 Tasks']

method_order = ["Zero-shot", "Weight Average", "Task Arithmetic", "Consensus TA", "TSV-M"]
dataset_order = ['Cars', 'DTD', 'EuroSAT', 'GTSRB', 'MNIST', 'RESISC45', 'SVHN', 'SUN397']

method_data = [[eight_tasks[method][dataset] for dataset in dataset_order] for method in method_order]

## Scene

In [13]:
%%manim -qh -v WARNING RadarChartScene

class RadarChartScene(Scene):
    def construct(self):
        #### 1) Define data ####

        ind = 0

        methods = [
            r"Zeroshot",
            r"Weight~Averaging",
            r"Task~Arithmetic",
            r"Consensus~TA",
            r"\textbf{TSV}",
        ]
        dataset_labels = [rf"{dataset_order[i]}" for i in range(8)] 

        num_datasets = len(dataset_labels) 
        num_methods = len(methods)         
        max_radius = 3.0                    # how far out the max value (1.0) extends

        # Shift the radar chart well to the left:
        center = 3 * LEFT + 0.5 * UP
        right_center = 1 * RIGHT + 0.5 * UP

        ############## BAR CHART (RIGHT SIDE) ################

        bar_colors = [
            color_dict["Eggshell"],
            color_dict["Sunset"],
            color_dict["Burnt sienna"],
            color_dict["Delft Blue"],
            color_dict["Cambridge blue"],
        ]

        chart = BarChart(
            values=accuracy_data[:, ind],
            bar_names=None,  # We'll manually place method labels
            y_range=[0, 100, 10],
            y_length=6,
            x_length=10,
            x_axis_config={"font_size": 36},
            bar_colors=bar_colors,
            bar_width=0.3,
        )
        # Make it a bit smaller and move it to the right
        chart.scale(0.7).move_to(right_center)

        # Bar-value labels (numbers above bars)
        c_bar_lbls = chart.get_bar_labels(font_size=32)

        # Custom method labels below bars (rotated)
        custom_labels = VGroup()
        for bar, name in zip(chart.bars, methods):
            label = Tex(name, font_size=20)
            label.rotate(45 * DEGREES).scale(1.5)
            label.next_to(bar, 3 * DOWN, buff=0.2)
            custom_labels.add(label)

        #### RADAR CHART (LEFT SIDE) ####

        # 2) Draw radial axes 
        axes_group = VGroup()
        for i in range(num_datasets):
            angle = TAU * i / num_datasets
            end_point = center + max_radius * np.array([
                math.cos(angle),
                math.sin(angle),
                0
            ])
            axis = Line(center, end_point, color=GRAY)
            axes_group.add(axis)

        # 3) Concentric circles (spider web)
        ring_group = VGroup()
        num_rings = 4
        for r_i in np.linspace(max_radius / num_rings, max_radius, num_rings):
            ring = Circle(radius=r_i, stroke_color=GRAY, stroke_opacity=0.5)
            ring_group.add(ring)
        # Move the rings to the same center
        ring_group.move_to(center)

        # 4) Axis labels for dataset names
        axis_labels = VGroup()
        label_offset = 0.4
        for i, d_label in enumerate(dataset_labels):
            angle = TAU * i / num_datasets
            label_pos = center + (max_radius + label_offset) * np.array([
                math.cos(angle),
                math.sin(angle),
                0
            ])
            label_mobj = Tex(d_label, font_size=24).move_to(label_pos)
            axis_labels.add(label_mobj)

        # 5) Polygons for each method
        polygons = VGroup()
        for method_idx in range(num_methods):
            data_values = method_data[method_idx]  # 8 values
            points = []
            for i, val in enumerate(data_values):
                angle = TAU * i / num_datasets
                r = val * max_radius  # data in [0,1]
                x = r * math.cos(angle)
                y = r * math.sin(angle)
                points.append([x, y, 0])
            polygon = Polygon(*points, color=bar_colors[method_idx])
            polygon.set_fill(bar_colors[method_idx], opacity=0.3)
            polygons.add(polygon)

        # Move polygons to the left-center
        polygons.shift(center)

        barchart_group = VGroup(chart, c_bar_lbls, custom_labels)
        barchart_group.scale(0.8).to_edge(RIGHT)

        radarchart_group = VGroup(axes_group, ring_group, axis_labels, polygons)
        radarchart_group.scale(0.8).to_edge(LEFT)

        # 7) Create charts
        self.add(axes_group, ring_group, axis_labels, chart.axes)

        # Animate polygons & bars one by one
        for idx, (poly, bar, value_label, method_label) in enumerate(
            zip(polygons, chart.bars, c_bar_lbls, custom_labels)
        ):
            # Grow the bar from the bottom, show bar's numeric label & method label
            # simultaneously create the corresponding polygon
            self.play(
                GrowFromEdge(bar, DOWN),
                FadeIn(value_label),
                FadeIn(method_label),
                Create(poly),
                run_time=1,
            )
            self.wait(0.6)

        self.wait(2)


                                                                                 