# Setup

In [None]:
import torch

from src.prog_scheme.utils import generate_target_weights

input_size = 100
output_size = 50
rank = 30
dim = input_size * output_size
over_sampling = 2
x_rand = False

batch_size = 2
tol = 1e-8
max_iter = 1000
norm_type = "fro"
svd_every_k_iter = 1
read_noise_std = 0.1
update_noise_std = 0.1
input_ratio = 1.1
# generate low rank matrix
w_target = generate_target_weights(input_size, output_size, rank)
print(w_target[:5, :5])

In [None]:
from aihwkit.simulator.configs import FloatingPointRPUConfig, SingleRPUConfig
from aihwkit.simulator.configs.devices import (
    ConstantStepDevice,
    DriftParameter,
    ExpStepDevice,
    FloatingPointDevice,
    IdealDevice,
    LinearStepDevice,
    SimpleDriftParameter,
)
from aihwkit.simulator.configs.utils import (
    InputRangeParameter,
    PrePostProcessingParameter,
    UpdateParameters,
)
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.presets.configs import IdealizedPreset, PCMPreset, ReRamSBPreset
from aihwkit.simulator.presets.devices import IdealizedPresetDevice
from aihwkit.simulator.tiles import FloatingPointTile

from src.core.aihwkit.utils import rpuconf2dict

pre_post_cfg = PrePostProcessingParameter(input_range=InputRangeParameter(enable=False))
# device_cfg = ExpStepDevice()
device_cfg = LinearStepDevice()
# device_cfg = IdealDevice()

update_cfg = UpdateParameters(pulse_type=PulseType.STOCHASTIC, desired_bl=127)
rpuconfig = SingleRPUConfig(update=update_cfg, device=device_cfg)
rpuconfig.forward.out_noise = read_noise_std
rpuconfig.device.write_noise_std = update_noise_std
rpuconfig.device.w_max = 1.0
rpuconfig.device.gamma_up = 0.3
rpuconfig.device.gamma_down = 0.5
rpuconfig.device.w_min = -1.0
rpuconfig.device.w_max_dtod = 0.01
rpuconfig.device.w_min_dtod = 0.01
rpuconfig.device.dw_min_std = 0.0
rpuconfig.device.mult_noise = False  # additive noise
# rpuconfig.forward.inp_res = 0
# rpuconfig = IdealizedPreset(update=update_cfg, device=device_cfg, pre_post=pre_post_cfg)

rpuconf_dict = rpuconf2dict(rpuconfig, max_depth=2)

conf = {
    **rpuconf_dict,
    "matrix": {"input_size": input_size, "output_size": output_size, "rank": rank},
    "methods": {
        "tolerance": tol,
        "max_iter": max_iter,
        "batch_size": batch_size,
        "norm_type": norm_type,
        "svd_every_k_iter": svd_every_k_iter,
        "input_ratio": input_ratio,
        "read_noise_std": read_noise_std,
        "update_noise_std": update_noise_std,
        "w_init": 0.01,
        "over_sampling": over_sampling,
        "x_rand": x_rand,
    },
}

# AnalogTile

## Compare

In [None]:
import copy

from aihwkit.simulator.tiles.analog import AnalogTile

atile = AnalogTile(output_size, input_size, rpu_config=rpuconfig)  # with periphery
atile_dic = {}
atile.state_dict(atile_dic)
tiles = [atile]
# atile3.load_state_dict(atile_dic, assign=True)
method_names = ["gdp", "gdp-kf", "svd", "svd-kf"]
for i in range(len(method_names) - 1):
    tiles.append(copy.deepcopy(atile))
if rpuconfig.device.__class__.__name__ == "LinearStepDevice":
    tiles.append(copy.deepcopy(atile))
    method_names.append("svd-ekf")
print(atile.tile.get_info())

In [None]:
from aihwkit.utils.visualization import plot_programming_error

plot_programming_error(rpuconfig.device)

In [None]:
from aihwkit.simulator.tiles.periphery import TileWithPeriphery

from src.prog_scheme.program_methods import GDP, SVD

# enroll the programming methods
for tile, method_cls in zip(tiles, [GDP, GDP] + [SVD] * (len(tiles) - 2)):
    tile.program_weights = method_cls.program_weights.__get__(tile, TileWithPeriphery)

In [None]:
import time

from src.prog_scheme.filters import DeviceKF, LinearDeviceEKF
from src.prog_scheme.utils import program_n_log

