# Setup

In [None]:
import torch

from src.prog_scheme.utils import generate_target_weights

input_size = 100
output_size = 50
rank = 30
dim = input_size * output_size
over_sampling = 1
x_rand = False

batch_size = 3
tol = 1e-8
max_iter = 1000
norm_type = "fro"
svd_every_k_iter = 5
read_noise_std = 0.1
update_noise_std = 0.1
# generate low rank matrix
w_target = generate_target_weights(input_size, output_size, rank)
print(w_target[:5, :5])

In [None]:
from aihwkit.simulator.configs import FloatingPointRPUConfig, SingleRPUConfig
from aihwkit.simulator.configs.devices import (
    ConstantStepDevice,
    DriftParameter,
    ExpStepDevice,
    FloatingPointDevice,
    IdealDevice,
    LinearStepDevice,
    SimpleDriftParameter,
)
from aihwkit.simulator.configs.utils import (
    InputRangeParameter,
    PrePostProcessingParameter,
    UpdateParameters,
)
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.presets.configs import IdealizedPreset, PCMPreset, ReRamSBPreset
from aihwkit.simulator.presets.devices import IdealizedPresetDevice
from aihwkit.simulator.tiles import FloatingPointTile

from src.core.aihwkit.utils import rpuconf2dict
from src.prog_scheme.kalman import DeviceKF, ExpDeviceEKF, LinearDeviceEKF

pre_post_cfg = PrePostProcessingParameter(input_range=InputRangeParameter(enable=False))
# device_cfg = ExpStepDevice()
device_cfg = LinearStepDevice()
# device_cfg = IdealDevice()

update_cfg = UpdateParameters(pulse_type=PulseType.STOCHASTIC, desired_bl=127)
rpuconfig = SingleRPUConfig(update=update_cfg, device=device_cfg)
rpuconfig.forward.out_noise = read_noise_std
rpuconfig.device.write_noise_std = update_noise_std
rpuconfig.device.w_max = 1
rpuconfig.device.gamma_up = 0.1
rpuconfig.device.gamma_down = 0.1
rpuconfig.device.w_min = -1
rpuconfig.device.w_max_dtod = 0.01
rpuconfig.device.w_min_dtod = 0.01
rpuconfig.device.dw_min_std = 0.0
rpuconfig.device.mult_noise = False  # additive noise
# rpuconfig.forward.inp_res = 0
# rpuconfig = IdealizedPreset(update=update_cfg, device=device_cfg, pre_post=pre_post_cfg)

rpuconf_dict = rpuconf2dict(rpuconfig, max_depth=2)
if rpuconfig.device.__class__ == LinearStepDevice:
    device_ekf = LinearDeviceEKF(
        dim=dim,
        read_noise_std=read_noise_std,
        update_noise_std=update_noise_std,
        iterative_update=False,
        **rpuconf_dict["device"],
    )
elif rpuconfig.device.__class__ == ExpStepDevice:

    device_ekf = None
    # device_ekf = ExpDeviceEKF(
    #     dim=dim,
    #     read_noise_std=read_noise_std,
    #     update_noise_std=update_noise_std,
    #     iterative_update=True,
    #     **rpuconf_dict["device"],

    # )
else:
    device_ekf = None
    # raise NotImplementedError

fnc = DeviceKF(dim, read_noise_std, update_noise_std)

conf = {
    **rpuconf_dict,
    "matrix": {"input_size": input_size, "output_size": output_size, "rank": rank},
    "methods": {
        "fnc": None,
        "tolerance": tol,
        "max_iter": max_iter,
        "batch_size": batch_size,
        "norm_type": norm_type,
        "svd_every_k_iter": svd_every_k_iter,
        "read_noise_std": read_noise_std,
        "update_noise_std": update_noise_std,
        "w_init": 0.01,
        "over_sampling": over_sampling,
        "x_rand": x_rand,
    },
}

In [None]:
device_ekf

In [None]:
from typing import Any, List, Optional, Tuple, Union

from aihwkit.simulator.tiles.analog import AnalogTile
from torch import Tensor, eye
from torch.autograd import no_grad
from torch.linalg import lstsq


