In [None]:
import re
import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import argparse
import itertools
from tqdm import tqdm

In [None]:
path = "/root/autodl-tmp/runs/sorsa_qv_ana/"
file_pattern = re.compile(r"metadata\.pt_(\d+)\.pt")

In [None]:
def calc_x_y(w_0, w_t):
    u_t, s_t, vt_t = torch.linalg.svd(w_t, full_matrices=False)
    u_0, s_0, vt_0 = w_0
    ds = (s_t - s_0).abs().mean()
    dd = 1 - ((u_t * u_0).sum(dim=0).abs() + (vt_t * vt_0).sum(dim=1).abs()).mean() / 2
    return ds.item(), dd.item()

In [None]:
w_0 = {}

In [None]:
data = {}

In [None]:
for filename in os.listdir(path):
    match = file_pattern.match(filename)
    if match:
        step = int(match.group(1))
        file_path = os.path.join(path, filename)
        if os.path.exists(file_path):
            print(f"Step: {step}")
            weight_dict = torch.load(file_path, map_location="cpu")
            if step == 0:
                if bool(w_0) is False:
                    progress_bar = tqdm(range(len(weight_dict.keys())))
                    for key, value in weight_dict.items():
                        u, s, vt = torch.linalg.svd(value.T, full_matrices=False)
                        w_0[key] = (u, s, vt)
                        progress_bar.update(1)
                    progress_bar.close()
            elif len(data.get(step, {}).keys()) is not len(weight_dict.keys()):
                progress_bar = tqdm(range(len(weight_dict.keys())))
                data[step] = {}
                for key, value in weight_dict.items():
                    x, y = calc_x_y(w_0[key], value.T)
                    data[step][key] = (x, y)
                    progress_bar.update(1)
                progress_bar.close()

In [None]:
steps = sorted(data.keys())
colors = plt.cm.viridis(np.linspace(0, 1, len(steps)))
markers_list = [
    "$0$",
    "$1$",
    "$2$",
    "$3$",
    "$4$",
    "$5$",
    "$6$",
    "$7$",
    "$8$",
    "$9$",
    "$10$",
    "$11$",
    "$12$",
    "$13$",
    "$14$",
    "$15$",
    "$16$",
    "$17$",
    "$18$",
    "$19$",
    "$20$",
    "$21$",
    "$22$",
    "$23$",
    "$24$",
    "$25$",
    "$26$",
    "$27$",
    "$28$",
    "$29$",
    "$30$",
    "$31$",
]
markers = itertools.cycle(markers_list)

# Initialize plot
plt.figure(figsize=(15, 10))

# Plot data
layer_points = {name: [] for step in data for name in data[step]}
for step in steps:
    for i, name in enumerate(data[step]):
        x, y = data[step][name]
        marker = markers_list[i % len(markers_list)]
        color = colors[steps.index(step) % len(colors)]
        plt.scatter(
            x,
            y,
            label=f"Step {step}, Layer {name}",
            marker=marker,
            color=color,
            alpha=0.6,
        )
        layer_points[name].append((x, y, color))

# Calculate and plot mean points for each step
plt.scatter(0, 0, color="black", s=100)
mean_positions = []
for step in steps:
    xs, ys = zip(*[data[step][name] for name in data[step]])
    mean_x = np.mean(xs)
    mean_y = np.mean(ys)
    mean_positions.append((mean_x, mean_y, colors[steps.index(step)]))
    plt.scatter(mean_x, mean_y, color=colors[steps.index(step)], s=100)

# Connect mean points
for i in range(len(mean_positions) - 1):
    plt.plot(
        [mean_positions[i][0], mean_positions[i + 1][0]],
        [mean_positions[i][1], mean_positions[i + 1][1]],
        color=mean_positions[i][2],
        linestyle="-",
        linewidth=4,
    )

# Connect the first mean point to (0, 0) with a black line
plt.plot(
    [0, mean_positions[0][0]],
    [0, mean_positions[0][1]],
    color="black",
    linestyle="-",
    linewidth=4,
)

# Connect points with same name
for name, points in layer_points.items():
    if len(points) > 1:
        # points.sort()  # Ensure points are sorted by step if needed
        xs, ys, cs = zip(*points)
        for i in range(len(xs) - 1):
            plt.plot(
                [xs[i], xs[i + 1]],
                [ys[i], ys[i + 1]],
                color=cs[i],
                linestyle="-",
                linewidth=2,
                alpha=0.1,  # Set transparency for the connecting lines
            )


# Custom legend for steps
handles = [
    plt.Line2D([0], [0], marker="o", color=color, linestyle="", markersize=10)
    for color in colors
]
labels = [f"Step {step}" for step in steps]
plt.legend(handles, labels, title="Steps", loc="upper left", bbox_to_anchor=(1, 1))

# Custom legend for layers (markers)
# handles = [
#     plt.Line2D([0], [0], marker=marker, color="k", linestyle="", markersize=10)
#     for marker in markers_list
# ]
# labels = [f"Layer {i}" for i in range(len(markers_list))]
# layer_legend = plt.legend(
#     handles, labels, title="Layers", loc="upper right", bbox_to_anchor=(1, 0.5)
# )
# plt.gca().add_artist(layer_legend)

# Add labels and title
plt.xlabel("$\Delta \Sigma$")
plt.ylabel("$\Delta D$")
if "sorsa" in path:
    plt.title("SORSA")
elif "LoRA" in path:
    plt.title("LoRA")
else:
    plt.title("FT")
plt.grid(True)

# plt.show()
plt.savefig(f"{path}graph.svg", format="svg")