In [None]:
import copy

import matplotlib.pyplot as plt
import torch
import wandb
from aihwkit.simulator.configs.utils import (
    InputRangeParameter,
    PrePostProcessingParameter,
    UpdateParameters,
)

In [None]:
from torch import Tensor, eye
from torch.autograd import no_grad
from torch.linalg import lstsq

In [None]:
from aihwkit.simulator.configs import FloatingPointRPUConfig, SingleRPUConfig
from aihwkit.simulator.configs.devices import (
    ConstantStepDevice,
    DriftParameter,
    ExpStepDevice,
    FloatingPointDevice,
    IdealDevice,
    LinearStepDevice,
    SimpleDriftParameter,
)
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.presets.configs import IdealizedPreset, PCMPreset, ReRamSBPreset
from aihwkit.simulator.presets.devices import IdealizedPresetDevice
from aihwkit.simulator.tiles import FloatingPointTile
from aihwkit.simulator.tiles.analog import AnalogTile
from aihwkit.simulator.tiles.periphery import TileWithPeriphery

from src.prog_scheme.kalman import ExpDeviceEKF, LinearDeviceEKF
from src.prog_scheme.program_methods import gdp2, svd, svd_ekf_lqg, svd_kf
from src.prog_scheme.utils import generate_target_weights, program_n_log, rpuconf2dict

## Define customread_tile class

In [None]:
class customread_tile(AnalogTile):
    @no_grad()
    def read_weights_(
        self,
        apply_weight_scaling: bool = False,
        x_values: Tensor | None = None,
        x_rand: bool = False,
        over_sampling: int = 10,
    ) -> tuple[Tensor, Tensor | None]:
        """Reads the weights (and biases) in a realistic manner
        by using the forward pass for weights readout.

        Gets the tile weights and extracts the mathematical weight
        matrix and biases (if present, by determined by the ``self.analog_bias``
        parameter).

        The weight will not be directly read, but linearly estimated
        using random inputs using the analog forward pass.

        Note:
            If the tile includes digital periphery (e.g. out scaling),
            these will be applied. Thus this weight is the logical
            weights that correspond to the weights in an FP network.

        Note:
            weights are estimated using the ``lstsq`` solver from torch.

        Args:
            apply_weight_scaling: Whether to rescale the given weight matrix
                and populate the digital output scaling factors as
                specified in the configuration
                :class:`~aihwkit.simulator.configs.MappingParameter`. A
                new ``weight_scaling_omega`` can be given. Note that
                this will overwrite the existing digital out scaling
                factors.

            x_values: Values to use for estimating the matrix. If
                not given, inputs are standard normal vectors.

            over_sampling: If ``x_values`` is not given,
                ``over_sampling * in_size`` random vectors are used
                for the estimation

        Returns:
            a tuple where the first item is the ``[out_size, in_size]`` weight
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.

        Raises:
            TileError: in case wrong code usage of TileWithPeriphery

        """
        dtype = self.get_dtype()
        if x_values is None:
            x_values = eye(self.in_size, self.in_size, device=self.device, dtype=dtype)
            if x_rand:
                x_values = torch.rand(self.in_size, self.in_size, device=self.device, dtype=dtype)
        else:
            x_values = x_values.to(self.device)

        x_values = x_values.repeat(over_sampling, 1)

        # forward pass in eval mode
        was_training = self.training
        is_indexed = self.is_indexed()
        self.eval()
        if is_indexed:
            self.analog_ctx.set_indexed(False)
        y_values = self.forward(x_values)
        if was_training:
            self.train()
        if is_indexed:
            self.analog_ctx.set_indexed(True)

        if self.bias is not None:
            y_values -= self.bias

        est_weight = lstsq(x_values, y_values).solution.T.cpu()
        weight, bias = self._separate_weights(est_weight)

        if self.digital_bias:
            bias = self.bias.detach().cpu()

        if not apply_weight_scaling:
            # we de-apply all scales
            alpha = self.get_scales()
            if alpha is not None:
                alpha = alpha.detach().cpu()
                return (weight / alpha.view(-1, 1), bias / alpha if self.analog_bias else bias)
        return weight, bias

## sweep main function

