# SETUP

In [None]:
import numpy as np
import torch

from src.core.eqprop.eqprop_util import OTS, P3OTS
from src.core.eqprop.strategy import GradientDescentStrategy, NewtonStrategy

In [None]:
ckpt_path = "../../logs/test/last.ckpt"

ckpt = torch.load(ckpt_path, map_location="cpu")
w1 = ckpt["state_dict"]["net.model.0.weight"]
w2 = ckpt["state_dict"]["net.model.1.weight"]

model = torch.nn.Sequential(torch.nn.Linear(3, 2, bias=False), torch.nn.Linear(2, 2, bias=False))
model[0].weight.data = w1
model[1].weight.data = w2

In [None]:
w1

In [None]:
w2

In [None]:
st = NewtonStrategy(
    activation=P3OTS(Is=1e-6, Vl=-0.6, Vr=0.6, Vth=1.0),
    clip_threshold=0.5,
    amp_factor=1.0,
    max_iter=5,
    atol=1e-7,
    add_nonlin_last=False,
)
st.set_strategy_params(model)

In [None]:
print(ckpt["state_dict"]["net.model.0.positive_node"])
print(ckpt["state_dict"]["net.model.1.positive_node"])
print(ckpt["state_dict"]["net.model.0.negative_node"])
print(ckpt["state_dict"]["net.model.1.negative_node"])

In [None]:
p0 = ckpt["state_dict"]["net.model.0.positive_node"]
p1 = ckpt["state_dict"]["net.model.1.positive_node"]
# torch.allclose(p0, p1, rtol=1e-7)
p1

# Run

In [None]:
import io
import logging


class LogCapture:
    def __init__(self, logger_name: str = None):
        """Capture log messages to a list

        Args:
            logger_name (str, optional): Name of the logger. Usually the file path under src. Defaults to None.

        Example:
            with LogCapture("src.core.eqprop.strategy") as log_capture:
                logger = logging.getLogger("src.core.eqprop.strategy")
                logger.info("Hello")
                logger.info("World")
                log_list = log_capture.get_log_list()
        """
        self.log_stream = io.StringIO()
        self.logger = logging.getLogger(logger_name)
        self.logging_level = self.logger.getEffectiveLevel()
        self.stream_handler = logging.StreamHandler(self.log_stream)
        self.formatter = logging.Formatter("%(message)s")
        self.stream_handler.setFormatter(self.formatter)

    def __enter__(self):
        self.logger.setLevel(logging.DEBUG)
        self.logger.addHandler(self.stream_handler)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.logger.removeHandler(self.stream_handler)
        self.log_stream.close()
        self.logger.setLevel(self.logging_level)

    def get_log_list(self):
        log_contents = self.log_stream.getvalue().strip().split("\n")
        return log_contents

In [None]:
import re


def extract_and_convert_tensors(log_list):
    tensor_list = []
    tensor_pattern = re.compile(r"tensor\(\[\[(.*?)\]\]\)")

    for log in log_list:
        match = tensor_pattern.search(log)
        if match:
            tensor_str = match.group(1)
            tensor_values = [float(x) for x in tensor_str.split(", ")]
            tensor = torch.tensor(tensor_values)  # Reshape to match the original format
            tensor_list.append(tensor)

    return tensor_list

In [None]:
x = torch.tensor([[-2, 2, 1]]).float()
torch.set_printoptions(precision=10)
i_ext = None  # torch.tensor(0)
with LogCapture("src.core.eqprop.strategy") as log_capture:
    st.solve(x, i_ext)
    log_list = log_capture.get_log_list()

In [None]:
v_traj = extract_and_convert_tensors(log_list)
print(v_traj)

2 2 1.358 1.999
2 -2 

## Reimplement the solve method

# Visualize

In [None]:
import matplotlib.pyplot as plt

In [None]:
def project_and_plot(trajectories, target_dim=2, plot_arrow: bool = False):
    """
    n차원 데이터를 2차원 또는 3차원으로 투영하고 플롯합니다.

    Parameters:
    trajectories (list): n차원 벡터 궤적 리스트
    target_dim (int): 투영할 목표 차원 (2 또는 3)
    plot
    """
    # 텐서로 변환
    # data = torch.tensor(trajectories, dtype=torch.float32)
    size = len(trajectories)
    data = torch.cat(trajectories).reshape(size, -1)
    # SVD 수행
    U, S, Vh = torch.linalg.svd(data)

    # 첫 target_dim 개의 eigenvector
    eigenvectors = Vh[:target_dim, :]

    # 데이터를 target_dim 평면으로 투영
    projected_data = torch.matmul(data, eigenvectors.T).numpy()

    # 플롯
    if target_dim == 2:
        plt.figure(figsize=(8, 6))
        plt.plot(projected_data[:, 0], projected_data[:, 1], "o-", label="Projected Trajectory")
        if plot_arrow:
            for i in range(len(projected_data) - 1):
                plt.arrow(
                    projected_data[i, 0],
                    projected_data[i, 1],
                    projected_data[i + 1, 0] - projected_data[i, 0],
                    projected_data[i + 1, 1] - projected_data[i, 1],
                    head_width=0.1,
                    head_length=0.2,
                    fc="k",
                    ec="k",
                )
        plt.xlabel("First Principal Component")
        plt.ylabel("Second Principal Component")
        plt.title("Trajectory Projection onto 2D Plane")
        plt.legend()
        plt.grid(True)
        plt.show()
    elif target_dim == 3:
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection="3d")
        ax.plot(
            projected_data[:, 0],
            projected_data[:, 1],
            projected_data[:, 2],
            "o-",
            label="Projected Trajectory",
        )
        if plot_arrow:
            for i in range(len(projected_data) - 1):
                ax.quiver(
                    projected_data[i, 0],
                    projected_data[i, 1],
                    projected_data[i, 2],
                    projected_data[i + 1, 0] - projected_data[i, 0],
                    projected_data[i + 1, 1] - projected_data[i, 1],
                    projected_data[i + 1, 2] - projected_data[i, 2],
                    arrow_length_ratio=0.1,
                    color="k",
                )
        ax.set_xlabel("First Principal Component")
        ax.set_ylabel("Second Principal Component")
        ax.set_zlabel("Third Principal Component")
        ax.set_title("Trajectory Projection onto 3D Space")
        ax.legend()
        plt.show()
    else:
        raise ValueError("target_dim must be 2 or 3")

In [None]:
project_and_plot(v_traj, target_dim=2)

In [None]:
torch.any(st.residual(v_traj[1], x, None).abs() > 1e-5)

In [None]:
st.residual(v_traj[40], x, None)

In [None]:
L = st.laplacian()

In [None]:
L @ v_traj[1]

In [None]:
st.rhs(x)