# Hardware For ML Class Project

# Modeling Albiero
To model Albiero, we divide the dot product kernel into several steps:
- Input Conversion:
    - Handles conversion from DE -> AE -> AO.
    - Accounts for the losses/noises that occur along the way.
- Weight Conversion:
    - Handlers conversion from DE -> AE.
    - Accounts for normalization to [-1, 1].
- The Dot Product itself.
    - Performs the AE/AO dot product.
    - Handles the conversion from AO to AE in the PD.
- The Output conversion.
    - Handles conversion from AE to quantized DE.


I don't know if this is the level of expected detail, but it's a good start to actually understand what the accelerator is doing.


There are many things I am very unsure about.
I have left them as `TODO(Ask)` in the code. We should ask about them in office hours.
Feel to modify the code or add your own questions.

Once we have clarified these points, we can just turn these classes into pytorch operations, and run the DNNs.

# Outline of Clarifications to Ask
## General Questions
1. Level of Detail:
    - It seems impossible to capture the noise without a semi-detailed step by step computation.
    - Is this an overkill? What's the alternative? Looks like the proposal above seems basically good to go.
2. Parameter Values:
    - The paper does not specify all the values (e.g., feedback resistance at the PD, or crosstalk noise in PLCU MRRs)
        - Do you know where to find them? TBD.
        - Or can they be derived from the provided ones (e.g., MRR crosstalk from $k^2$ and FSR)? Look below for cross-talk for specifics.
        - Or can we assume some 'ideal' default (e.g., the feedback resistance that would allow loss-less computation). Yes we can assume the ideal to begin with.
3. Losses:
    - In addition to noise, there are also losses.
    - Do we ignore them, or do we take them into account? Answer: we should ignore them and we can justify this by mentioning that the losses are predictable. If the losses are not predictable then maybe we should model them.
4. Cross-talk?
    - Cross talk seems input dependent, meaning that the amount of noise depends on surrounding values (meaning receptive fields that are multiplexed in the same waveguide).
    - Should we derive cross-talk for micro-ring resonator? Answer: We should try and if not we might not do it. Cross-talk is important.
5. Do we assume constants or make something parameterized? Yes. Do not just hard-code.
    

## Specific Questions
### Input Conversion
- I understand that quantized inputs are turned into voltages.
    - With what precision? In what range? Just assume some sort of ideal if it's very much not defined from the paper (i.e. just don't model it).
    - Like [0, 1.0]?
- The voltage is then turned into an optical signal, after being multiplied by a 'gain' in (W/V).
    - I can't find this value.
    - I can assume defaults that match the output deconversion?
- AWG (Arrayed Waveguide Grating) Crosstalk.
    - This is given as a fixed value in the paper.
    - Can we assume it?
    - Isn't crosstalk input-dependent.

### Weight Conversion
- The paper expects weights to be in [-1, 1]. So I assume we have to manually scale down, then scale back up right?
- What the weights become voltages, can we assume a perfect conversion?
    - E.g., if the weight is $0.378934373$, the voltage can exactly match that.

### Optical Dot Product
- How to compute MRR cross-talk?
    - We are given $k^2$ (cross-coupling factor) and FSR (free spectral range).
    - It should input-dependent?
- How to capture RIN (relative intensity noise)?
    - The units we are given are decibels relative to the carrier per hertz (dBc/Hz)?
        - The bandwidth (frequency?) is later given as 5GHz.
- How to get the "feedback resistance"?
    - Allows converting current to voltage.

### Output Conversion
- How do we map voltage back to integers.
- Like:
    - Can we assume some uniform mapping, from (V_min -> 0) and (V_max -> int_max).
    - Are V_min and V_max fixed parameters, or do change input by input?
        - I.e., does 1V always correspond to the same integer, is it relative to other voltage values in the output.
- Same question about voltage precision.
    - Can we assume perfect voltage precision, or is something lost.

In [None]:
import torch
import typing as t
import math
import numpy as np

def dB_to_linear(dB):
    """
    Convert a decibel (dB) value to a linear scale factor.

    For a loss L in dB, the linear efficiency factor is:
       factor = 10^(-L/10)
    """
    return 10 ** (-dB / 10)



