In [None]:
import numpy as np
import time
from sega_learn.neural_networks.neuralNetwork import *
from sega_learn.neural_networks.numba_utils import (
    relu as relu_nb, 
    relu_derivative as relu_derivative_nb,
    leaky_relu as leaky_relu_nb,
    leaky_relu_derivative as leaky_relu_derivative_nb,
    tanh as tanh_nb,
    tanh_derivative as tanh_derivative_nb,
    sigmoid as sigmoid_nb,
    sigmoid_derivative as sigmoid_derivative_nb,
    softmax as softmax_nb,
)

from sega_learn.neural_networks.numba_utils import sum_axis0, sum_reduce


In [4]:
def compare_outputs(func1, func2, *args):
    output1 = func1(*args)
    output2 = func2(*args)
    
    tolerance = 1e-7
    if np.allclose(output1, output2, atol=tolerance):
        # print(f"{func1.__name__} and {func2.__name__} outputs within tolerance of {tolerance}.")
        pass
    else:
        fail = True
        print(f"\n{func1.__name__} and {func2.__name__} outputs are not within tolerance of {tolerance}.")

#### Compare Activation and Activation Derivatives

In [None]:
z = np.random.randn(1000, 2)

# ReLU and ReLU Derivative
compare_outputs(Activation.relu, relu_nb, z)
compare_outputs(Activation.relu_derivative, relu_derivative_nb, z)

# Leaky ReLU and Leaky ReLU Derivative
compare_outputs(Activation.leaky_relu, leaky_relu_nb, z)
compare_outputs(Activation.leaky_relu_derivative, leaky_relu_derivative_nb, z)

# Tanh and Tanh Derivative
compare_outputs(Activation.tanh, tanh_nb, z)
compare_outputs(Activation.tanh_derivative, tanh_derivative_nb, z)

# Sigmoid and Sigmoid Derivative
compare_outputs(Activation.sigmoid, sigmoid_nb, z)
compare_outputs(Activation.sigmoid_derivative, sigmoid_derivative_nb, z)

# Softmax
compare_outputs(Activation.softmax, softmax_nb, z)

#### Compare utils

In [54]:
# Generate random data
X = np.random.randn(1000, 1000)

# Result for np.sum
np_sum_result = np.sum(X, axis=0, keepdims=True)

# Result for sum_axis0
sum_axis0_result = sum_axis0(X)

# Verify that results are the same
tolerance = 1e-6
if np.allclose(np_sum_result, sum_axis0_result, atol=tolerance):
    print("Results match!")
else:
    print("Results do not match!")
    diff_index = np.where(np.abs(np_sum_result - sum_axis0_result) > tolerance)[0][0]
    print(f"Difference found at index {diff_index}: {np_sum_result[0, diff_index]} vs {sum_axis0_result[diff_index]}")

Results match!


#### Compare Loss Functions

In [None]:
# Import the loss functions from your modules
from sega_learn.neural_networks.loss import CrossEntropyLoss, BCEWithLogitsLoss
from sega_learn.neural_networks.loss_jit import JITCrossEntropyLoss, JITBCEWithLogitsLoss

In [37]:
# Compare Cross Entropy Loss

# Generate some dummy data for multi-class classification
n_samples, n_classes = 5, 3

logits_ce = np.random.randn(n_samples, n_classes)

# Generate integer targets and convert to one-hot
targets_int = np.random.randint(0, n_classes, size=n_samples)
targets_onehot = np.eye(n_classes)[targets_int]

# Instantiate loss function objects
base_ce_loss = CrossEntropyLoss()
jit_ce_loss = JITCrossEntropyLoss()

# Calculate losses
loss_base_ce = base_ce_loss(logits_ce, targets_onehot)
loss_jit_ce = jit_ce_loss.calculate_loss(logits_ce, targets_onehot)

tolerance = 1e-7
print("Cross Entropy Loss Comparison:")
print("-"*75)
if np.allclose(loss_base_ce, loss_jit_ce, atol=tolerance):
    print("Losses are equal to within tolerance of", tolerance)
print("Base Loss     :", loss_base_ce)
print("JIT Loss      :", loss_jit_ce)
print("Difference    :", abs(loss_base_ce - loss_jit_ce))


Cross Entropy Loss Comparison:
---------------------------------------------------------------------------
Losses are equal to within tolerance of 1e-07
Base Loss     : 1.3912039674724226
JIT Loss      : 1.391203967472429
Difference    : 6.439293542825908e-15


In [None]:
# Compare Binary Cross Entropy Loss

# Generate some dummy data for binary classification
n_samples_bce = 10
logits_bce = np.random.randn(n_samples_bce)

# Generate binary targets (0 or 1)
targets_bce = np.random.randint(0, 2, size=n_samples_bce)

# Instantiate loss function objects
base_bce_loss = BCEWithLogitsLoss()
jit_bce_loss = JITBCEWithLogitsLoss()

# Calculate losses
loss_base_bce = base_bce_loss(logits_bce, targets_bce)
loss_jit_bce = jit_bce_loss.calculate_loss(logits_bce, targets_bce)


tolerance = 1e-7
print("\nBCE With Logits Loss Comparison:")
print("-"*50)
if np.allclose(loss_base_bce, loss_jit_bce, atol=tolerance):
    print("Losses are equal to within tolerance of", tolerance)

print("Base Loss     :", loss_base_bce)
print("JIT Loss      :", loss_jit_bce)
print("Difference    :", abs(loss_base_bce - loss_jit_bce))



BCE With Logits Loss Comparison:
--------------------------------------------------
Losses are equal to within tolerance of 1e-07
Base Loss     : 0.5771172002684145
JIT Loss      : 0.5771172002684144
Difference    : 1.1102230246251565e-16
