In [19]:
# MLP Architectures
mlp1 = {"net": "mlp", "optimizer": "sgd", "layers": [{"d_in": 784, "d_out": 100}, {"d_in": 100, "d_out": 50}, {"d_in": 50, "d_out": 10}]}
mlp2 = {"net": "mlp", "optimizer": "sgd", "layers": [{"d_in": 784, "d_out": 200}, {"d_in": 200, "d_out": 100}, {"d_in": 100, "d_out": 50}, {"d_in": 50, "d_out": 10}]}
mlp3 = {"net": "mlp", "optimizer": "adam", "layers": [{"d_in": 784, "d_out": 1024}, {"d_in": 1024, "d_out": 1024}, {"d_in": 1024, "d_out": 1024}, {"d_in": 1024, "d_out": 10}]}
mlp4 = {"net": "mlp", "optimizer": "adam", "layers": [{"d_in": 1024, "d_out": 3000}, {"d_in": 3000, "d_out": 3000}, {"d_in": 3000, "d_out": 3000}, {"d_in": 3000, "d_out": 10}]}

# Conv Architectures
conv1 = {"net": "conv", "optimizer": "adam", "layers": [{"d_in": 32, "d_out": 32, "c_in": 3, "c_out": 128, "k": 3}, {"d_in": 32, "d_out": 32, "c_in": 128, "c_out": 256, "k": 3}, {"d_in": 16, "d_out": 16, "c_in": 256, "c_out": 256, "k": 3}, {"d_in": 16, "d_out": 16, "c_in": 256, "c_out": 512, "k": 3}, {"d_in": 8, "d_out": 8, "c_in": 512, "c_out": 512, "k": 3}, {"d_in": 4, "d_out": 4, "c_in": 512, "c_out": 512, "k": 3}, {"d_in": 2048, "d_out": 1024}, {"d_in": 1024, "d_out": 10}]}
conv2 = {"net": "conv", "optimizer": "adam", "layers": [{"d_in": 32, "d_out": 32, "c_in": 3, "c_out": 128, "k": 3}, {"d_in": 32, "d_out": 32, "c_in": 128, "c_out": 128, "k": 3}, {"d_in": 32, "d_out": 32, "c_in": 128, "c_out": 128, "k": 3}, {"d_in": 32, "d_out": 32, "c_in": 128, "c_out": 256, "k": 3}, {"d_in": 16, "d_out": 16, "c_in": 256, "c_out": 256, "k": 3}, {"d_in": 16, "d_out": 16, "c_in": 256, "c_out": 512, "k": 3}, {"d_in": 8, "d_out": 8, "c_in": 512, "c_out": 512, "k": 3}, {"d_in": 8, "d_out": 8, "c_in": 512, "c_out": 512, "k": 3}, {"d_in": 4, "d_out": 4, "c_in": 512, "c_out": 512, "k": 3}, {"d_in": 2048, "d_out": 1024}, {"d_in": 1024, "d_out": 10}]}

In [20]:
nitro = False          # False = SGD, True = Our solution

In [22]:
# ===== Config =====
bs = 64                     # batch size
LOCAL_FC_IN_CONV = 4096     # conv-layer local head input dim
LOCAL_FC_OUT = 10           # local head output classes
   
