In [None]:
import math
import os

import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from src.core.spice import ShallowCircuit, SPICEParser, XyceSim, create_circuit
from src.utils.eqprop_utils import AdjustParams, deltaV

In [None]:
def load_XOR():
    """load XOR dataset

    Returns:
        TensorDataset: _description_
    """
    from torch.utils.data import TensorDataset

    X = torch.FloatTensor([[[-2, -2]], [[-2, 2]], [[2, -2]], [[2, 2]]])
    Y = torch.FloatTensor([[1, 0], [0, 1], [1, 0], [0, 1]])
    dataset = TensorDataset(X, Y)
    return dataset

In [None]:
def ray_train(
    circuit: ShallowCircuit,
    dimensions: list,
    batch,
    beta,
    mpi_commands=None,
    normalize_per_input: bool = False,
):  # id:int):
    """_summary_

    Args:
        circuit (Circuit): _description_
        dimensions (list): _description_
        x (_type_): _description_
        y (_type_): _description_
        mpi_commands (_type_, optional): _description_. Defaults to None.

    Returns:
        free_Vdrops (list): _description_
        nudged_Vdrops (list): _description_
        loss (float): _description_
    """
    # setup
    x, y = batch
    # normalize input with torch.var_mean()
    if normalize_per_input:
        (std, mean) = torch.std_mean(x)
        x = (x - mean) / std

    # if mpi_commands[-1] == "-cpu-set":
    #     mpi_commands.append(str(id + 2))
    # free phase
    SPICEParser.clampLayer(circuit, x)
    # analysis
    xyce = XyceSim(mpi_command=mpi_commands)
    raw_file = xyce(spice_input=circuit)
    voltages = SPICEParser.fastRawfileParser(
        raw_file, nodenames=circuit.nodes, dimensions=dimensions
    )
    free_Vdrops = []
    ypred = None
    for vin, vout in voltages:
        free_Vdrops.append(deltaV(vin, vout))
        ypred = vout

    # calculate output layer grads

    ypred = ypred.expand(1, -1)
    ypred.requires_grad = True
    loss = F.mse_loss(ypred, y.expand(1, -1).double(), reduction="sum")
    # loss = costFun.compute_energy(ypred)
    loss.backward()
    ygrad = ypred.grad * beta
    # nudged phase
    SPICEParser.releaseLayer(circuit, ygrad)
    # analysis 2
    raw_file2 = xyce(spice_input=circuit)
    voltages2 = SPICEParser.fastRawfileParser(
        raw_file2, nodenames=circuit.nodes, dimensions=dimensions
    )
    nudged_Vdrops = [deltaV(vin, vout) for (vin, vout) in voltages2]

    return (free_Vdrops, nudged_Vdrops, np.abs(loss.detach().numpy()))

In [None]:
def ray_predict(circuit: ShallowCircuit, dimensions: list, X, mpi_commands=None):
    """_summary_

    Args:
        circuit (Circuit): _description_
        dimensions (list): _description_
        x (Tensor): input tensor
        mpi_commands (_type_, optional): _description_. Defaults to None.

    Returns:
        Tensor(1,num_classes): prediction per classes
    """
    x = X
    # input_dimension = dimensions[0]
    # output_dimension = dimensions[-1]
    # if mpi_commands[-1] == "-cpu-set":
    #     mpi_commands.append(str(id + 2))

    # free phase
    SPICEParser.clampLayer(circuit, x)
    # analysis
    xyce = XyceSim(mpi_command=mpi_commands)
    raw_file = xyce(spice_input=circuit)
    (_, Vout) = SPICEParser.fastRawfileParser(
        raw_file, nodenames=circuit.nodes, dimensions=dimensions
    )
    out = Vout[-1][::2] - Vout[-1][1::2]
    out = torch.from_numpy(out)
    return out

