# SETUP

In [None]:
import os

import numpy as np
import torch

from src.core.eqprop.eqprop_util import OTS, P3OTS
from src.core.eqprop.strategy import GradientDescentStrategy, NewtonStrategy
from src.utils.logging_utils import LogCapture

In [None]:
ckpt_dir = "../../logs/test"
ckpt_filename = "last.ckpt"
ckpt_path = os.path.join(ckpt_dir, ckpt_filename)

ckpt = torch.load(ckpt_path, map_location="cpu")
w1 = ckpt["state_dict"]["net.model.0.weight"]
w2 = ckpt["state_dict"]["net.model.1.weight"]

model = torch.nn.Sequential(torch.nn.Linear(3, 2, bias=False), torch.nn.Linear(2, 2, bias=False))
model[0].weight.data = w1
model[1].weight.data = w2

torch.set_printoptions(precision=10)

In [None]:
v_xyce = torch.load(os.path.join(ckpt_dir, "[1, 1]"), map_location="cpu")
gt = torch.cat(v_xyce, dim=1)
gt

In [None]:
print(w1)
print(w2)

In [None]:
st = NewtonStrategy(
    activation=OTS(Is=1e-8, Vl=0.1, Vr=0.9, Vth=0.026),
    clip_threshold=0.5,
    amp_factor=1.0,
    max_iter=50,
    atol=1e-7,
    add_nonlin_last=False,
    momentum=0.1,
)
st.set_strategy_params(model)

In [None]:
print(ckpt["state_dict"]["net.model.0.positive_node"])
print(ckpt["state_dict"]["net.model.1.positive_node"])
print(ckpt["state_dict"]["net.model.0.negative_node"])
print(ckpt["state_dict"]["net.model.1.negative_node"])

In [None]:
p0 = ckpt["state_dict"]["net.model.0.positive_node"]
p1 = ckpt["state_dict"]["net.model.1.positive_node"]
# torch.allclose(p0, p1, rtol=1e-7)
p1

# Run

In [None]:
import re


def extract_and_convert_tensors(log_list):
    tensor_list = []
    tensor_pattern = re.compile(r"tensor\(\[\[(.*?)\]\]\)")

    for log in log_list:
        match = tensor_pattern.search(log)
        if match:
            tensor_str = match.group(1)
            tensor_values = [float(x) for x in tensor_str.split(", ")]
            tensor = torch.tensor(tensor_values)  # Reshape to match the original format
            tensor_list.append(tensor)

    return tensor_list

In [None]:
x = torch.tensor([[-2, 2, 1]]).float()
i_ext = None  # torch.tensor(0)
st.momentum = 0.1
st.max_iter = 50
with LogCapture("src.core.eqprop.strategy") as log_capture:
    v = st.solve(x, i_ext)
    log_list = log_capture.get_log_list()

In [None]:
log_list

In [None]:
v_traj = extract_and_convert_tensors(log_list)

2 2 1.358 1.999
2 -2 

## Reimplement the solve method

# Visualize

In [None]:
import matplotlib.pyplot as plt

In [None]:
def project_and_plot(trajectories, ground_truth, target_dim=2, plot_label: bool = False):
    """
    n차원 데이터를 2차원 또는 3차원으로 투영하고 플롯합니다.

    Parameters:
        trajectories (list): n차원 벡터 궤적 리스트
        target_dim (int): 투영할 목표 차원 (2 또는 3)
        plot_label (bool): 레이블을 플롯에 추가할지 여부
    """
    # 텐서로 변환
    # data = torch.tensor(trajectories, dtype=torch.float32)
    size = len(trajectories)
    data = torch.cat(trajectories).reshape(size, -1)
    # SVD 수행
    U, S, Vh = torch.linalg.svd(data)

    # 첫 target_dim 개의 eigenvector
    eigenvectors = Vh[:target_dim, :]

    # 데이터를 target_dim 평면으로 투영
    projected_data = torch.matmul(data, eigenvectors.T).numpy()
    projected_gt = torch.matmul(ground_truth, eigenvectors.T).numpy()
    # 플롯
    if target_dim == 2:
        plt.figure(figsize=(8, 6))
        plt.plot(projected_data[:, 0], projected_data[:, 1], "o-", label="Projected Trajectory")
        plt.plot(projected_data[-1, 0], projected_data[-1, 1], "k^", markersize=10)
        plt.plot(
            projected_gt[-1, 0], projected_gt[-1, 1], "r*", markersize=10, label="Ground Truth"
        )
        if plot_label:
            for i in range(len(projected_data) - 1):
                # plot label if i%10 == 0
                if i % 10 == 0:
                    plt.text(
                        projected_data[i, 0],
                        projected_data[i, 1],
                        str(i),
                        fontsize=12,
                        color="blue",
                    )

        plt.xlabel("First Principal Component")
        plt.ylabel("Second Principal Component")
        plt.title("Trajectory Projection onto 2D Plane")
        plt.legend()
        plt.grid(True)
        plt.show()
        # mark star at the end point
    elif target_dim == 3:
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection="3d")
        # rotate the 3D plot
        ax.view_init(elev=20, azim=40)
        ax.plot(
            projected_data[:, 0],
            projected_data[:, 1],
            projected_data[:, 2],
            "o-",
            label="Projected Trajectory",
        )
        ax.plot(
            projected_data[-1, 0], projected_data[-1, 1], projected_data[-1, 2], "k^", markersize=10
        )
        ax.plot(
            projected_gt[-1, 0],
            projected_gt[-1, 1],
            projected_gt[-1, 2],
            "r*",
            markersize=10,
            label="Ground Truth",
        )
        if plot_label:
            for i in range(len(projected_data) - 1):
                if i % 10 == 0:
                    ax.text(
                        projected_data[i, 0],
                        projected_data[i, 1],
                        projected_data[i, 2],
                        str(i),
                        fontsize=12,
                        color="blue",
                    )
        ax.set_xlabel("First Principal Component")
        ax.set_ylabel("Second Principal Component")
        ax.set_zlabel("Third Principal Component")
        ax.set_title("Trajectory Projection onto 3D Space")
        ax.legend()
        plt.show()
    else:
        raise ValueError("target_dim must be 2 or 3")

In [None]:
project_and_plot(v_traj[:-1], gt, target_dim=3, plot_label=True)

# ETC

In [None]:
v_xyce

In [None]:
v_traj[-1]

In [None]:
st.bias()

In [None]:
st.OTS.i(v_traj[-1])

In [None]:
st.rhs(x).shape

In [None]:
torch.any(st.residual(v_traj[1], x, None).abs() > 1e-5)

In [None]:
st.residual(gt, x, None)

In [None]:
L = st.laplacian()