# Util & setup

In [None]:
# form log list[string], search Error: {err_normalized} pattern and extract the value
# into list
import re

import torch

from src.utils.logging_utils import LogCapture


def extract_error(log_list, prefix: str = "Error: ") -> list:
    err_list = []
    for log in log_list:
        if prefix in log:
            err_list.append(float(re.findall(prefix + r"([0-9.e-]+)", log)[0]))

    return err_list

In [None]:
from matplotlib import pyplot as plt


def plot_singular_values(Ws: tuple[torch.Tensor]):
    for w in Ws:
        s = torch.linalg.svdvals(w)
        plt.plot(s)
    plt.yscale("log")
    plt.xlabel("Singular Value Index")
    plt.ylabel("Singular Value")
    plt.title("Singular Values of Weight Matrix")
    plt.show()

In [None]:
input_size = 100
output_size = 50
batch_size = 1

# generate low rank matrix
rank = 50
w = torch.randn(input_size, output_size).double()
w.clamp_(-1, 1)
if rank < min(w.shape):
    u, s, v = torch.svd(w)
    w = torch.mm(u[:, :rank], torch.mm(torch.diag(s[:rank]), v[:, :rank].t()))

# AnalogTile

## Compare

In [None]:
from aihwkit.simulator.configs import FloatingPointRPUConfig, SingleRPUConfig
from aihwkit.simulator.configs.devices import (
    ConstantStepDevice,
    DriftParameter,
    ExpStepDevice,
    FloatingPointDevice,
    IdealDevice,
    SimpleDriftParameter,
)
from aihwkit.simulator.configs.utils import (
    InputRangeParameter,
    PrePostProcessingParameter,
    UpdateParameters,
)
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.presets.configs import IdealizedPreset, PCMPreset, ReRamSBPreset
from aihwkit.simulator.presets.devices import IdealizedPresetDevice
from aihwkit.simulator.tiles import FloatingPointTile
from aihwkit.simulator.tiles.analog import AnalogTile

pre_post_cfg = PrePostProcessingParameter(input_range=InputRangeParameter(enable=True))
# device_cfg = IdealDevice()
device_cfg = IdealizedPresetDevice()
update_cfg = UpdateParameters(pulse_type=PulseType.STOCHASTIC_COMPRESSED)
rpuconfig = SingleRPUConfig(update=update_cfg, device=device_cfg, pre_post=pre_post_cfg)
rpuconfig.forward.is_perfect = True
# rpuconfig = SingleRPUConfig()
# rpuconfig = IdealizedPreset()
atile = AnalogTile(output_size, input_size, rpu_config=rpuconfig)  # with periphery
atile_dic = {}
atile.state_dict(atile_dic)
atile2 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
atile2.load_state_dict(atile_dic, assign=True)
atile3 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
atile3.load_state_dict(atile_dic, assign=True)
print(rpuconfig)

In [None]:
rpuconfig.forward.__dict__

In [None]:
print(atile2.tile.get_info())

In [None]:
from aihwkit.simulator.tiles.periphery import TileWithPeriphery

from src.prog_scheme.program_methods import gdp2, svd, svd_ekf

# enroll the programming methods
atile.program_weights = gdp2.__get__(atile, TileWithPeriphery)
atile2.program_weights = svd.__get__(atile2, TileWithPeriphery)
atile3.program_weights = svd_ekf.__get__(atile3, TileWithPeriphery)

In [None]:
with LogCapture() as logc:
    atile.tile.set_weights(w.clone().T)
    atile.program_weights(batch_size=batch_size, tolerance=1e-10, max_iter=1000)
    log_list1 = logc.get_log_list()

In [None]:
with LogCapture() as logc:
    # atile2.set_weights(weight=w, realistic=True)
    atile2.tile.set_weights(w.clone().T)
    atile2.program_weights(max_iter=1000, tolerance=1e-10, svd_once=False)
    log_list2 = logc.get_log_list()

In [None]:
with LogCapture() as logc:
    atile3.tile.set_weights(w.clone().T)
    atile3.program_weights(tolerance=1e-10, max_iter=10)
    log_list3 = logc.get_log_list()