for i, net in enumerate([mlp1, mlp2, mlp3, mlp4, conv1, conv2]):
    if nitro:
        print(net["net"].upper() + f"{i} - Integer NITRO")
    else:
        print(net["net"].upper() + f"{i} - FP Backprop")

    # SGD update costs (per param); no momentum
    if nitro or net["optimizer"] == "sgd":
        PER_PARAM_MUL = 2
        PER_PARAM_ADD = 2
        OPT_STATE_MULT = 0          # persistent optimizer state elements per param
    else:
        PER_PARAM_MUL = 9
        PER_PARAM_ADD = 4
        OPT_STATE_MULT = 2          # m, v per param

    # ---- Ops ----
    batch_mul, batch_add = 0, 0   # per batch (fwd + bwd over bs)
    upd_mul, upd_add = 0, 0       # per-step (optimizer update once)

    # ---- Memory (elements, not bytes) ----
    mem_w = 0                     # parameters (Σ per layer incl. local heads)
    mem_h = 0                     # saved activations (Σ per base layer × bs)
    mem_d_peak = 0                # parameter gradients (MAX, streaming)
    mem_opt = 0                   # optimizer state (persistent, here 0 with plain SGD)

    L = len(net)
    for li, layer in enumerate(net["layers"]):
        # Decide layer kind: conv if "c_in" given; otherwise linear
        is_conv = ("c_in" in layer)
        Cin, Cout, Din, Dout, k = 0, 0, 0, 0, 0
        
        if is_conv:            
            # ---------- CONV layer ----------
            Cin  = layer["c_in"]
            Cout = layer["c_out"]
            Din  = layer["d_in"]          # spatial in (assume square)
            Dout = layer["d_out"]         # spatial out (assume square)
            k    = layer.get("k", 1)

            K   = Cin * k * k
            S   = Dout * Dout             # output spatial sites
            Sin = Din * Din               # input spatial sites

            # ---- per-sample ops (no bias) ----
            # forward
            mul_fwd = Cout * S * K
            add_fwd = Cout * S * (K - 1)

            # weight gradients dW
            mul_dW = Cout * Cin * k * k * S
            # dW adds handled at batch level: (bs*S - 1) per weight

            # input gradients dX
            mul_dX = Cin * Sin * Cout * k * k
            add_dX = Cin * Sin * (Cout * k * k - 1)

            # ---- per-batch totals (correct dW adds over batch) ----
            batch_mul += bs * (mul_fwd + mul_dW + mul_dX)
            dW_add_batch = (Cout * Cin * k * k) * (bs * S - 1)
            batch_add += bs * (add_fwd + add_dX) + dW_add_batch

            # ---- memory for base layer ----
            params = Cout * Cin * k * k
            mem_w += params
            mem_h += bs * Cout * S
            mem_d_peak = max(mem_d_peak, params)

            # ---- SGD optimizer update (once per step) ----
            upd_mul += params * PER_PARAM_MUL
            upd_add += params * PER_PARAM_ADD
            mem_opt += params * OPT_STATE_MULT  # stays 0

            # ------- local head for HIDDEN conv layers -------
            if nitro and (li < L - 2):
                lh_in, lh_out = LOCAL_FC_IN_CONV, LOCAL_FC_OUT
                lh_params = lh_in * lh_out  # no bias

                # ops for linear head (S=1)
                lh_mul_fwd = lh_out * lh_in
                lh_add_fwd = lh_out * (lh_in - 1)
                lh_mul_dW = lh_in * lh_out
                lh_add_dW_batch = lh_in * lh_out * (bs - 1)
                lh_mul_dX = lh_in * lh_out
                lh_add_dX = lh_in * (lh_out - 1)

                batch_mul += bs * (lh_mul_fwd + lh_mul_dW + lh_mul_dX)
                batch_add += bs * (lh_add_fwd + lh_add_dX) + lh_add_dW_batch

                mem_w += lh_params
                mem_d_peak = max(mem_d_peak, lh_params)
                upd_mul += lh_params * PER_PARAM_MUL
                upd_add += lh_params * PER_PARAM_ADD
                # no extra mem_h: head uses the base layer activation

        else:
            # ---------- LINEAR layer ----------
            din  = layer["d_in"]          # feature in
            dout = layer["d_out"]         # feature out

            # per-sample ops (no bias)
            mul_fwd = dout * din
            add_fwd = dout * (din - 1)

            # dW, dX per-sample (dW adds at batch level)
            mul_dW = dout * din
            mul_dX = din * dout
            add_dX = din * (dout - 1)

            # per-batch totals (here S=1)
            batch_mul += bs * (mul_fwd + mul_dW + mul_dX)
            dW_add_batch = (dout * din) * (bs - 1)
            batch_add += bs * (add_fwd + add_dX) + dW_add_batch

            # memory for base layer
            params = dout * din
            mem_w += params
            mem_h += bs * dout
            mem_d_peak = max(mem_d_peak, params)

            # SGD update
            upd_mul += params * PER_PARAM_MUL
            upd_add += params * PER_PARAM_ADD
            mem_opt += params * OPT_STATE_MULT

            # ------- local head for HIDDEN linear layers -------
            if nitro and (li < L - 2):
                lh_in, lh_out = dout, LOCAL_FC_OUT
                lh_params = lh_in * lh_out  # no bias

                lh_mul_fwd = lh_out * lh_in
                lh_add_fwd = lh_out * (lh_in - 1)
                lh_mul_dW = lh_in * lh_out
                lh_add_dW_batch = lh_in * lh_out * (bs - 1)
                lh_mul_dX = lh_in * lh_out
                lh_add_dX = lh_in * (lh_out - 1)

                batch_mul += bs * (lh_mul_fwd + lh_mul_dW + lh_mul_dX)
                batch_add += bs * (lh_add_fwd + lh_add_dX) + lh_add_dW_batch

                mem_w += lh_params
                mem_d_peak = max(mem_d_peak, lh_params)
                upd_mul += lh_params * PER_PARAM_MUL
                upd_add += lh_params * PER_PARAM_ADD
                # no extra mem_h

    # ---- Per-step totals ----
    step_mul = batch_mul + upd_mul
    step_add = batch_add + upd_add

    # ---- Costs from https://ieeexplore.ieee.org/abstract/document/6757323 ----
    if nitro:
        mul_cost = 3.1*10**(-12)
        add_cost = 0.1*10**(-12)
    else:
        mul_cost = 3.7*10**(-12)
        add_cost = 0.9*10**(-12)
    
    # print("MUL:", step_mul)
    # print("ADD:", step_add)
    
    # Convert Joules to mJ, Bytes to MB
    print("Energy:", (step_mul*mul_cost+step_add*add_cost) * 1000, "mJ")
    
    input_dim =  net["layers"][0]["c_in"] *  net["layers"][0]["d_in"]**2 if net["net"] == "conv" else net["layers"][0]["d_in"]
    if nitro:
        print("Memory:", (input_dim*bs*8 + mem_w*16 + mem_h*8 + mem_d_peak*32)/8/1000/1000, "MB")  # 32 bits per element, /8 to bytes, /1000/1000 to MB
    else:
        print("Memory:", (input_dim*bs*32 + mem_w*32 + mem_h*32 + mem_d_peak*32 + mem_opt*32)/8/1000/1000, "MB")  # 32 bits per element, /8 to bytes, /1000/1000 to MB

    print()

MLP0 - FP Backprop
Energy: 0.0747338356 mJ
Memory: 0.890864 MB

MLP1 - FP Backprop
Energy: 0.1624343956 mJ
Memory: 1.6492639999999998 MB

MLP2 - FP Backprop
Energy: 2.6746635648 mJ
Memory: 40.106496 MB

MLP3 - FP Backprop
Energy: 19.395862041599997 mJ
Memory: 291.792704 MB

CONV4 - FP Backprop
Energy: 838.7449070976 mJ
Memory: 278.694912 MB

CONV5 - FP Backprop
Energy: 1238.8826883456002 mJ
Memory: 386.04288 MB

