In [None]:
import torch
from aihwkit.nn import AnalogLinear
from aihwkit.optim import AnalogSGD
from aihwkit.simulator.configs import SingleRPUConfig
from aihwkit.simulator.configs.devices import ConstantStepDevice, DriftParameter
from aihwkit.simulator.configs.utils import InputRangeParameter, PrePostProcessingParameter
from aihwkit.simulator.tiles.analog import AnalogTile, AnalogTileWithoutPeriphery

input_size = 64
output_size = 10
pre_post = PrePostProcessingParameter(input_range=InputRangeParameter(enable=True))
device = ConstantStepDevice(diffusion=0, drift=DriftParameter())
rpuconfig = SingleRPUConfig(device=device, pre_post=pre_post)

tile = AnalogTileWithoutPeriphery(input_size, output_size, rpu_config=rpuconfig)
tile2 = AnalogTile(input_size, output_size, rpu_config=rpuconfig)  # with periphery

only `AnalogTile` which inherits `TileWithPeriphery` class has `program_weights` method

`program_weights` method implements "Gradient descent-based programming of analog in-memory computing cores" by default

`set_weights` method is used to set the weights of the analog tile to the given values\
`program_weights` method is internally called by `set_weights` method to program the weights of the analog tile\

`get_weights` method is used to get the weights of the analog tile\
`read_weights` method is used to read the weights of the analog tile with read noise

In [None]:
t, _ = tile2.get_weights()
t2, _ = tile2.read_weights()
torch.allclose(t, t2)

In [None]:
type(tile.tile)

In [None]:
tile2.set_weights(realistic=False)

## Custom Tiles

In [None]:
from aihwkit.simulator.tiles.custom import CustomTile

# TODO: Inherit CustomRPUConfig and CustomTile and
ctile = CustomTile(input_size, output_size)

# Gradient descent-based programming of analog in-memory computing cores

In [None]:
digital_layer = torch.nn.Linear(input_size, output_size, bias=False)
layer = AnalogLinear.from_digital(digital_layer, rpuconfig)

In [None]:
layer.training = False
layer.program_analog_weights(None)

In [None]:
optimizer = AnalogSGD(layer.parameters(), lr=0.005)
losses = []
for _ in range(1000):
    x = torch.rand(input_size)
    yhat = layer(x)
    loss = (yhat**2).sum()
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# plot losses
import matplotlib.pyplot as plt

plt.plot(losses)