# TODO(Ask): Some of these values are hardcoded in the paper. Do we parametrize all/some of them?
# And some of them are missing (e.g., some waveguide lengths).
# Or is it okay to just assume the paper's values? And ideal values for the missing ones?
class InputConversion:
    """
    InputConversion handles everything from the input tensor up to right before the optical dot product.
    It includes:
    - Uniform quantization in case the input tensor is not already quantized.
    - DE to AE conversion.
    - AE to AO conversion.
    - AO losses and noises that occur before the optical dot product.
        - E.g., inherent losses, AWG cross talk, etc.
    """

    def __init__(
        self,
        starting_tensor,
        quantization_bitwidth=8,
        # Used for converting from DE to AE.
        voltage_min=0,
        voltage_max=255,
        # Used for converting from AE to AO.
        optical_gain=1, # What the voltage is multiplied by to get the optical power.
        inherent_losses_DB: t.List[float]=[], # Losses from interconnects, etc.
        wag_cross_talk_DB=0, # Cross talk from other channels in the WAG.
    ):
        self.starting_tensor = starting_tensor
        self.quantization_bitwidth = quantization_bitwidth
        self.max_q_val = 2**quantization_bitwidth - 1
        self.voltage_min = voltage_min
        self.voltage_max = voltage_max
        self.optical_gain = optical_gain
        # Compute inherent losses.
        self.inherent_losses_DB = inherent_losses_DB
        self.inherent_losses = [dB_to_linear(loss) for loss in inherent_losses_DB]
        self.inherent_loss = 1
        for loss in self.inherent_losses:
            self.inherent_loss *= loss
        self.wag_cross_talk_DB = wag_cross_talk_DB
        self.wag_cross_talk = dB_to_linear(wag_cross_talk_DB)


        # Compute the conversions.
        self.quantized_tensor = self.uniform_quantization(tensor=starting_tensor)
        self.ae_tensor = self.DE_to_AE()
        self.ao_tensor = self.AE_to_AO()
        self.final_optical_tensor = self.apply_conversion_losses_and_noise()


    def uniform_quantization(self, tensor):
        """
        Uniformly quantize a tensor to the specified bitwidth.
        Assume the tensor only contains positive values (post RELU, or at intial input).
        Assume the zero point is 0.
        Noop if the tensor is already quantized.
        """
        tensor = tensor.clamp(0, self.max_q_val)
        return torch.round(tensor).to(torch.int)


    def DE_to_AE(self):
        """
        Convert from DE (Digital Electric) to AE (Analog Electric).
        Linearly maps the quantized tensor to a voltage range.
        """
        tensor = self.quantized_tensor.float()
        scaled = tensor / self.max_q_val
        return self.voltage_min + (self.voltage_max - self.voltage_min) * scaled


    def AE_to_AO(self):
        """
        Convert from AE (Analog Electric) to AO (Analog Optical).
        """
        return self.ae_tensor * self.optical_gain



    def apply_conversion_losses_and_noise(self):
        """
        Apply the conversion losses and noise to the AO tensor.
        This includes:
        - Inherent losses (e.g., from interconnects, etc.)
        - AWG cross talk (e.g., from other channels in the WAG)
        - RIN, thermal, etc. noise.
        """
        # Apply inherent loss.
        tensor = self.ao_tensor * self.inherent_loss * self.wag_cross_talk
        return tensor



class WeightConversion:
    """
    WeightConversion handles everything from the weight tensor up to right before the optical dot product.
    It only includes:
    - Normalization of the weight tensor to put in the [-1, 1] range.
    - DE to AE conversion.
    """

    def __init__(
        self,
        starting_tensor,
    ):
        self.starting_tensor = starting_tensor
        self.normalized_tensor, self.normalization_scale = self.uniform_normalization()
        self.ae_tensor = self.DE_to_AE()


    def uniform_normalization(self):
        """
        Uniformly normalize the weight tensor to the [-1, 1] range.
        """
        scale = torch.max(torch.abs(self.starting_tensor))
        tensor = self.starting_tensor / scale
        return tensor, scale


    def DE_to_AE(self):
        """
        Convert from DE (Digital Electric) to AE (Analog Electric).
        Linearly maps the normalized tensor to a voltage range.
        """
        return self.normalized_tensor.clone()