class customread_tile(AnalogTile):
    @no_grad()
    def read_weights_(
        self,
        apply_weight_scaling: bool = False,
        x_values: Optional[Tensor] = None,
        x_rand: bool = False,
        over_sampling: int = 10,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Reads the weights (and biases) in a realistic manner
        by using the forward pass for weights readout.

        Gets the tile weights and extracts the mathematical weight
        matrix and biases (if present, by determined by the ``self.analog_bias``
        parameter).

        The weight will not be directly read, but linearly estimated
        using random inputs using the analog forward pass.

        Note:

            If the tile includes digital periphery (e.g. out scaling),
            these will be applied. Thus this weight is the logical
            weights that correspond to the weights in an FP network.

        Note:
            weights are estimated using the ``lstsq`` solver from torch.

        Args:
            apply_weight_scaling: Whether to rescale the given weight matrix
                and populate the digital output scaling factors as
                specified in the configuration
                :class:`~aihwkit.simulator.configs.MappingParameter`. A
                new ``weight_scaling_omega`` can be given. Note that
                this will overwrite the existing digital out scaling
                factors.

            x_values: Values to use for estimating the matrix. If
                not given, inputs are standard normal vectors.

            over_sampling: If ``x_values`` is not given,
                ``over_sampling * in_size`` random vectors are used
                for the estimation

        Returns:
            a tuple where the first item is the ``[out_size, in_size]`` weight
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.

        Raises:
            TileError: in case wrong code usage of TileWithPeriphery
        """
        dtype = self.get_dtype()
        if x_values is None:
            x_values = eye(self.in_size, self.in_size, device=self.device, dtype=dtype)
            if x_rand:
                x_values = torch.rand(self.in_size, self.in_size, device=self.device, dtype=dtype)
        else:
            x_values = x_values.to(self.device)

        x_values = x_values.repeat(over_sampling, 1)

        # forward pass in eval mode
        was_training = self.training
        is_indexed = self.is_indexed()
        self.eval()
        if is_indexed:
            self.analog_ctx.set_indexed(False)
        y_values = self.forward(x_values)
        if was_training:
            self.train()
        if is_indexed:
            self.analog_ctx.set_indexed(True)

        if self.bias is not None:
            y_values -= self.bias

        est_weight = lstsq(x_values, y_values).solution.T.cpu()
        weight, bias = self._separate_weights(est_weight)

        if self.digital_bias:
            bias = self.bias.detach().cpu()

        if not apply_weight_scaling:
            # we de-apply all scales
            alpha = self.get_scales()
            if alpha is not None:
                alpha = alpha.detach().cpu()
                return (weight / alpha.view(-1, 1), bias / alpha if self.analog_bias else bias)
        return weight, bias

# AnalogTile

## Compare

In [None]:
import copy

from aihwkit.simulator.tiles.analog import AnalogTile

atile = customread_tile(output_size, input_size, rpu_config=rpuconfig)  # with periphery
atile_dic = {}
atile.state_dict(atile_dic)
# atile2 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
atile2 = copy.deepcopy(atile)
# atile2.load_state_dict(atile_dic, assign=True)
# atile3 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
atile3 = copy.deepcopy(atile)
atile4 = copy.deepcopy(atile)
# atile3.load_state_dict(atile_dic, assign=True)
print(atile.tile.get_info())

In [None]:
from aihwkit.simulator.tiles.periphery import TileWithPeriphery

from src.prog_scheme.program_methods import GDP, SVD

# enroll the programming methods
atile.program_weights = GDP.call_Program_Method.__get__(atile, TileWithPeriphery)
atile.read_weights_ = GDP.read_weights_.__get__(atile, TileWithPeriphery)
atile.init_setup = GDP.init_setup.__get__(atile, TileWithPeriphery)

atile2.program_weights = SVD.call_Program_Method.__get__(atile, TileWithPeriphery)
atile2.read_weights_ = SVD.read_weights_.__get__(atile, TileWithPeriphery)
atile2.init_setup = SVD.init_setup.__get__(atile, TileWithPeriphery)

tiles = [atile, atile2]
# tiles.append(atile4) if device_ekf is not None else None
method_names = [t.program_weights.__name__ for t in tiles]

In [None]:
from src.prog_scheme.utils import program_n_log

err_lists = program_n_log(tiles, w_target.T, conf.get("methods", {}))

In [None]:
conf.get("methods", {})

## Visualize

In [None]:
# atile2.tile.target_weights = w.T
# atile2.program_weights(**conf.get("methods", {}))

In [None]:
import matplotlib.pyplot as plt


def plot_singular_values(Ws: tuple[torch.Tensor]):
    for w in Ws:
        s = torch.linalg.svdvals(w.squeeze())
        plt.plot(s)
    plt.yscale("log")
    plt.xlabel("Singular Value Index")
    plt.ylabel("Singular Value")
    plt.title("Singular Values of Weight Matrix")
    plt.show()

In [None]:
W = [w_target.T - tile_.tile.get_weights() for tile_ in tiles]

plot_singular_values(W)
print(f"{norm_type} norm of \n")
for i, w in enumerate(W):
    print(f"atile{i}: {torch.linalg.matrix_norm(w, ord=norm_type)}")

In [None]:
for err in err_lists:
    plt.semilogy(err)
    print(err[-1])
# set legend
plt.legend(method_names)


plt.xlabel("Iteration")
plt.ylabel(f"{norm_type} norm of weight error")
plt.title("Error vs Iteration @ {}x{}, rank={}".format(input_size, output_size, rank))
plt.show()

In [None]:
err

## Log

In [None]:
import wandb

conf["methods"]["device_ekf"] = conf["methods"]["device_ekf"].__class__.__name__
with wandb.init(project="prog-scheme", entity="spk", config=conf, dir="../../logs") as run:
    # Determine the maximum length

    max_len = max([len(err_list) for err_list in err_lists])
    # Pad the lists with 0s
    for err_list in err_lists:
        err_list += [None] * (max_len - len(err_list))
    # Log the data
    for i in range(max_len):
        run.log(
            {f"{name}_{norm_type}": err_list[i] for name, err_list in zip(method_names, err_lists)}
        )

## VISUALIZE UPDATES

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import TruncatedSVD

# assert (atile.initial_weights - atile2.initial_weights).max() == 0
optimal_change = (w.T - atile2.initial_weights).flatten()

# 데이터 가공
actual_updates = atile.actual_weight_updates
data = np.array(actual_updates)
flattened_data = data.reshape(data.shape[0], -1)
cumulative_update = np.cumsum(flattened_data, axis=0)

data2 = np.array(atile3.actual_weight_updates)
flattened_data2 = data2.reshape(data2.shape[0], -1)
cumulative_update2 = np.cumsum(flattened_data2, axis=0)

In [None]:
# Calculate distances
distance = []
distance2 = []
for i in range(len(cumulative_update)):
    distance.append(optimal_change - cumulative_update[i])

for i in range(len(cumulative_update2)):
    distance2.append(optimal_change - cumulative_update2[i])

distance = np.array(distance)
distance2 = np.array(distance2)
concat_distances = np.concatenate((distance, distance2), axis=0)

# Apply TruncatedSVD
svd = TruncatedSVD(n_components=2)
svd_result = svd.fit_transform(distance)

# Map original data to SVD space

svd_gdp = svd_result[0 : max_iter - 1]
svd_svd = svd_result[max_iter : max_iter * 2 - 1]

# Set grid in SVD result range
x = np.linspace(svd_result[:, 0].min(), svd_result[:, 0].max(), 100)
y = np.linspace(svd_result[:, 1].min(), svd_result[:, 1].max(), 100)
X, Y = np.meshgrid(x, y)

# Calculate distance from origin in SVD space
Z = np.sqrt(X**2 + Y**2)

# Visualization
plt.figure(figsize=(5, 4))
contour = plt.contour(X, Y, Z, levels=20, cmap="viridis")
plt.colorbar(contour, label="Distance from Origin (SVD space)")
plt.scatter(svd_gdp[:, 0], svd_gdp[:, 1], alpha=0.7, label="gdp2")
plt.scatter(svd_svd[:, 0], svd_svd[:, 1], alpha=0.3, label="svd")
plt.legend()

plt.xlabel("First Principal Component")
plt.ylabel("Second Principal Component")
plt.title("Truncated SVD of Weight Updates with Distance Contours")

# Add index to each point
for i, (x, y) in enumerate(svd_gdp):
    if i % 50 == 0:
        plt.annotate(str(i), (x, y), xytext=(5, 5), textcoords="offset points")

for i, (x, y) in enumerate(svd_svd):
    if i % 50 == 0:
        plt.annotate(str(i), (x, y), xytext=(5, 5), textcoords="offset points")

plt.grid(True)
plt.tight_layout()
plt.show()

## GDP batch-size effect

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca_result = pca.fit_transform(distance)

In [None]:
from src.prog_scheme.utils import extract_error
from src.utils.logging_utils import LogCapture

In [None]:
for batch_size_ in [1, 5, 10, 20, 50, input_size]:
    with LogCapture() as logc:
        atile.tile.set_weights(w.T)
        atile.program_weights(batch_size=batch_size_)
        log_list = logc.get_log_list()
    err_list = extract_error(log_list)
    num_iter = len(err_list)
    plt.semilogy(err_list, label=f"batch_size={batch_size_}")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Nuclear norm of weight error")
plt.title(
    "{}x{} rank={} matrix with {}".format(
        input_size, output_size, rank, atile.rpu_config.device.__class__.__name__
    )
)

### d2d variaton

In [None]:
# print dataclass fields
atile.rpu_config.device.__dict__

In [None]:
w.T[:5, :5]

In [None]:
# check whether the element wise perturbation is applied

atile.tile.set_weights(w.T)
wtile = atile.tile.get_weights()
torch.allclose(wtile, w.T)

# CustomTile

In [None]:
from aihwkit.simulator.tiles.custom import CustomTile

ctile = CustomTile(output_size, input_size)
ctile.get_weights(realistic=True)
ctile2 = CustomTile(output_size, input_size)

### RealisticTile(Ours)

In [None]:
from src.prog_scheme.realistic import RealisticTile, RPUConfigwithProgram

# rpu_config = RPUConfigwithProgram(program_weights=gdp2)
# ctile = RealisticTile(output_size, input_size, rpu_config=rpu_config)

# rpu_config2 = RPUConfigwithProgram(program_weights=svd)
# ctile2 = RealisticTile(output_size, input_size, rpu_config=rpu_config2)

In [None]:
# print(rpu_config)

In [None]:
with LogCapture() as logc:
    ctile.set_weights(w, realistic=True)
    log_list = logc.get_log_list()

with LogCapture() as logc:
    ctile2.set_weights(w, realistic=True)
    log_list2 = logc.get_log_list()

In [None]:
# extract error and plot
import matplotlib.pyplot as plt

err_list = extract_error(log_list)
err_list2 = extract_error(log_list2)

plt.plot(err_list, label="gpc")
plt.plot(err_list2, label="svd")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.title("Error vs Iteration")
plt.show()

# ETC

only `AnalogTile` which inherits `TileWithPeriphery` class has `program_weights` method

`program_weights` method implements "Gradient descent-based programming of analog in-memory computing cores" by default

`set_weights` method is used to set the weights of the analog tile to the given values\
`program_weights` method is internally called by `set_weights` method to program the weights of the analog tile\

`get_weights` method is used to get the weights of the analog tile\
`read_weights` method is used to read the weights of the analog tile with read noise

In [None]:
from aihwkit.nn import AnalogLinear
from aihwkit.optim import AnalogSGD

In [None]:
digital_layer = torch.nn.Linear(input_size, output_size, bias=False)
layer = AnalogLinear.from_digital(digital_layer, rpuconfig)

In [None]:
optimizer = AnalogSGD(layer.parameters(), lr=0.005)
losses = []
for _ in range(1000):
    x = torch.rand(input_size)
    yhat = layer(x)
    loss = (yhat**2).sum()
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# plot losses
import matplotlib.pyplot as plt

plt.plot(losses)