err_lists = []
method_kwargs = conf["methods"].copy()
for tile, name in zip(tiles, method_names):
    if name == "gdp":
        fnc = None
    elif name == "svd":
        fnc = None
    elif name.endswith("-kf"):
        fnc = DeviceKF(dim=dim, read_noise_std=read_noise_std, update_noise_std=update_noise_std)
    elif name == "svd-ekf" and tile.rpu_config.device.__class__.__name__ == "LinearStepDevice":
        fnc = LinearDeviceEKF(
            dim=dim,
            read_noise_std=read_noise_std,
            update_noise_std=update_noise_std,
            iterative_update=False,
            **rpuconf_dict["device"],
        )
    else:
        raise ValueError(f"Unknown method name: {name}")
    method_kwargs.update({"fnc": fnc})
    err_list = program_n_log(tile, w_target.T, **method_kwargs)
    err_lists.append(err_list)

## Visualize

In [None]:
import matplotlib.pyplot as plt


def plot_singular_values(Ws: tuple[torch.Tensor]):
    for w in Ws:
        s = torch.linalg.svdvals(w.squeeze())
        plt.plot(s)
    plt.yscale("log")
    plt.xlabel("Singular Value Index")
    plt.ylabel("Singular Value")
    plt.title("Singular Values of Weight Matrix")
    plt.show()

In [None]:
from src.core.aihwkit.utils import get_persistent_weights

W = [w_target.T - get_persistent_weights(tile_.tile) for tile_ in tiles]

plot_singular_values(W)
print(f"{norm_type} norm of \n")
for i, w in enumerate(W):
    print(f"atile{i}: {torch.linalg.matrix_norm(w, ord=norm_type)}")

In [None]:
for err in err_lists:
    loss = torch.tensor(err) / dim
    plt.semilogy(loss)
    print(err[-1])
# set legend
plt.legend(method_names)


plt.xlabel("Iteration")
plt.ylabel(f"Avg. {norm_type} norm of weight error per cell")
plt.title(f"Error vs Iteration @ {input_size}x{output_size}, rank={rank}")
plt.show()

In [None]:
import numpy as np
from aihwkit.utils.visualization import (
    compute_pulse_response,
    plot_programming_error,
    plot_response_overview,
)

# list to numpy array
direction = np.array([1, 1, 1, 1])
w_trace = compute_pulse_response(atile, direction)

In [None]:
plot_programming_error(rpuconfig.device, realistic_read=True, n_bins=51)

## Log

In [None]:
import wandb

# conf["methods"]["fnc"] = conf["methods"]["fnc"].__class__.__name__
with wandb.init(project="prog-scheme", entity="spk", config=conf, dir="../../logs") as run:
    # Determine the maximum length
    max_len = max([len(err_list) for err_list in err_lists])
    # Pad the lists with 0s
    for err_list in err_lists:
        err_list += [None] * (max_len - len(err_list))
    # Log the data
    for i in range(max_len):
        run.log(
            {f"{name}_{norm_type}": err_list[i] for name, err_list in zip(method_names, err_lists)}
        )

## VISUALIZE UPDATES

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import TruncatedSVD

# assert (atile.initial_weights - atile2.initial_weights).max() == 0
optimal_change = (w.T - tiles[0].initial_weights).flatten()

# 데이터 가공
actual_updates = atile.actual_weight_updates
data = np.array(actual_updates)
flattened_data = data.reshape(data.shape[0], -1)
cumulative_update = np.cumsum(flattened_data, axis=0)

data2 = np.array(tiles[2].actual_weight_updates)
flattened_data2 = data2.reshape(data2.shape[0], -1)
cumulative_update2 = np.cumsum(flattened_data2, axis=0)

In [None]:
# Calculate distances
distance = []
distance2 = []
for i in range(len(cumulative_update)):
    distance.append(optimal_change - cumulative_update[i])

for i in range(len(cumulative_update2)):
    distance2.append(optimal_change - cumulative_update2[i])

distance = np.array(distance)
distance2 = np.array(distance2)
concat_distances = np.concatenate((distance, distance2), axis=0)

# Apply TruncatedSVD
svd = TruncatedSVD(n_components=2)
svd_result = svd.fit_transform(distance)

# Map original data to SVD space

svd_gdp = svd_result[0 : max_iter - 1]
svd_svd = svd_result[max_iter : max_iter * 2 - 1]

# Set grid in SVD result range
x = np.linspace(svd_result[:, 0].min(), svd_result[:, 0].max(), 100)
y = np.linspace(svd_result[:, 1].min(), svd_result[:, 1].max(), 100)
X, Y = np.meshgrid(x, y)

# Calculate distance from origin in SVD space
Z = np.sqrt(X**2 + Y**2)

