In [2]:
import wandb
import torch
from src.prog_scheme.utils import generate_target_weights, program_n_log
from aihwkit.simulator.configs import SingleRPUConfig
from aihwkit.simulator.configs.devices import LinearStepDevice
from aihwkit.simulator.configs.utils import InputRangeParameter, PrePostProcessingParameter, UpdateParameters
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.tiles.analog import AnalogTile
from torch import Tensor, eye
from torch.autograd import no_grad
from torch.linalg import lstsq
import matplotlib.pyplot as plt
import copy
from typing import Optional, Tuple


In [3]:
import torch
import wandb
import copy
import matplotlib.pyplot as plt
from typing import Any, List, Optional, Tuple, Union
from torch import Tensor, eye
from torch.autograd import no_grad
from torch.linalg import lstsq

In [4]:
from src.prog_scheme.utils import generate_target_weights, program_n_log,rpuconf2dict
from aihwkit.simulator.configs import FloatingPointRPUConfig, SingleRPUConfig
from aihwkit.simulator.configs.devices import (
    ConstantStepDevice,
    DriftParameter,
    ExpStepDevice,
    FloatingPointDevice,
    IdealDevice,
    LinearStepDevice,
    SimpleDriftParameter,
)
from aihwkit.simulator.configs.utils import (
    InputRangeParameter,
    PrePostProcessingParameter,
    UpdateParameters,
)
from aihwkit.simulator.parameters.enums import PulseType
from aihwkit.simulator.presets.configs import IdealizedPreset, PCMPreset, ReRamSBPreset
from aihwkit.simulator.presets.devices import IdealizedPresetDevice
from aihwkit.simulator.tiles import FloatingPointTile

from src.prog_scheme.kalman import ExpDeviceEKF, LinearDeviceEKF

from aihwkit.simulator.tiles.analog import AnalogTile

from aihwkit.simulator.tiles.periphery import TileWithPeriphery

from src.prog_scheme.program_methods import gdp2, svd, svd_ekf_lqg, svd_kf

Define customread_tile class

In [5]:

