### Import libraries and dataset

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time, os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
from scipy.constants import speed_of_light as c
import scipy.io
from collections import OrderedDict
import torch.fft as tfft
import datetime
%matplotlib inline

# gpu = torch.device("mps")
gpu = torch.device("cpu")

# ## Sets everything to double point precision (use with gradcheck)
# torch.set_default_dtype(torch.float64)



In [2]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import mutual_info_score

In [3]:
speedoflight = 3e8
epsilon0 = 8.85e-12

kb = torch.tensor(1.38e-23); T = torch.tensor(300)
hbar = torch.tensor(1.05e-34)

eleccharge = torch.tensor(1.6e-19)

In [4]:
def boseeinstein(omega, T):

    return 1/(torch.exp(hbar*omega/(kb*T))-1)

In [6]:
"""
Physics module for MIWEN implementation containing physical constants,
noise models, and mixing functions.
"""

import torch
import scipy.constants as const
from dataclasses import dataclass

@dataclass
class PhysicalConstants:
    """Physical constants and device parameters"""
    kb: float = const.k        # Boltzmann constant
    T: float = 300.0          # Room temperature (K)
    e: float = const.e        # Elementary charge
    c: float = const.c        # Speed of light
    h: float = const.h        # Planck constant
    hbar: float = const.hbar  # Reduced Planck constant
    R: float = 50.0          # Load resistance (Ω)
    eta: float = 1.0    
    
    @property
    def VT(self):
        """Thermal voltage"""
        return self.eta * self.kb * self.T / self.e

class NoiseModel:
    """Implements various noise models for the system"""
    def __init__(self, constants: PhysicalConstants, bandwidth: float):
        self.constants = constants
        self.bandwidth = bandwidth
    
    def thermal_noise(self, signal: torch.Tensor) -> torch.Tensor:
        """Add Johnson-Nyquist noise to the signal"""
        # Thermal noise power = 4kTR∆f
        noise_power = 4 * self.constants.kb * self.constants.T * self.constants.R * self.bandwidth
        noise_std = torch.sqrt(torch.tensor(noise_power))
        noise = torch.randn_like(signal) * noise_std
        return signal + noise

    def shot_noise(self, signal: torch.Tensor) -> torch.Tensor:
        """Add shot noise (optional)"""
        # TODO: Implement shot noise
        return signal