class OpticalDotProduct:
    """
    OpticalDotProduct handles the actual dot product when given two tensors:
    - The AO tensor from the InputConversion.
    - The AE tensor from the WeightConversion.

    It includes:
    - The optical dot product itself.
    - The sum at the PD, which does the AO to AE conversion.
    - RELU activation.

    The output is the final AE tensor.
    """


    def __init__(
        self,
        ao_input_tensor: torch.Tensor,
        ae_weight_tensor: torch.Tensor,
        # MRR
        mrr_k2 = 0.03,
        mrr_fsr_nm = 16.1,
        mrr_loss_dB = 0,
        # MZM and Y-branch
        mzm_loss_DB = 0,
        y_branch_loss_DB = 0,
        # PD
        # TODO(Ask): How is the final noise computed from these values?
        pd_rin_DBCHZ = 0,
        pd_GHZ = 5,
        pd_T = 300, # Temperature in Kelvin.
        pd_responsivity = 1.0, # In A/W.
        pd_dark_current_pA = 0, # In pA @ 1V.
        pd_resistance = 50, # In Ohm. TODO: Not specified anywhere in the paper.
    ):
        self.ao_input_tensor = ao_input_tensor
        self.ae_weight_tensor = ae_weight_tensor
        self.mzm_loss_DB = mzm_loss_DB
        self.mzm_loss = dB_to_linear(mzm_loss_DB)
        self.y_branch_loss_DB = y_branch_loss_DB
        self.y_branch_loss = dB_to_linear(y_branch_loss_DB)
        self.mrr_lost_dB = mrr_loss_dB
        self.mrr_loss = dB_to_linear(mrr_loss_dB)
        self.mrr_k2 = mrr_k2
        self.mrr_fsr_nm = mrr_fsr_nm
        self.pd_resistance = pd_resistance
        self.pd_responsivity = pd_responsivity
        self.pd_dark_current_pA = pd_dark_current_pA
        self.pd_rin_DBCHZ = pd_rin_DBCHZ
        self.pd_GHZ = pd_GHZ
        self.pd_T = pd_T



        # Compute
        self.mzm_output = self.apply_mzm()
        self.mrr_output = self.apply_mrr()
        self.pd_output = self.sum_at_pd()
        self.ae_output = self.relu_activation()


    def apply_mzm(self):
        """
        Apply the MZM loss to the optical tensor.
        This includes:
        - Element-wise multiplication of the two tensors.
        - Y-branch loss
        - MZM loss
        """
        mzm_ouput = self.ao_input_tensor * self.ae_weight_tensor
        return mzm_ouput * self.y_branch_loss * self.mzm_loss

    def apply_mrr(self):
        """
        Apply the MRR loss to the optical tensor.
        This includes:
        - MRR loss.
        - MRR cross talk.
        """
        mrr_output = self.mzm_output * self.mrr_loss
        mrr_cross_talk_DB = 0 # TODO: Figure me out.
        mrr_cross_talk = dB_to_linear(mrr_cross_talk_DB)
        return mrr_output * mrr_cross_talk



    def sum_at_pd(self):
        """
        Sum the optical tensor at the PD.
        This includes:
        - PD responsivity.
        - PD dark current.
        - PD noises (RIN, thermal, shot)
        - PD resistance.
        """
        pd_output = torch.sum(self.mrr_output, dim=0)
        pd_output = pd_output + self.pd_dark_current_pA
        noise_DB = 0 # TODO: Figure me out.
        noise = dB_to_linear(noise_DB)
        pd_output = pd_output * noise
        pd_output = pd_output * self.pd_responsivity
        return pd_output * self.pd_resistance


    def relu_activation(self):
        """
        Apply the ReLU activation function to the PD output.
        """
        return torch.relu(self.pd_output)


class OutputConversion:
    """
    OutputConversion handles everything after the optical dot product up to the final output tensor.
    It only includes:
    - AE to DE conversion.
    - Quantization of the DE tensor.
    """


    def __init__(
        self,
        ae_tensor,
        quantization_bitwidth,
        voltage_min: t.Optional[int]=None,
        voltage_max: t.Optional[int]=None,
    ):
        self.voltage_min = voltage_min
        self.voltage_max = voltage_max
        self.ae_tensor = ae_tensor
        self.quantization_bitwidth = quantization_bitwidth
        self.max_q_val = 2**quantization_bitwidth - 1
        self.de_tensor = self.AE_to_DE()
        self.final_output = self.de_tensor



    def AE_to_DE(self):
        """
        Convert from AE (Analog Electric) to DE (Digital Electric).
        """
        voltage_min = self.voltage_min or self.ae_tensor.min()
        voltage_max = self.voltage_max or self.ae_tensor.max()
        if voltage_min == voltage_max:
            return self.ae_tensor.round().clamp(0, self.max_q_val).to(torch.int)
        quantized_min = 0
        quantized_max = self.max_q_val
        scale = (quantized_max - quantized_min) / (voltage_max - voltage_min)
        tensor = (self.ae_tensor - voltage_min) * scale
        return tensor.round().clamp(quantized_min, quantized_max).to(torch.int)