In [None]:
class LightningRay(L.LightningModule):

    def __init__(
        self,
        dims: list,
        batch_size: int,
        num_classes: int,
        SPICE_params: dict,
        mpi_commands: list,
        optim_kwargs: dict = None,
        **kwargs,
    ):
        super().__init__()

        self.dimensions = dims
        self.n_layers = len(self.dimensions)
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.optim_kwargs = optim_kwargs
        # hparams
        self.SPICE_params = SPICE_params
        self.save_hyperparameters(ignore=["mpi_commands"])
        self._hparams_tolist()

        # weight initialize
        self._weight_init()

        self.Pycircuit = create_circuit(W=self.W, dimensions=self.dimensions, **self.SPICE_params)
        self.circuit = ShallowCircuit.copyFromCircuit(self.Pycircuit)

        # manual optimization
        self.automatic_optimization = False
        self.mpi_commands = mpi_commands  # delete '--allow-run-as-root'

        if batch_size == 1:
            self.mpi_commands.pop()
            self.mpi_commands.pop()
            self.mpi_commands.append(str(os.cpu_count() - 2))
            self.mpi_commands.append("-cpu-set")
            self.mpi_commands.append("2-" + str(os.cpu_count() - 1))
            print(f"batch size 1 detected. change mpi cmd as {self.mpi_commands}")
        self.clipper = AdjustParams(L=self.SPICE_params["L"], U=None)

    def _hparams_tolist(self, keys=("alpha", "L", "U")) -> None:
        """broadcast hyperparameters to list

        Args:
            keys (tuple, optional): _description_. Defaults to ('alpha', 'L', 'U').
        """
        assert self.SPICE_params is not None, "hparams not set"
        [
            self.SPICE_params.update({key: [val] * (self.n_layers - 1)})
            for (key, val) in self.SPICE_params.items()
            if key in keys and type(val) is not list
        ]
        assert (
            len(self.SPICE_params["alpha"]) == self.n_layers - 1
        ), "alpha length does not match n_layers"

    def _weight_init(self) -> None:
        assert self.dimensions is not None, "dimensions not set"
        assert self.SPICE_params is not None, "hyper_params not set"
        self.W = nn.ModuleList(
            nn.Linear(dim1, dim2, bias=False)  # include bias in weight
            for dim1, dim2 in zip(self.dimensions[:-1], self.dimensions[1:])
        )
        assert self.SPICE_params["L"] is not None, "L not set"
        assert self.SPICE_params["U"] is not None, "U not set"
        for module, Li, Ui in zip(self.W, self.SPICE_params["L"], self.SPICE_params["U"]):
            module.weight.data = nn.init.uniform_(module.weight, Li, Ui)

    def configure_optimizers(self):
        params = []
        [
            params.append({"params": W.parameters(), "lr": self.SPICE_params["alpha"][idx]})
            for idx, W in enumerate(self.W)
        ]
        optimizer = torch.optim.SGD(params, **self.optim_kwargs)
        lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        beta_i = self.SPICE_params["beta"] * [-1, 1][torch.randint(0, 2, (1,))]  # random sign

        # clone circuit, processed data
        # circuit = ray.put(self.circuit)
        # batch = ray.put(self._data_preprocess(batch, num_classes=self.num_classes))
        # costFun = ray.put(self.net.costFun)
        # parallel processing with ray
        # maybe with loss?
        Vlists = ray_train(
            self.circuit, self.dimensions, batch, beta=beta_i, mpi_commands=self.mpi_commands
        )
        # merge Vlists
        self.fdV, self.ndV, losses = zip(*Vlists)

        # fdVtup

        # self.fdV = np.array(list(fdVmap), dtype=object) / self.batch_size
        # ndVmap = reduce(lambda x, y: map(add, x, y), ndVtuple)
        # self.ndV = np.array(list(ndVmap), dtype=object) / self.batch_size

        # update everything
        # update G
        # self.net.w_optimize(fdV, ndV, self.optimizers())
        self.zero_grad()
        opt = self.optimizers()
        opt.zero_grad()
        for p, fdv, ndv in zip(self.parameters(), self.fdV, self.ndV):
            p.grad = -(1 / beta_i) * torch.from_numpy(ndv**2 - fdv**2).transpose(1, 0).float()
            p.grad.contiguous()

        # clip weights(G)
        opt.step()
        self.clipper(self.W)
        # update Rarray
        SPICEParser.updateWeight(self.circuit, self.W)

        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        self.log(
            "train_loss",
            np.array(losses).mean(),
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

    def validation_step(self, batch, batch_idx):
        beta_i = self.SPICE_params["beta"] * [-1, 1][torch.randint(0, 2, (1,))]
        (X, Y) = self._data_preprocess(batch, num_classes=self.num_classes)

        outList = ray_predict(
            circuit=self.circuit, dimensions=self.dimensions, X=X, mpi_commands=self.mpi_commands
        )
        outs = torch.stack(outList, dim=0)
        o1 = outs.clone().detach()
        # calculate output layer grads
        outs.requires_grad = True
        # ypreds = F.softmax(outs, dim=1)
        loss = F.mse_loss(outs, Y, reduction="sum")
        acc = self.accuracy(outs, Y)
        self.log(
            "valid_loss", torch.abs(loss), on_step=True, on_epoch=True, prog_bar=False, logger=True
        )
        self.log("valid_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("confidence", torch.max(o1, dim=1)[0].mean(), on_epoch=True, logger=True)

    # def validation_epoch_end(self, output):
    #     [self.tb.add_histogram(name, param, self.current_epoch)for name, param in self.named_parameters()]

    def test_step(self, batch, batch_idx):

        pass

    # def predict_step(self, batch, batch_idx):
    #     circuit = ray.put(self.circuit)
    #     (X, _) = self._data_preprocess(batch, num_classes=self.num_classes)
    #     ypredList = ray_predict(id, circuit=circuit, dimensions=self.dimensions, X=X, mpi_commands=self.mpi_commands)
    #     ypreds = torch.stack(ypredList, dim=0)
    #     return torch.argmax(ypreds, dim=1)

    def accuracy(self, ypreds, labels):
        _, predicted = torch.max(ypreds.data, 1)
        correct = (predicted == torch.argmax(labels, 1)).sum().item()
        accuracy = correct / len(labels)
        return torch.tensor(accuracy)

    @staticmethod
    def _data_preprocess(batch, num_classes=10):
        """process batch data to fit in analog circuit model

        Args:
            batch (_type_): batch dataset
            num_classes (int, optional): number of classification classes. Defaults to 10.

        Returns:
            _type_: processed batch
        """
        X_batch, Y_batch = batch
        X_batch = X_batch.view(X_batch.size(0), -1)  # == X_batch.view(-1,X_batch.size(-1)**2)
        X_batch = X_batch.repeat_interleave(2, dim=1)
        X_batch[:, 1::2] = -X_batch[:, ::2]
        if Y_batch.size(-1) != num_classes:
            Y_batch = F.one_hot(Y_batch, num_classes=num_classes).expand(Y_batch.size(0), -1)
        return (X_batch, Y_batch)

# Train

In [None]:
SPICE_params = {
    "L": 1e-7,
    "U": None,
    "A": 4,
    "alpha": [0.1, 0.05],  # ~learning rate
    "beta": 1e-2,
    "Diode": {
        "Path": "/path/to/libraries/diode/switching/1N4148.lib",
        "ModelName": "1N4148",
        "Rectifier": "BidRectifier",
    },
    "noise": 0,  # ratio
}
mpi_commands = ["mpirun", "-use-hwthread-cpus", "-np", "1", "-cpu-set"]
cfg = {
    "frac": 1,
    "num_epochs": 10,
    "upper_frac": 2.8,
    "optimizer": "adam",
    "std_dev": 1,
    "SPICE_params/A": 4,
    "SPICE_params/Diode/Rectifier": "BidRectifier",
    "SPICE_params/L": 1e-5,
}

In [None]:
dataset = load_XOR()
train_loader = DataLoader(
    dataset,
    num_workers=1,
    persistent_workers=False,
    batch_size=1,
    drop_last=True,
    shuffle=True,
)
val_loader = DataLoader(
    dataset,
    num_workers=1,
    persistent_workers=False,
    batch_size=1,
    drop_last=True,
    shuffle=False,
)

In [None]:
model = LightningRay(**cfg.to_dict())
L.seed_everything(42)
trainer = L.Trainer(max_epochs=cfg["num_epochs"])
trainer.fit(model, rain_dataloaders=train_loader, val_dataloaders=val_loader)