class DiodeMixing:
    """Implements different diode mixing models"""
    def __init__(self, constants: PhysicalConstants, noise_model: NoiseModel):
        self.constants = constants
        self.noise_model = noise_model

    def simple_mixing(self, z: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """Simplified mixing with tanh for stability"""
        z_noisy = self.noise_model.thermal_noise(z)
        w_noisy = self.noise_model.thermal_noise(w)
        return torch.tanh((z_noisy + w_noisy) / (2 * self.constants.VT))

    def exact_mixing(self, z: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """
        EXACT diode formula with numerical stability:
        out = w/2 + (VT/2)* log( (exp(z/VT)+ exp(-w/VT)) / (exp(z/VT)+ exp(w/VT)) )
        """
#         z_noisy = self.noise_model.thermal_noise(z)
#         w_noisy = self.noise_model.thermal_noise(w)
        
        VT = self.constants.VT
        eps = 1e-8
        
        z_scaled = torch.clamp(z / VT, min=-50, max=50)
        w_scaled = torch.clamp(w / VT, min=-50, max=50)
        
        max_val = torch.maximum(z_scaled, torch.maximum(w_scaled, -w_scaled))
        numerator = torch.log(torch.exp(z_scaled - max_val) + torch.exp(-w_scaled - max_val) + eps) + max_val
        denominator = torch.log(torch.exp(z_scaled - max_val) + torch.exp(w_scaled - max_val) + eps) + max_val
        
        out = (w/2.0) + (VT/2.0) * (numerator - denominator)
        return torch.clamp(out, min=-1e3, max=1e3) 

### Custom layers

In [7]:
class Params:
    def __init__(self, weightshape=None, red=True, units=1, fixedband=False, totBW=None,
                 noisein=False, noiseout=False, lowtemp=False, weightfreqspacing=None,
                 k_th_AWG=None, k_th_AWG_lowtemp=None, k_th_AWG_lowtemp_opt=None,
                 k_th_w=None, k_th_out=None, k_th_w_lowtemp=None, k_th_out_lowtemp=None,
                 k_th_w_lowtemp_opt=None, k_th_A2D=None, k_th_A2D_lowtemp=None,
                 scaletoz=False, scaletoW=False, wunits=1, unitconverter=1,
                 includeinputfft=False, includeoutputfft=False, semianalog=True,
                 digitallike=False, noisefactor=4, optics=False,
                 photodiodearea=1e-6, omegaoptical=2*torch.pi*193.41e12,
                 inttime=1e-7, capval=1e-12, 
                 includecarrier=False, 
                 carrieroffmult=5):
        
        # Common attributes
        self.weightshape = weightshape
        self.red = red
        self.units = units
        self.fixedband = fixedband
        self.totBW = totBW
        self.noisein = noisein
        self.noiseout = noiseout
        self.lowtemp = lowtemp
        self.weightfreqspacing = weightfreqspacing
        self.digitallike = digitallike
        self.noisefactor = noisefactor
        self.optics = optics
        self.photodiodearea = photodiodearea
        self.omegaoptical = omegaoptical
        self.inttime = inttime
        self.capval = capval
        
        # FFT related
        self.includeinputfft = includeinputfft
        self.includeoutputfft = includeoutputfft
        
        # Weight encoding/mixing attributes
        self.scaletoz = scaletoz
        self.scaletoW = scaletoW
        self.wunits = wunits
        self.unitconverter = unitconverter

        # Noise attributes
        self.k_th_AWG = k_th_AWG
        self.k_th_AWG_lowtemp = k_th_AWG_lowtemp
        self.k_th_AWG_lowtemp_opt = k_th_AWG_lowtemp_opt

        self.k_th_w = k_th_w
        self.k_th_out = k_th_out
        self.k_th_w_lowtemp = k_th_w_lowtemp
        self.k_th_out_lowtemp = k_th_out_lowtemp
        self.k_th_w_lowtemp_opt = k_th_w_lowtemp_opt

        # A2D specific attributes
        self.k_th_A2D = k_th_A2D
        self.k_th_A2D_lowtemp = k_th_A2D_lowtemp
        self.semianalog = semianalog
        
        self.includecarrier = includecarrier
        self.carrieroffmult = carrieroffmult


In [32]:
class ActivationEncoding(nn.Module):

    def __init__(self, params: Params):
        super(ActivationEncoding, self).__init__()
        self.params = params

        self.N = self.params.weightshape[1]
        self.R = self.params.weightshape[0]

        if self.params.red:
            self.n0 = 0
            self.r0 = self.R * (self.N - 1) / 2
            if not self.params.includecarrier:
                self.highestindex = int(2 * ((self.r0 + self.R) + (self.n0 + self.N) * self.R) + 4)
                self.total = int(2 * (self.highestindex - 1))
            else:
                self.highestindex = int(2 * ( self.params.carrieroffmult*((self.r0 + self.R) + (self.n0 + self.N) * self.R) ) + 4)
                self.total = int(2 * (self.highestindex - 1))
        else:
            self.n0 = int(self.N * (self.R - 1) / 2)
            self.r0 = 0
            if not self.params.includecarrier:
                self.highestindex = int(2 * ((self.r0 + self.R) * self.N + (self.n0 + self.N)) + 4)
                self.total = int(2 * (self.highestindex - 1))
            else:
                self.highestindex = int(2 * ( self.params.carrieroffmult*((self.r0 + self.R) * self.N + (self.n0 + self.N)) ) + 4)
                self.total = int(2 * (self.highestindex - 1))

    def forward(self, x):
        if self.params.includeinputfft:
            x_complex = torch.complex(x, torch.zeros_like(x))
            x = torch.fft.fft(x_complex.conj().t(), dim=0, norm="ortho").conj().t()

        if self.params.red:
            inputs_extended = x.unsqueeze(2)
            extendedshape = inputs_extended.shape
            zeropaddingleft = torch.zeros(extendedshape[0], extendedshape[1], self.R - 1)
            output_inter_tensor = torch.cat([zeropaddingleft, inputs_extended], dim=-1)
            output_inter_tensor = output_inter_tensor.view(extendedshape[0], -1)
            furtherzeropaddingleft = torch.zeros(extendedshape[0], self.n0 * self.R + 1)
            output_inter_tensor = torch.cat([furtherzeropaddingleft, output_inter_tensor], dim=-1)
            
            if self.params.includecarrier:
                extrabins = int((self.params.carrieroffmult-1)*((self.r0 + self.R) + (self.n0 + self.N) * self.R))
                carrierzeropaddingleft = torch.zeros(extendedshape[0], extrabins)
                output_inter_tensor = torch.cat([carrierzeropaddingleft, output_inter_tensor], dim=-1)
            
            zeropaddingright = torch.zeros(extendedshape[0], self.highestindex - output_inter_tensor.shape[-1])
            y = torch.cat([output_inter_tensor, zeropaddingright], dim=-1)
            x = torch.real(torch.fft.irfft(y, dim=-1, norm="ortho"))
            xener = torch.sum(x ** 2, dim=-1)
        else:
            xshape = x.shape
            zeropaddingleft = torch.zeros(xshape[0], self.n0 + 1)
            output_inter_tensor = torch.cat([zeropaddingleft, x], dim=-1)
            
            if self.params.includecarrier:
                extrabins = int((self.params.carrieroffmult-1)*((self.r0 + self.R) * self.N + (self.n0 + self.N))) 
                carrierzeropaddingleft = torch.zeros(xshape[0], extrabins)
                output_inter_tensor = torch.cat([carrierzeropaddingleft, output_inter_tensor], dim=-1)
            
            zeropaddingright = torch.zeros(xshape[0], self.highestindex - output_inter_tensor.shape[-1])
            y = torch.cat([output_inter_tensor, zeropaddingright], dim=-1)
            x = torch.real(torch.fft.irfft(y, dim=-1, norm="ortho"))
            xener = torch.sum(x ** 2, dim=-1)

        # adding noise in the time domain
        if self.params.noisein:
            if not self.params.optics:
                noisetouse = self.params.k_th_AWG_lowtemp if self.params.lowtemp else self.params.k_th_AWG
                sigma2_x_th = (
                    noisetouse * self.params.weightfreqspacing * self.total
                    if not self.params.fixedband
#                     else noisetouse * self.params.totBW
                    else noisetouse * (omegac/(2*np.pi))
                )
                if self.params.digitallike:
                    sigma2_x_th = sigma2_x_th*self.params.noisefactor
#                 xadditivenoise = torch.normal(0., torch.sqrt(torch.tensor(sigma2_x_th)), size=x.shape)
                xadditivenoise = torch.normal(0., torch.sqrt(sigma2_x_th), size=x.shape)
                x = x + xadditivenoise

        return x, xener


In [33]:
class WeightEncodingandMixing(nn.Module):

    def __init__(self, params: Params, diodemixer):
        super(WeightEncodingandMixing, self).__init__()

        self.params = params

        self.N = self.params.weightshape[1]
        self.R = self.params.weightshape[0]
        
        self.diodemixer = diodemixer

        if self.params.red:
            self.n0 = 0
            self.r0 = self.R * (self.N - 1) / 2
            if not self.params.includecarrier:
                self.highestindex = int(2 * ((self.r0 + self.R) + (self.n0 + self.N) * self.R) + 4)
                self.total = int(2 * (self.highestindex - 1))
            else:
                self.highestindex = int(2 * ( self.params.carrieroffmult*((self.r0 + self.R) + (self.n0 + self.N) * self.R) ) + 4)
                self.total = int(2 * (self.highestindex - 1))
                
        else:
            self.n0 = int(self.N * (self.R - 1) / 2)
            self.r0 = 0
            if not self.params.includecarrier:
                self.highestindex = int(2 * ((self.r0 + self.R) * self.N + (self.n0 + self.N)) + 4)
                self.total = int(2 * (self.highestindex - 1))
            else:
                self.highestindex = int(2 * ( self.params.carrieroffmult*((self.r0 + self.R) * self.N + (self.n0 + self.N)) ) + 4)
                self.total = int(2 * (self.highestindex - 1))

        complex_tensor = torch.empty(*self.params.weightshape, dtype=torch.complex64)
        nn.init.kaiming_uniform_(complex_tensor.real, a=0)
        complex_tensor.imag.zero_()
        self.W = nn.Parameter(complex_tensor)

    def forward(self, z):
        if self.params.red:
            if not self.params.includeinputfft and not self.params.includeoutputfft:
                Wtr = self.W.t()
            else:
                selfW = self.W.to(dtype=torch.complex64)

                if self.params.includeoutputfft:
                    selfW = torch.fft.fft(selfW, dim=0, norm="ortho")

                if self.params.includeinputfft:
                    selfW = torch.fft.fft(selfW.conj().t(), dim=0, norm="ortho").conj().t()

                Wtr = selfW.t()

            Wunrolled = Wtr.reshape(-1)
            zeropaddingleft = torch.zeros(int(self.r0 + 1 + (self.n0 + 1) * self.R))
            Wpaddedleft = torch.cat([zeropaddingleft, Wunrolled], dim=-1)
            
            if self.params.includecarrier:
                extrabins = int((self.params.carrieroffmult-1)*((self.r0 + self.R) + (self.n0 + self.N) * self.R)) 
                carrierzeropaddingleft = torch.zeros(extrabins)
                Wpaddedleft = torch.cat([carrierzeropaddingleft, Wpaddedleft], dim=-1)
            
            zeropaddingright = torch.zeros(self.highestindex - Wpaddedleft.shape[-1])
            Wpadded = torch.cat([Wpaddedleft, zeropaddingright], dim=-1)
            Wtime = torch.real(torch.fft.irfft(Wpadded, dim=-1, norm="ortho").unsqueeze(0))
        else:
            if not self.params.includeinputfft:
                Wunrolled = self.W.reshape(-1)
            else:
                FW = fourier_matrix(self.R) @ self.W
                Wunrolled = FW.reshape(-1)

            zeropaddingleft = torch.zeros(int((self.r0 + 1) * self.N + (self.n0 + 1)))
            Wpaddedleft = torch.cat([zeropaddingleft, Wunrolled], dim=-1)
            zeropaddingright = torch.zeros(self.highestindex - Wpaddedleft.shape[-1])
            Wpadded = torch.cat([Wpaddedleft, zeropaddingright], dim=-1)
            Wtime = torch.real(torch.fft.irfft(Wpadded, dim=-1, norm="ortho").unsqueeze(0))

        if self.params.noisein:
            if not self.params.optics:
                noisetouse = self.params.k_th_w_lowtemp if self.params.lowtemp else self.params.k_th_w
                sigma2_W_th = (noisetouse * self.params.weightfreqspacing * self.total
                               if not self.params.fixedband
#                                else noisetouse * self.params.totBW
                               else noisetouse * (omegac/(2*np.pi)))

                if self.params.digitallike:
                    sigma2_W_th = sigma2_W_th*self.params.noisefactor

#                 Wadditivenoise = torch.normal(0., torch.sqrt(torch.tensor(sigma2_W_th)), size=Wtime.shape)
                Wadditivenoise = torch.normal(0., torch.sqrt(sigma2_W_th), size=Wtime.shape)
                Wtime = Wtime + Wadditivenoise

        if not self.params.optics:
#             outputs_in_time = (1 / (4 * (kb * T / eleccharge))) * torch.mul(z, Wtime)
            outputs_in_time = self.diodemixer.exact_mixing(z, Wtime)
        else:
            upper_out = (z + Wtime) / torch.sqrt(torch.tensor(2))
            lower_out = (z - Wtime) / torch.sqrt(torch.tensor(2))
            nth = boseeinstein(self.params.omegaoptical, T)

            upper_var = nth + nth**2 + (2 * nth + 1) * upper_out**2
            lower_var = nth + nth**2 + (2 * nth + 1) * lower_out**2
            diff_current = eleccharge * torch.normal(mean=2 * z * Wtime,
                                                     std=torch.sqrt(upper_var + lower_var))

            outputs_in_time = diff_current / self.params.capval

        if not self.params.optics and self.params.scaletoz:
            outputs_in_time = F.normalize(outputs_in_time, p=2, dim=-1)
            znorm = torch.norm(z, p=2, dim=-1, keepdim=True)
            outputs_in_time = torch.mul(znorm, outputs_in_time)

        if not self.params.optics and self.params.scaletoW:
            outputs_in_time = F.normalize(outputs_in_time, p=2, dim=-1)
            Wnorm = torch.norm(Wtime, p=2, dim=-1, keepdim=True)
            outputs_in_time = torch.mul(Wnorm, outputs_in_time)

        outputs_in_time = outputs_in_time * self.params.unitconverter

        if self.params.noiseout:
            noisetouse = self.params.k_th_out_lowtemp if self.params.lowtemp else self.params.k_th_out
            sigma2_output_th = (noisetouse * self.params.weightfreqspacing * self.total
                                if not self.params.fixedband
#                                 else noisetouse * self.params.totBW
                                else noisetouse * (omegac/(2*np.pi)))

            if not self.params.digitallike:
                output_additivenoise = torch.normal(0., torch.sqrt(sigma2_output_th), size=Wtime.shape)
                outputs_in_time = outputs_in_time + output_additivenoise

        return outputs_in_time


In [34]:
class A2D(nn.Module):

    def __init__(self, params: Params):
        super(A2D, self).__init__()

        self.params = params

#         if self.params.noiseout:
        self.N = self.params.weightshape[1]
        self.R = self.params.weightshape[0]

        if self.params.red:
            self.n0 = 0
            self.r0 = self.R * (self.N - 1) / 2
            if not self.params.includecarrier:
                self.highestindex = int(2 * ((self.r0 + self.R) + (self.n0 + self.N) * self.R) + 4)
                self.total = int(2 * (self.highestindex - 1))
            else:
                self.highestindex = int(2 * ( self.params.carrieroffmult*((self.r0 + self.R) + (self.n0 + self.N) * self.R) ) + 4)
                self.total = int(2 * (self.highestindex - 1))
        else:
            self.n0 = int(self.N * (self.R - 1) / 2)
            self.r0 = 0
            if not self.params.includecarrier:
                self.highestindex = int(2 * ((self.r0 + self.R) * self.N + (self.n0 + self.N)) + 4)
                self.total = int(2 * (self.highestindex - 1))
            else:
                self.highestindex = int(2 * ( self.params.carrieroffmult*((self.r0 + self.R) * self.N + (self.n0 + self.N)) ) + 4)
                self.total = int(2 * (self.highestindex - 1))

    def forward(self, outputs_in_time):
        if self.params.noiseout:
            noisetouse = self.params.k_th_A2D_lowtemp if self.params.lowtemp else self.params.k_th_A2D

            if not self.params.fixedband:
                sigma2_A2D_th = noisetouse * self.params.weightfreqspacing * self.total
            else:
#                 sigma2_A2D_th = noisetouse * self.params.totBW
                sigma2_A2D_th = noisetouse * (omegac/(2*np.pi))

            if self.params.semianalog:
                sigma2_A2D_th = sigma2_A2D_th*self.params.noisefactor

            if not self.params.digitallike:
                additivenoise = torch.normal(0., torch.sqrt(sigma2_A2D_th), size=outputs_in_time.shape)
                outputs_in_time = outputs_in_time + additivenoise

        if self.params.optics:
            outputs_in_time = outputs_in_time * self.params.capval / (2 * eleccharge)
        else:
            outputs_in_time = outputs_in_time * 4 * (kb * T / eleccharge)

        outputs_in_freq = torch.sqrt(torch.tensor(self.total)) * torch.real(torch.fft.rfft(outputs_in_time, dim=-1, norm="ortho"))
        
        return outputs_in_freq


In [11]:
class MAFTFilter(nn.Module):

    def __init__(self, params: Params):
        super(MAFTFilter, self).__init__()

        self.params = params

        self.N = self.params.weightshape[1]
        self.R = self.params.weightshape[0]

        if self.params.red:
            self.r0 = self.R * (self.N - 1) / 2
        else:
            self.r0 = 0

    def forward(self, z):

        if self.params.red:
            z = z[:, int(self.r0 + 1):int(self.r0 + self.R + 1)]
        else:
            z = z[:, int((self.r0 + 1) * self.N):int((self.r0 + self.R) * self.N + 1):int(self.N)]

        if self.params.includeoutputfft:
            z_complex = torch.complex(z, torch.zeros_like(z))
            z = torch.fft.ifft(z_complex.t(), dim=0, norm="ortho").t().real

        return z


### Make the network

In [14]:
class AnaNetwork1layer(nn.Module):

    def __init__(self, params: Params, outputsize, inputsize, diodemixer):
        super(AnaNetwork1layer, self).__init__()

        self.params = params
        
        self.diodemixer = diodemixer

        self.flatten = nn.Flatten()

        if self.params.semianalog:

            self.analogprepro = ActivationEncoding(self.params)

            self.analoglayer = WeightEncodingandMixing(self.params, self.diodemixer)

            self.A2D = A2D(self.params)

        else:

            # Even though semianalog is False, we pass the same params, as nothing changes structurally
            self.analogprepro = ActivationEncoding(self.params)

            self.analoglayer = WeightEncodingandMixing(self.params)

            self.A2D = A2D(self.params)

        self.filter = MAFTFilter(self.params)

    def forward(self, x):

        x = self.flatten(x)

        x, _ = self.analogprepro(x)

        x = self.analoglayer(x)

        x = self.A2D(x)

        x = self.filter(x)

        return x


In [15]:
preprounits = torch.tensor(1) # current in A
wunits = torch.tensor(1) # voltage in V

unitconverter = torch.tensor(1) # converts voltage to conductance,
                                # and also has a unitless voltage divider like factor

R_AWG = 50; R_w = 50; R_out = 50; R_A2D = 50

weightfreqspacing = torch.tensor(1e3) # in Hz, 1 kHz
timespan = 1/weightfreqspacing
totalweightbandwidth = torch.tensor(1e8) # in Hz, 100 MHz
numcomblines = totalweightbandwidth/weightfreqspacing # which is the number of MACs

R_trans = torch.tensor(50) # ohms
RF_power = torch.tensor(250e-6) # Watts
Vref = torch.sqrt(RF_power*R_trans)

analogenergyperMAC = RF_power*timespan/numcomblines

k_th_AWG = 4*kb*T*R_AWG; k_th_w = 4*kb*T*R_w; k_th_out = 4*kb*T*R_out; k_th_A2D = 4*kb*T*R_A2D

## Multifreq plot

In [17]:
def fidelity(outputsize, inputsize, params, diodemixer, 
             scalefact=1., numtrials=1000, weightscalefact=None,
             omegac=None, omegadiff=None, total=None, uni=True):

    if not params.optics:
        params.k_th_AWG_lowtemp = 2 * (hbar * omegac * np.cosh(hbar * omegac / (2 * kb * T)) / np.sinh(hbar * omegac / (2 * kb * T))) * R_AWG
        params.k_th_w_lowtemp = 2 * (hbar * omegac * np.cosh(hbar * omegac / (2 * kb * T)) / np.sinh(hbar * omegac / (2 * kb * T))) * R_w
        params.k_th_out_lowtemp = 2 * (hbar * omegac * np.cosh(hbar * omegac / (2 * kb * T)) / np.sinh(hbar * omegac / (2 * kb * T))) * R_out
        params.k_th_A2D_lowtemp = 2 * (hbar * omegac * np.cosh(hbar * omegac / (2 * kb * T)) / np.sinh(hbar * omegac / (2 * kb * T))) * R_A2D

        k_th_AWG_lowtemp_opt = None
        k_th_w_lowtemp_opt = None
    else:

        params.k_th_AWG_lowtemp = 2 * (hbar * omegadiff * np.cosh(hbar * omegadiff / (2 * kb * T)) / np.sinh(hbar * omegadiff / (2 * kb * T))) * R_AWG
        params.k_th_w_lowtemp = 2 * (hbar * omegadiff * np.cosh(hbar * omegadiff / (2 * kb * T)) / np.sinh(hbar * omegadiff / (2 * kb * T))) * R_w
        params.k_th_out_lowtemp = 2 * (hbar * omegadiff * np.cosh(hbar * omegadiff / (2 * kb * T)) / np.sinh(hbar * omegadiff / (2 * kb * T))) * R_out
        params.k_th_A2D_lowtemp = 2 * (hbar * omegadiff * np.cosh(hbar * omegadiff / (2 * kb * T)) / np.sinh(hbar * omegadiff / (2 * kb * T))) * R_A2D

    # Refactored the network creation to use the params object
    my_network = AnaNetwork1layer(params, outputsize, inputsize, diodemixer)

    # Randomly generate matrix and input vector
    if not uni:
        weights = weightscalefact * torch.randn(outputsize, inputsize, dtype=torch.float64)
    else:
        weights = weightscalefact * (2*torch.rand(outputsize, inputsize, dtype=torch.float64)-1)
    
    my_network.analoglayer.W.data = weights
    
    if not uni:
        inputs = scalefact * torch.randn(numtrials, inputsize, dtype=torch.float64)
    else:
        inputs = scalefact * (2*torch.rand(numtrials, inputsize, dtype=torch.float64)-1)

    analogout = my_network(inputs)

    target = (weights @ inputs.t()).t()

    return analogout, target


In [35]:
def nsr_vs_energy(params, diodemixer, stddev=1, optics=False, 
                  inputsize=1000, outputsize=1, 
                  omegac=2*np.pi*1e9, omegadiff=2*np.pi*1e9, 
                  totalweightenergyopt=1e-12, totalweightenergyrf=1e-12,
                  energiespermac=None, numtrials=100, 
                  uni=True, 
                  maxsnr=True):

    results = []

    if optics:
        weightenergypermacopt = totalweightenergyopt / inputsize
    else:
        weightenergypermacrf = totalweightenergyrf / inputsize

    if params.red:
        n0 = 0
        r0 = outputsize * (inputsize - 1) / 2
        highestindex = int(2 * ((r0 + outputsize) + (n0 + inputsize) * outputsize) + 4)
        total = int(2 * (highestindex - 1))
    else:
        n0 = int(inputsize * (outputsize - 1) / 2)
        r0 = 0
        highestindex = int(2 * ((r0 + outputsize) * inputsize + (n0 + inputsize)) + 4)
        total = int(2 * (highestindex - 1))

#     delta_t = torch.tensor(1 / params.totBW, dtype=torch.float64)
    delta_t = torch.tensor(1 / (omegac/(2*np.pi)), dtype=torch.float64)

    for index, energypermac in enumerate(energiespermac):

        if optics:
            scalefact = torch.sqrt(energypermac / (hbar * omegac * stddev**2))
            weightscalefact = torch.sqrt(weightenergypermacopt / (hbar * omegac * stddev**2))
            noisefactor = 4
        else:
            scalefact = torch.sqrt(energypermac * R_trans / (delta_t * stddev**2))
            weightscalefact = torch.sqrt(weightenergypermacrf * R_trans / (delta_t * stddev**2))
            noisefactor = 4

        # Call fidelity function using the params object
        analogout, target = fidelity(outputsize, inputsize,
                                     params, diodemixer,
                                     scalefact=scalefact,
                                     weightscalefact=weightscalefact,
                                     numtrials=numtrials,
                                     omegac=omegac,
                                     omegadiff=omegadiff,
                                     total=total, 
                                     uni=uni)
        
        if not maxsnr:
            result = torch.mean((analogout - target).flatten()**2) / torch.mean(target.flatten()**2)
        else:
            result = torch.mean((analogout - target).flatten()**2) / (inputsize*scalefact*weightscalefact)**2

        results.append(result)

    return results


### plotting

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap, Normalize
import matplotlib.ticker as ticker

def plot_contour_with_dual_axes(start, end, ENOBrf, energiespermac, wbw=0.25e8, 
                                dukeplot=False, optics=False, uni=False, maxsnr=False, 
                                noisein=False, noiseout=True):
    """
    Plots a contour plot with power (W) on the bottom and left axes 
    and energy (J) on the top and right axes.
    
    The colorbar is positioned so it does not cover the right-side axis labels.
    The primary (bottom/left) axes are in log10(power), and the twin (top/right) 
    axes are in log10(energy). Ticks on both sets of axes are aligned exactly.
    """
    # 1) Convert energy -> power using tbin = 1 / wbw
    tbin = 1 / wbw
    powers = energiespermac / tbin  # For your data and contours
    
    fontsize = 13.5

    # 2) Convert those power values to log scale for contour plotting
    x_power = np.log10(powers)
    y_power = np.log10(powers)
    X, Y = np.meshgrid(x_power, y_power)

    # 3) Define a custom colormap
    colors = [(0.8, 0.0, 0.0), (0.85, 0.85, 0.0), (0.0, 0.8, 0.0)]
    cmap_custom = LinearSegmentedColormap.from_list("BlueGreenYellow", colors, N=256)

    # 4) Normalize the color mapping
    norm = Normalize(vmin=np.min(ENOBrf), vmax=np.max(ENOBrf))

    # 5) Create figure and main axes (for power)
    fig, ax_power = plt.subplots(figsize=(6, 5))

    # 6) Contour plot using the power domain
    contour_levels = np.arange(1, np.ceil(np.max(ENOBrf)) + 1, 1)
    CS = ax_power.contour(X, Y, ENOBrf, levels=contour_levels, 
                          linewidths=2.5, cmap=cmap_custom, norm=norm)
    ax_power.clabel(CS, inline=True, fontsize=fontsize, colors='black')
    
    if dukeplot:
        # marking Zhihui's points
        points_power = [
        (np.log10(0.5e-3), np.log10(0.5e-9)),
        (np.log10(0.5e-3), np.log10(0.5e-8)),
        (np.log10(0.5e-3), np.log10(0.5e-7)),
        ]

        for (px, py) in points_power:
            ax_power.plot(px, py, '*', markersize=12, color='black')

    # 7) Label the main (power) axes
    ax_power.set_xlabel('Client power (W)', fontsize=fontsize)
    ax_power.set_ylabel('Server power (W)', fontsize=fontsize)

    # -------------------------------------------------------------------------
    #   A) CHOOSE DESIRED ENERGY TICKS (these are the "most important" ticks)
    #      For example, from 1e-17 to 1e-8 in integer log steps
    # -------------------------------------------------------------------------
#     start, end = -17, -8  # you can adjust as desired
    desired_energy_ticks = [10.0**i for i in range(start, end+1)]

    # -------------------------------------------------------------------------
    #   B) CONVERT THESE ENERGY TICKS TO POWER (in linear space), 
    #      THEN TO log10(POWER) FOR THE MAIN AXIS
    # -------------------------------------------------------------------------
    power_tick_positions = [np.log10(E / tbin) for E in desired_energy_ticks]
    # We'll label them in scientific notation
    power_tick_labels = [f"{(E / tbin):.1e}" for E in desired_energy_ticks]

    # -------------------------------------------------------------------------
    #   C) APPLY THOSE TICKS TO THE MAIN (POWER) AXES
    # -------------------------------------------------------------------------
    # X-axis
    ax_power.set_xticks(power_tick_positions[::2])
    ax_power.set_xticklabels(power_tick_labels[::2], fontsize=fontsize)
    # Y-axis
    ax_power.set_yticks(power_tick_positions[::2])
    ax_power.set_yticklabels(power_tick_labels[::2], fontsize=fontsize)

    # -------------------------------------------------------------------------
    #   D) CREATE TWIN AXES FOR ENERGY (TOP AND RIGHT)
    # -------------------------------------------------------------------------
    ax_energy_top = ax_power.twiny()
    ax_energy_right = ax_power.twinx()

    # We'll use a log scale for these axes
    ax_energy_top.set_xscale('log')
    ax_energy_right.set_yscale('log')

    # -------------------------------------------------------------------------
    #   E) SET TICKS FOR THE ENERGY AXES
    # -------------------------------------------------------------------------
    # Top axis (client energy)
    ax_energy_top.set_xticks(desired_energy_ticks[::2])
    ax_energy_top.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.0e'))
    ax_energy_top.tick_params(axis='x', labelsize=fontsize)
    ax_energy_top.set_xlabel('Client energy per MAC (J)', fontsize=fontsize, labelpad=12)

    # Right axis (server energy)
    ax_energy_right.set_yticks(desired_energy_ticks[::2])
    ax_energy_right.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0e'))
    ax_energy_right.tick_params(axis='y', labelsize=fontsize)
    ax_energy_right.set_ylabel('Server energy per MAC (J)', fontsize=fontsize, 
                               rotation=-90, labelpad=20)

    # -------------------------------------------------------------------------
    #   F) FORCE THE MAIN AXES LIMITS & MATCH THE TWIN AXES LIMITS
    #      so that the top/right truly align with bottom/left
    #
    #   The main axis is in log10(power), so let's set xlim/ylim to 
    #   the first and last tick positions in log10(POWER).
    #
    #   Then for the twin axes (in energy), we apply the same range 
    #   but in linear scale for energy = 10^(log10(power)) * tbin.
    # -------------------------------------------------------------------------
    xmin_p, xmax_p = power_tick_positions[0], power_tick_positions[-1]
    ax_power.set_xlim(xmin_p, xmax_p)
    ax_power.set_ylim(xmin_p, xmax_p)

    # For the twin top axis: E = P * tbin, so in linear scale that is
    # from 10^xmin_p * tbin to 10^xmax_p * tbin
    ax_energy_top.set_xlim(10**xmin_p * tbin, 10**xmax_p * tbin)
    # For the twin right axis: same approach for y-limits
    ax_energy_right.set_ylim(10**xmin_p * tbin, 10**xmax_p * tbin)

    # -------------------------------------------------------------------------
    #   G) Grid and style adjustments
    # -------------------------------------------------------------------------
    for spine in ax_power.spines.values():
        spine.set_linewidth(1.5)
    ax_power.grid(True, linestyle=(0, (1, 3)), linewidth=2)

    # Adjust layout so the colorbar doesn't overlap the right axis
    fig.subplots_adjust(right=0.85)

    # -------------------------------------------------------------------------
    #   H) Add the colorbar on the right
    # -------------------------------------------------------------------------
    sm = plt.cm.ScalarMappable(cmap=CS.cmap, norm=CS.norm)
    sm.set_array([])

    cbar_ax = fig.add_axes([1.02, 0.15, 0.03, 0.7])  # [left, bottom, width, height]
    colorbar = plt.colorbar(sm, cax=cbar_ax)
    colorbar.ax.tick_params(labelsize=fontsize)
    colorbar.outline.set_linewidth(1.5)
    if not optics:
        colorbar.set_label('RF mixer output effective number of bits', fontsize=fontsize, 
                           rotation=270, labelpad=15)
    else:
        colorbar.set_label('Homodyne output effective number of bits', fontsize=fontsize, 
                           rotation=270, labelpad=15)
#     fig.suptitle("RF output bit precision", fontsize=fontsize+2, y=1.02)

    if not optics:
        plt.savefig(f"Figures/ENOB_rf_2D_Jun3_eta{eta}_uni{uni}_maxsnr{maxsnr}_noisein{noisein}_noiseout{noiseout}.pdf", bbox_inches='tight')
    else:
        plt.savefig(f"Figures/ENOB_optics_2D_Jun3_uni{uni}_maxsnr{maxsnr}.pdf", bbox_inches='tight')

    plt.show()
    


In [41]:
# Define the start and end of the log scale
start = -17  # log10(1e-20)
end = -10   # log10(1e-11)

# Define the number of points in the tensor
num_points = 16  # You can adjust this number as needed, was 37

# Create the tensor with log-spaced values
energiespermac = torch.logspace(start, end, steps=num_points, dtype=torch.float64)

In [42]:
includecarrier = True

carrieroffmult = 40 # UHF cell
totBW = torch.tensor(0.25e8, dtype=torch.float64)

UHFfreq = carrieroffmult*totBW # UHF cell

omegac = 2*np.pi*UHFfreq

inputsize = 256
outputsize = 1

eta = 1.

noisein = True
noiseout = True

In [43]:
constants = PhysicalConstants(eta=eta)
noisemodel = NoiseModel(constants=constants, bandwidth=0.25e8)

diodemixer = DiodeMixing(constants=constants, noise_model=noisemodel)

In [44]:
paramsrf = Params(weightshape=[outputsize, inputsize], red=True, units=1, fixedband=True, totBW=0.25e8,
                 noisein=noisein, noiseout=noiseout, lowtemp=True, weightfreqspacing=1e3,
                 k_th_AWG=k_th_AWG, 
                 k_th_w=k_th_w, k_th_out=k_th_out, 
                   k_th_A2D=k_th_A2D,
                 scaletoz=False, scaletoW=False, wunits=1, unitconverter=1,
                 includeinputfft=True, includeoutputfft=True, semianalog=True,
                 digitallike=False, noisefactor=4, optics=False,
                 photodiodearea=1e-6, omegaoptical=2*torch.pi*cbandfreq,
                 inttime=1e-7, capval=1e-12, 
                  includecarrier=includecarrier, carrieroffmult=carrieroffmult)

In [45]:
# inputsize = 256
totalweightenergyrf_array = energiespermac*inputsize

results_rf_all_nsr = []

uni = True
maxsnr = True

if uni:
    stddev = np.sqrt(1/3)
else:
    stddev = 1

for idx, totalweightenergyrf in enumerate(totalweightenergyrf_array):
    
    results_rf = nsr_vs_energy(paramsrf, diodemixer, stddev=stddev, optics=False, 
                  inputsize=inputsize, outputsize=1, 
                  omegac=2*np.pi*UHFfreq, omegadiff=2*np.pi*UHFfreq, 
                  totalweightenergyopt=None, totalweightenergyrf=totalweightenergyrf,
                  energiespermac=energiespermac, 
                               numtrials=200, 
                               uni=uni, 
                               maxsnr=maxsnr)
    
    results_rf_all_nsr.append([tensor.item() for tensor in results_rf])
    print(f"iteration {idx} done")

  delta_t = torch.tensor(1 / (omegac/(2*np.pi)), dtype=torch.float64)


iteration 0 done
iteration 1 done
iteration 2 done
iteration 3 done
iteration 4 done
iteration 5 done
iteration 6 done
iteration 7 done
iteration 8 done
iteration 9 done
iteration 10 done
iteration 11 done
iteration 12 done
iteration 13 done
iteration 14 done
iteration 15 done


In [46]:
ENOBrf_nsr = 0.5*np.log(1+1/np.array(results_rf_all_nsr))/np.log(2)

In [None]:
plot_contour_with_dual_axes(start, end, ENOBrf_nsr, energiespermac, wbw=0.25e8, 
                            optics=False, dukeplot=False, uni=uni, maxsnr=maxsnr, 
                            noisein=noisein, noiseout=noiseout)

In [None]:
import json
with open(f"ENOBrf_nsr_0.25e8_unimaxsnr_noisein_Jun4.json", "w") as f:
    json.dump(ENOBrf_nsr.tolist(), f)