# Visualization
plt.figure(figsize=(5, 4))
contour = plt.contour(X, Y, Z, levels=20, cmap="viridis")
plt.colorbar(contour, label="Distance from Origin (SVD space)")
plt.scatter(svd_gdp[:, 0], svd_gdp[:, 1], alpha=0.7, label="gdp2")
plt.scatter(svd_svd[:, 0], svd_svd[:, 1], alpha=0.3, label="svd")
plt.legend()

plt.xlabel("First Principal Component")
plt.ylabel("Second Principal Component")
plt.title("Truncated SVD of Weight Updates with Distance Contours")

# Add index to each point
for i, (x, y) in enumerate(svd_gdp):
    if i % 50 == 0:
        plt.annotate(str(i), (x, y), xytext=(5, 5), textcoords="offset points")

for i, (x, y) in enumerate(svd_svd):
    if i % 50 == 0:
        plt.annotate(str(i), (x, y), xytext=(5, 5), textcoords="offset points")

plt.grid(True)
plt.tight_layout()
plt.show()

## GDP batch-size effect

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca_result = pca.fit_transform(distance)

In [None]:
from src.prog_scheme.utils import extract_error
from src.utils.logging_utils import LogCapture

In [None]:
import numpy as np

method_kwargs["fnc"] = None
for batch_size_ in [1, 5, 10, 20, 50]:
    with LogCapture() as logc:
        atile.tile.set_weights(w_target.T)
        method_kwargs["batch_size"] = batch_size_
        atile.program_weights(atile, **method_kwargs)
        log_list = logc.get_log_list()
    err_list = extract_error(log_list)
    num_iter = np.arange(len(err_list)) * batch_size_
    plt.semilogy(num_iter, err_list, label=f"batch_size={batch_size_}")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel(f"{norm_type} norm of weight error")
plt.title(
    f"{input_size}x{output_size} rank={rank} matrix with {atile.rpu_config.device.__class__.__name__}"
)

### d2d variaton

In [None]:
# print dataclass fields
atile.rpu_config.device.__dict__

In [None]:
w.T[:5, :5]

In [None]:
# check whether the element wise perturbation is applied

atile.tile.set_weights(w.T)
wtile = atile.tile.get_weights()
torch.allclose(wtile, w.T)

# CustomTile

In [None]:
from aihwkit.simulator.tiles.custom import CustomTile

ctile = CustomTile(output_size, input_size)
ctile.get_weights(realistic=True)
ctile2 = CustomTile(output_size, input_size)

### RealisticTile(Ours)

In [None]:
from src.prog_scheme.realistic import RealisticTile, RPUConfigwithProgram

# rpu_config = RPUConfigwithProgram(program_weights=gdp2)
# ctile = RealisticTile(output_size, input_size, rpu_config=rpu_config)

# rpu_config2 = RPUConfigwithProgram(program_weights=svd)
# ctile2 = RealisticTile(output_size, input_size, rpu_config=rpu_config2)

In [None]:
# print(rpu_config)

In [None]:
with LogCapture() as logc:
    ctile.set_weights(w, realistic=True)
    log_list = logc.get_log_list()

with LogCapture() as logc:
    ctile2.set_weights(w, realistic=True)
    log_list2 = logc.get_log_list()

In [None]:
# extract error and plot
import matplotlib.pyplot as plt

err_list = extract_error(log_list)
err_list2 = extract_error(log_list2)

plt.plot(err_list, label="gpc")
plt.plot(err_list2, label="svd")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.title("Error vs Iteration")
plt.show()

# ETC

only `AnalogTile` which inherits `TileWithPeriphery` class has `program_weights` method

`program_weights` method implements "Gradient descent-based programming of analog in-memory computing cores" by default

`set_weights` method is used to set the weights of the analog tile to the given values\
`program_weights` method is internally called by `set_weights` method to program the weights of the analog tile\

`get_weights` method is used to get the weights of the analog tile\
`read_weights` method is used to read the weights of the analog tile with read noise

In [None]:
from aihwkit.nn import AnalogLinear
from aihwkit.optim import AnalogSGD

In [None]:
digital_layer = torch.nn.Linear(input_size, output_size, bias=False)
layer = AnalogLinear.from_digital(digital_layer, rpuconfig)

In [None]:
optimizer = AnalogSGD(layer.parameters(), lr=0.005)
losses = []
for _ in range(1000):
    x = torch.rand(input_size)
    yhat = layer(x)
    loss = (yhat**2).sum()
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# plot losses
import matplotlib.pyplot as plt

plt.plot(losses)