class customread_tile(AnalogTile):
    @no_grad()
    def read_weights_(
        self,
        apply_weight_scaling: bool = False,
        x_values: Optional[Tensor] = None,
        x_rand: bool = False,
        over_sampling: int = 10,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Reads the weights (and biases) in a realistic manner
        by using the forward pass for weights readout.

        Gets the tile weights and extracts the mathematical weight
        matrix and biases (if present, by determined by the ``self.analog_bias``
        parameter).

        The weight will not be directly read, but linearly estimated
        using random inputs using the analog forward pass.

        Note:

            If the tile includes digital periphery (e.g. out scaling),
            these will be applied. Thus this weight is the logical
            weights that correspond to the weights in an FP network.

        Note:
            weights are estimated using the ``lstsq`` solver from torch.

        Args:
            apply_weight_scaling: Whether to rescale the given weight matrix
                and populate the digital output scaling factors as
                specified in the configuration
                :class:`~aihwkit.simulator.configs.MappingParameter`. A
                new ``weight_scaling_omega`` can be given. Note that
                this will overwrite the existing digital out scaling
                factors.

            x_values: Values to use for estimating the matrix. If
                not given, inputs are standard normal vectors.

            over_sampling: If ``x_values`` is not given,
                ``over_sampling * in_size`` random vectors are used
                for the estimation

        Returns:
            a tuple where the first item is the ``[out_size, in_size]`` weight
            matrix; and the second item is either the ``[out_size]`` bias vector
            or ``None`` if the tile is set not to use bias.

        Raises:
            TileError: in case wrong code usage of TileWithPeriphery
        """
        dtype = self.get_dtype()
        if x_values is None:
            x_values = eye(self.in_size, self.in_size, device=self.device, dtype=dtype)
            if x_rand:
                x_values = torch.rand(self.in_size, self.in_size, device=self.device, dtype=dtype)
        else:
            x_values = x_values.to(self.device)

        x_values = x_values.repeat(over_sampling, 1)

        # forward pass in eval mode
        was_training = self.training
        is_indexed = self.is_indexed()
        self.eval()
        if is_indexed:
            self.analog_ctx.set_indexed(False)
        y_values = self.forward(x_values)
        if was_training:
            self.train()
        if is_indexed:
            self.analog_ctx.set_indexed(True)

        if self.bias is not None:
            y_values -= self.bias

        est_weight = lstsq(x_values, y_values).solution.T.cpu()
        weight, bias = self._separate_weights(est_weight)

        if self.digital_bias:
            bias = self.bias.detach().cpu()

        if not apply_weight_scaling:
            # we de-apply all scales
            alpha = self.get_scales()
            if alpha is not None:
                alpha = alpha.detach().cpu()
                return (weight / alpha.view(-1, 1), bias / alpha if self.analog_bias else bias)
        return weight, bias

sweep main function

In [6]:
def main():
    # Define default parameters
    default_config = {
        'input_size': 100,
        'output_size': 50,
        'rank': 50,
        'over_sampling': 10,
        'x_rand': False,
        'batch_size': 1,
        'tol': 1e-8,
        'max_iter': 1000,
        'norm_type': "fro",
        'svd_every_k_iter': 5,
        'read_noise_std': 0.1,
        'update_noise_std': 0.1,
        'w_init': 0.01,
        'gamma_up': 0.1,
        'gamma_down': 0.1,
        'desired_bl': 127,
        'w_max': 1,
        'w_min': -1,
    }



    # Initialize wandb
    wandb.init(config=default_config)
    config = wandb.config



    # Extract parameters from wandb.config
    input_size = config.input_size
    output_size = config.output_size
    rank = config.rank
    dim = input_size * output_size
    over_sampling = config.over_sampling
    x_rand = config.x_rand
    batch_size = config.batch_size
    tol = config.tol
    max_iter = config.max_iter
    norm_type = config.norm_type
    svd_every_k_iter = config.svd_every_k_iter
    read_noise_std = config.read_noise_std
    update_noise_std = config.update_noise_std
    w_init = config.w_init

    # Generate low rank matrix
    w_target = generate_target_weights(input_size, output_size, rank)

    # Configure device and RPU
    pre_post_cfg = PrePostProcessingParameter(input_range=InputRangeParameter(enable=False))
    device_cfg = LinearStepDevice()
    update_cfg = UpdateParameters(pulse_type=PulseType.STOCHASTIC, desired_bl=config.desired_bl)
    rpuconfig = SingleRPUConfig(update=update_cfg, device=device_cfg)
    rpuconfig.forward.out_noise = read_noise_std
    rpuconfig.device.write_noise_std = update_noise_std
    rpuconfig.device.w_max = config.w_max
    rpuconfig.device.gamma_up = config.gamma_up
    rpuconfig.device.gamma_down = config.gamma_down
    rpuconfig.device.w_min = config.w_min
    rpuconfig.device.w_max_dtod = 0.01
    rpuconfig.device.w_min_dtod = 0.01
    rpuconfig.device.dw_min_std = 0.0
    rpuconfig.device.mult_noise = False  # Additive noise

    # Convert RPU config to dict
    rpuconf_dict = rpuconf2dict(rpuconfig, max_depth=2)
    if isinstance(rpuconfig.device, LinearStepDevice):
        device_ekf = LinearDeviceEKF(
            dim=dim,
            read_noise_std=read_noise_std,
            update_noise_std=update_noise_std,
            **rpuconf_dict["device"],
        )
    elif isinstance(rpuconfig.device, ExpStepDevice):
        device_ekf = ExpDeviceEKF(
            dim=dim,
            read_noise_std=read_noise_std,
            update_noise_std=update_noise_std,
            **rpuconf_dict["device"],
        )
    else:
        device_ekf = None

    conf = {
        **rpuconf_dict,
        "matrix": {"input_size": input_size, "output_size": output_size, "rank": rank},
        "methods": {
            "device_ekf": device_ekf,
            "tolerance": tol,
            "max_iter": max_iter,
            "batch_size": batch_size,
            "norm_type": norm_type,
            "svd_every_k_iter": svd_every_k_iter,
            "read_noise_std": read_noise_std,
            "update_noise_std": update_noise_std,
            "w_init": w_init,
            "over_sampling": over_sampling,
            "x_rand": x_rand,
        },
    }

    # Initialize tiles
    atile = customread_tile(output_size, input_size, rpu_config=rpuconfig)  # with periphery
    atile_dic = {}
    atile.state_dict(atile_dic)
    # atile2 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
    atile2 = copy.deepcopy(atile)
    # atile2.load_state_dict(atile_dic, assign=True)
    # atile3 = AnalogTile(output_size, input_size, rpu_config=rpuconfig)
    atile3 = copy.deepcopy(atile)
    atile4 = copy.deepcopy(atile)
    
    atile.program_weights = gdp2.__get__(atile, TileWithPeriphery)
    atile2.program_weights = svd.__get__(atile2, TileWithPeriphery)
    atile3.program_weights = svd_kf.__get__(atile3, TileWithPeriphery)
    atile4.program_weights = svd_ekf_lqg.__get__(atile4, TileWithPeriphery)
    tiles = [atile, atile2, atile3]
    tiles.append(atile4) if device_ekf is not None else None
    method_names = [t.program_weights.__name__ for t in tiles]

    # Program and log errors
    err_lists = program_n_log(tiles, w_target.T, conf.get("methods", {}))

    # Log errors to wandb
    for idx, err in enumerate(err_lists):
        # Log the last error value for each method
        wandb.log({f"{method_names[idx]} Error": err[-1]})
        plt.semilogy(err, label=method_names[idx])

    plt.xlabel("Iteration")
    plt.ylabel(f"{norm_type} norm of weight error")
    plt.title(f"Error vs Iteration @ {input_size}x{output_size}, rank={rank}")
    plt.legend()
    plt.savefig("error_plot.png")
    wandb.log({"Error Plot": wandb.Image("error_plot.png")})
    plt.close()
    
 

Sweep

In [7]:
   
#sweep config 설정    
sweep_config = {
    'method': 'grid',  # 랜덤 서치를 위해서는 'random' 사용
    'parameters': {
      'rank': {
            'values': [10, 20, 30, 40, 50]
        },
        'svd_every_k_iter': {
            'values': [1, 5, 10]
        }
    }
}

# 스위프 실행
sweep_id = wandb.sweep(sweep_config, project='prog-scheme-sweep')
wandb.agent(sweep_id, function=main)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: isn9lcsx
Sweep URL: https://wandb.ai/spk/prog-scheme-sweep/sweeps/isn9lcsx


[34m[1mwandb[0m: Agent Starting Run: 8qp56h70 with config:
[34m[1mwandb[0m: 	rank: 10
[34m[1mwandb[0m: 	svd_every_k_iter: 1
[34m[1mwandb[0m: Currently logged in as: [33mminwookk5[0m ([33mspk[0m). Use [1m`wandb login --relogin`[0m to force relogin


Programming time: 13.09s
Programming time: 4.94s
Programming time: 4.72s
Programming time: 4.70s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.4298
svd Error,2.42531
svd_ekf_lqg Error,1.72679
svd_kf Error,1.63733


[34m[1mwandb[0m: Agent Starting Run: ez5f2h74 with config:
[34m[1mwandb[0m: 	rank: 10
[34m[1mwandb[0m: 	svd_every_k_iter: 5


Programming time: 12.97s
Programming time: 2.03s
Programming time: 4.16s
Programming time: 4.41s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.53986
svd Error,2.50805
svd_ekf_lqg Error,1.80133
svd_kf Error,1.80133


[34m[1mwandb[0m: Agent Starting Run: zbh5e4s0 with config:
[34m[1mwandb[0m: 	rank: 10
[34m[1mwandb[0m: 	svd_every_k_iter: 10


Programming time: 13.21s
Programming time: 1.71s
Programming time: 4.56s
Programming time: 4.40s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.49583
svd Error,2.45985
svd_ekf_lqg Error,1.69543
svd_kf Error,1.69484


[34m[1mwandb[0m: Agent Starting Run: tnbur1o1 with config:
[34m[1mwandb[0m: 	rank: 20
[34m[1mwandb[0m: 	svd_every_k_iter: 1


Programming time: 13.21s
Programming time: 5.38s
Programming time: 5.10s
Programming time: 5.04s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.41863
svd Error,2.46497
svd_ekf_lqg Error,1.71614
svd_kf Error,1.70227


[34m[1mwandb[0m: Agent Starting Run: niy9fk4x with config:
[34m[1mwandb[0m: 	rank: 20
[34m[1mwandb[0m: 	svd_every_k_iter: 5


Programming time: 13.22s
Programming time: 2.37s
Programming time: 4.83s
Programming time: 4.69s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.50641
svd Error,2.50549
svd_ekf_lqg Error,1.71257
svd_kf Error,1.69276


[34m[1mwandb[0m: Agent Starting Run: hc3lggfo with config:
[34m[1mwandb[0m: 	rank: 20
[34m[1mwandb[0m: 	svd_every_k_iter: 10


Programming time: 13.36s
Programming time: 1.82s
Programming time: 4.44s
Programming time: 4.52s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.47269
svd Error,2.51385
svd_ekf_lqg Error,1.71874
svd_kf Error,1.7373


[34m[1mwandb[0m: Agent Starting Run: 19t34nxw with config:
[34m[1mwandb[0m: 	rank: 30
[34m[1mwandb[0m: 	svd_every_k_iter: 1


Programming time: 12.90s
Programming time: 5.35s
Programming time: 5.30s
Programming time: 5.50s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.45308
svd Error,2.4787
svd_ekf_lqg Error,1.7326
svd_kf Error,1.71476


[34m[1mwandb[0m: Agent Starting Run: b4o4ct19 with config:
[34m[1mwandb[0m: 	rank: 30
[34m[1mwandb[0m: 	svd_every_k_iter: 5


Programming time: 13.05s
Programming time: 2.40s
Programming time: 4.66s
Programming time: 4.82s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.43264
svd Error,2.49436
svd_ekf_lqg Error,1.70677
svd_kf Error,1.67875


[34m[1mwandb[0m: Agent Starting Run: twtx0w79 with config:
[34m[1mwandb[0m: 	rank: 30
[34m[1mwandb[0m: 	svd_every_k_iter: 10


Programming time: 13.24s
Programming time: 1.80s
Programming time: 4.33s
Programming time: 4.63s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.51729
svd Error,2.52886
svd_ekf_lqg Error,1.77518
svd_kf Error,1.72119


[34m[1mwandb[0m: Agent Starting Run: 2asnsxbw with config:
[34m[1mwandb[0m: 	rank: 40
[34m[1mwandb[0m: 	svd_every_k_iter: 1


Programming time: 13.42s
Programming time: 5.57s
Programming time: 5.27s
Programming time: 5.29s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.48722
svd Error,2.55173
svd_ekf_lqg Error,1.82698
svd_kf Error,1.73268


[34m[1mwandb[0m: Agent Starting Run: c10rrb87 with config:
[34m[1mwandb[0m: 	rank: 40
[34m[1mwandb[0m: 	svd_every_k_iter: 5


Programming time: 13.04s
Programming time: 2.33s
Programming time: 5.29s
Programming time: 5.30s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.44598
svd Error,2.56382
svd_ekf_lqg Error,1.79939
svd_kf Error,1.74783


[34m[1mwandb[0m: Agent Starting Run: q724tioy with config:
[34m[1mwandb[0m: 	rank: 40
[34m[1mwandb[0m: 	svd_every_k_iter: 10


Programming time: 13.66s
Programming time: 2.15s
Programming time: 4.97s
Programming time: 5.16s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.47594
svd Error,2.61344
svd_ekf_lqg Error,1.79469
svd_kf Error,1.7958


[34m[1mwandb[0m: Agent Starting Run: 1beezovi with config:
[34m[1mwandb[0m: 	rank: 50
[34m[1mwandb[0m: 	svd_every_k_iter: 1


Programming time: 13.24s
Programming time: 5.72s
Programming time: 5.72s
Programming time: 5.77s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.24759
svd Error,2.62203
svd_ekf_lqg Error,2.11605
svd_kf Error,1.86052


[34m[1mwandb[0m: Agent Starting Run: q9srh2lh with config:
[34m[1mwandb[0m: 	rank: 50
[34m[1mwandb[0m: 	svd_every_k_iter: 5


Programming time: 13.28s
Programming time: 2.84s
Programming time: 5.12s
Programming time: 5.39s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.29479
svd Error,2.67571
svd_ekf_lqg Error,1.95685
svd_kf Error,1.9833


[34m[1mwandb[0m: Agent Starting Run: 4icaahi8 with config:
[34m[1mwandb[0m: 	rank: 50
[34m[1mwandb[0m: 	svd_every_k_iter: 10


Programming time: 13.45s
Programming time: 2.16s
Programming time: 4.98s
Programming time: 5.10s


0,1
gdp2 Error,▁
svd Error,▁
svd_ekf_lqg Error,▁
svd_kf Error,▁

0,1
gdp2 Error,2.27216
svd Error,2.75272
svd_ekf_lqg Error,1.95632
svd_kf Error,1.94625


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Sweep Agent: Exiting.