In [None]:
W = (
    w.T - atile.tile.get_weights(),
    w.T - atile2.tile.get_weights(),
    w.T - atile3.tile.get_weights(),
)
plot_singular_values(W)

In [None]:
print(
    f"nuclear norm of \n"
    f"atile: {torch.linalg.matrix_norm(W[0], ord='nuc')},\n"
    f"atile2: {torch.linalg.matrix_norm(W[1], ord='nuc')},\n"
    f"atile3: {torch.linalg.matrix_norm(W[2], ord='nuc')}"
)

In [None]:
log_list3[-10:]

In [None]:
for log in [log_list1, log_list2, log_list3]:
    err = extract_error(log)
    plt.semilogy(err)
# set legend
plt.legend([f"gdp-seq(batchsize {batch_size})", "svd", "ekf"])
plt.xlabel("Iteration")
plt.ylabel("Nuclear norm of weight error (sum of singular values)")
plt.title("Error vs Iteration @ {}x{}, rank={}".format(input_size, output_size, rank))
plt.show()

## GDP batch-size effect

In [None]:
for batch_size_ in [1, 5, 10, 20, 50, input_size]:
    with LogCapture() as logc:
        atile.tile.set_weights(w.T)
        atile.program_weights(batch_size=batch_size_)
        log_list = logc.get_log_list()
    err_list = extract_error(log_list)
    num_iter = len(err_list)
    plt.semilogy(err_list, label=f"batch_size={batch_size_}")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Nuclear norm of weight error")
plt.title(
    "{}x{} rank={} matrix with {}".format(
        input_size, output_size, rank, atile.rpu_config.device.__class__.__name__
    )
)

### d2d variaton

In [None]:
# print dataclass fields
atile.rpu_config.device.__dict__

In [None]:
w.T[:5, :5]

In [None]:
# check whether the element wise perturbation is applied

atile.tile.set_weights(w.T)
wtile = atile.tile.get_weights()
torch.allclose(wtile, w.T)

# CustomTile

In [None]:
from aihwkit.simulator.tiles.custom import CustomTile

ctile = CustomTile(output_size, input_size)
ctile.get_weights(realistic=True)

### RealisticTile(Ours)

In [None]:
from src.prog_scheme.realistic import RealisticTile, RPUConfigwithProgram

rpu_config = RPUConfigwithProgram(program_weights=gdp2)
ctile = RealisticTile(output_size, input_size, rpu_config=rpu_config)

rpu_config2 = RPUConfigwithProgram(program_weights=svd)
ctile2 = RealisticTile(output_size, input_size, rpu_config=rpu_config2)

In [None]:
print(rpu_config)

In [None]:
with LogCapture() as logc:
    ctile.set_weights(w, realistic=True)
    log_list = logc.get_log_list()

with LogCapture() as logc:
    ctile2.set_weights(w, realistic=True)
    log_list2 = logc.get_log_list()

In [None]:
# extract error and plot
import matplotlib.pyplot as plt

err_list = extract_error(log_list)
err_list2 = extract_error(log_list2)

plt.plot(err_list, label="gpc")
plt.plot(err_list2, label="svd")
plt.legend()
plt.xlabel("Iteration")
plt.ylabel("Error")
plt.title("Error vs Iteration")
plt.show()

# ETC

only `AnalogTile` which inherits `TileWithPeriphery` class has `program_weights` method

`program_weights` method implements "Gradient descent-based programming of analog in-memory computing cores" by default

`set_weights` method is used to set the weights of the analog tile to the given values\
`program_weights` method is internally called by `set_weights` method to program the weights of the analog tile\

`get_weights` method is used to get the weights of the analog tile\
`read_weights` method is used to read the weights of the analog tile with read noise

In [None]:
from aihwkit.nn import AnalogLinear
from aihwkit.optim import AnalogSGD

In [None]:
digital_layer = torch.nn.Linear(input_size, output_size, bias=False)
layer = AnalogLinear.from_digital(digital_layer, rpuconfig)

In [None]:
optimizer = AnalogSGD(layer.parameters(), lr=0.005)
losses = []
for _ in range(1000):
    x = torch.rand(input_size)
    yhat = layer(x)
    loss = (yhat**2).sum()
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# plot losses
import matplotlib.pyplot as plt

plt.plot(losses)