In [None]:
def main():
    # Define default parameters
    default_config = {
        "input_size": 100,
        "output_size": 50,
        "rank": 50,
        "over_sampling": 10,
        "x_rand": False,
        "batch_size": 1,
        "tol": 1e-8,
        "max_iter": 1000,
        "norm_type": "fro",
        "svd_every_k_iter": 5,
        "read_noise_std": 0.1,
        "update_noise_std": 0.1,
        "w_init": 0.01,
        "gamma_up": 0.1,
        "gamma_down": 0.1,
        "desired_bl": 127,
        "w_max": 1,
        "w_min": -1,
    }

    # Initialize wandb
    wandb.init(config=default_config)
    config = wandb.config

    # Extract parameters from wandb.config
    input_size = config.input_size
    output_size = config.output_size
    rank = config.rank
    dim = input_size * output_size
    over_sampling = config.over_sampling
    x_rand = config.x_rand
    batch_size = config.batch_size
    tol = config.tol
    max_iter = config.max_iter
    norm_type = config.norm_type
    svd_every_k_iter = config.svd_every_k_iter
    read_noise_std = config.read_noise_std
    update_noise_std = config.update_noise_std
    w_init = config.w_init

    # Generate low rank matrix
    w_target = generate_target_weights(input_size, output_size, rank)

    # Configure device and RPU
    pre_post_cfg = PrePostProcessingParameter(input_range=InputRangeParameter(enable=False))
    device_cfg = LinearStepDevice()
    update_cfg = UpdateParameters(pulse_type=PulseType.STOCHASTIC, desired_bl=config.desired_bl)
    rpuconfig = SingleRPUConfig(update=update_cfg, device=device_cfg)
    rpuconfig.forward.out_noise = read_noise_std
    rpuconfig.device.write_noise_std = update_noise_std
    rpuconfig.device.w_max = config.w_max
    rpuconfig.device.gamma_up = config.gamma_up
    rpuconfig.device.gamma_down = config.gamma_down
    rpuconfig.device.w_min = config.w_min
    rpuconfig.device.w_max_dtod = 0.01
    rpuconfig.device.w_min_dtod = 0.01
    rpuconfig.device.dw_min_std = 0.0
    rpuconfig.device.mult_noise = False  # Additive noise

    # Convert RPU config to dict
    rpuconf_dict = rpuconf2dict(rpuconfig, max_depth=2)
    if isinstance(rpuconfig.device, LinearStepDevice):
        device_ekf = LinearDeviceEKF(
            dim=dim,
            read_noise_std=read_noise_std,
            update_noise_std=update_noise_std,
            **rpuconf_dict["device"],
        )
    elif isinstance(rpuconfig.device, ExpStepDevice):
        device_ekf = ExpDeviceEKF(
            dim=dim,
            read_noise_std=read_noise_std,
            update_noise_std=update_noise_std,
            **rpuconf_dict["device"],
        )
    else:
        device_ekf = None

    conf = {
        **rpuconf_dict,
        "matrix": {"input_size": input_size, "output_size": output_size, "rank": rank},
        "methods": {
            "device_ekf": device_ekf,
            "tolerance": tol,
            "max_iter": max_iter,
            "batch_size": batch_size,
            "norm_type": norm_type,
            "svd_every_k_iter": svd_every_k_iter,
            "read_noise_std": read_noise_std,
            "update_noise_std": update_noise_std,
            "w_init": w_init,
            "over_sampling": over_sampling,
            "x_rand": x_rand,
        },
    }

    # Initialize tiles
    atile = customread_tile(output_size, input_size, rpu_config=rpuconfig)  # with periphery
    atile_dic = {}
    atile.state_dict(atile_dic)
    # atile2 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
    atile2 = copy.deepcopy(atile)
    # atile2.load_state_dict(atile_dic, assign=True)
    # atile3 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
    atile3 = copy.deepcopy(atile)
    atile4 = copy.deepcopy(atile)

    atile.program_weights = gdp2.__get__(atile, TileWithPeriphery)
    atile2.program_weights = svd.__get__(atile2, TileWithPeriphery)
    atile3.program_weights = svd_kf.__get__(atile3, TileWithPeriphery)
    atile4.program_weights = svd_ekf_lqg.__get__(atile4, TileWithPeriphery)
    tiles = [atile, atile2, atile3]
    tiles.append(atile4) if device_ekf is not None else None
    method_names = [t.program_weights.__name__ for t in tiles]

    # Log errors to wandb
    for idx, tile in enumerate(tiles):
        err = program_n_log(tile, w_target.T, **conf)
        # Log the last error value for each method
        wandb.log({f"{method_names[idx]} Error": err[-1]})
        plt.semilogy(err, label=method_names[idx])

    plt.xlabel("Iteration")
    plt.ylabel(f"{norm_type} norm of weight error")
    plt.title(f"Error vs Iteration @ {input_size}x{output_size}, rank={rank}")
    plt.legend()
    plt.savefig("error_plot.png")
    wandb.log({"Error Plot": wandb.Image("error_plot.png")})
    plt.close()

## Sweep

In [None]:
# sweep config 설정
sweep_config = {
    "method": "grid",  # 랜덤 서치를 위해서는 'random' 사용
    "parameters": {
        "rank": {"values": [10, 20, 30, 40, 50]},
        "svd_every_k_iter": {"values": [1, 5, 10]},
    },
}

# Sweep 실행
sweep_id = wandb.sweep(sweep_config, project="prog-scheme-sweep")
wandb.agent(sweep_id, function=main)