In [2]:
import torch as th
import torch.nn as nn
import timeit
import statistics
from branchNetwork.BranchLayer import BranchLayer as BL # Adjust the import according to your project structure
from branchNetwork.BranchLayerMM import BranchLayer as BLMM # Adjust the import accordingly as BranchLayerMM

# Optional: If you use GPU and want to measure GPU memory
# th.cuda.reset_peak_memory_stats()
# th.cuda.empty_cache()

In [15]:
def test_branch_layer(layer_class, branch_params, num_trials=1000):
    durations = []
    memory_usages = []
    grad_checks = []

    for _ in range(num_trials):
        branch_layer = layer_class(**branch_params)
        branch_layer.train()  # Ensure the layer is in training mode for gradient checks

        x = th.randn(5, branch_params['n_in'], requires_grad=True)
        start_time = timeit.default_timer()

        # Forward pass
        out = branch_layer(x)
        assert out.shape == (5, branch_params['n_b'], branch_params['n_next_h']), "Output shape mismatch"

        # Backward pass for gradient check
        out.sum().backward()

        # Record the end time and compute the duration
        durations.append(timeit.default_timer() - start_time)

        # Memory usage check (if on GPU)
        if th.cuda.is_available():
            memory_usages.append(th.cuda.memory_allocated() / (1024 ** 2))  # Memory in MB

        # Gradient check
        grad_checks.append(x.grad is not None)

        # Reset gradients
        branch_layer.zero_grad()
        if x.grad is not None:
            x.grad.zero_()

    # Calculate average time and memory usage
    avg_duration = statistics.mean(durations)
    std_duration = statistics.stdev(durations)
    avg_memory_usage = statistics.mean(memory_usages) if memory_usages else 0
    std_memory_usage = statistics.stdev(memory_usages) if memory_usages else 0

    all_gradients_ok = all(grad_checks)

    print(f"Layer: {layer_class.__name__}")
    print(f"Average Duration: {avg_duration:.4f} sec, Std Dev: {std_duration:.4f} sec")
    print(f"Average Memory Usage: {avg_memory_usage:.2f} MB, Std Dev: {std_memory_usage:.2f} MB")
    print(f"All gradients computed: {all_gradients_ok}\n")

    return avg_duration, avg_memory_usage, all_gradients_ok


In [23]:
def gpu_test_branch_layer(layer_class_1, layer_class_2, branch_params, num_trials=10, device='cpu'):
    durations = []
    memory_usages = []
    outputs_equivalent = []

    x = th.randn(5, branch_params['n_in'], requires_grad=True).to(device)
    expected_output = None

    for _ in range(num_trials):
        # Instantiate layers and move to specified device
        layer1 = layer_class_1(**branch_params).to(device)
        layer2 = layer_class_2(**branch_params).to(device)

        layer1.train()
        layer2.train()

        start_time = timeit.default_timer()

        # Forward pass for both layers
        out1 = layer1(x)
        out2 = layer2(x)
        print(out1)
        print(out2)

        # Check for equivalence of outputs
        outputs_equivalent.append(th.allclose(out1, out2, atol=1e-6))

        # Backward pass for gradient checks (on one to save time)
        out1.sum().backward()

        # Record time and reset gradients
        durations.append(timeit.default_timer() - start_time)
        layer1.zero_grad()
        x.grad.zero_()

        # Memory usage check (if on GPU)
        if device == 'cuda':
            memory_usages.append(th.cuda.memory_allocated() / (1024 ** 2))  # Memory in MB

    # Calculate statistics
    avg_duration = statistics.mean(durations)
    std_duration = statistics.stdev(durations)
    avg_memory_usage = statistics.mean(memory_usages) if memory_usages else 0
    std_memory_usage = statistics.stdev(memory_usages) if memory_usages else 0
    output_consistency = all(outputs_equivalent)

    print(f"Layer comparison: {layer_class_1.__name__} vs {layer_class_2.__name__}")
    print(f"Average Duration: {avg_duration:.4f} sec, Std Dev: {std_duration:.4f} sec")
    print(f"Average Memory Usage: {avg_memory_usage:.2f} MB, Std Dev: {std_memory_usage:.2f} MB")
    print(f"Output consistency across all trials: {output_consistency}\n")

    return avg_duration, avg_memory_usage, output_consistency


In [24]:
# Set parameters common to both layer types
branch_params = {
    'n_in': 800,
    'n_npb': 40,
    'n_b': 20,
    'n_next_h': 400
}

# Run test for BranchLayer
test_branch_layer(BL, branch_params)

# Run test for BranchLayerMM
test_branch_layer(BLMM, branch_params)


Layer: BranchLayer
Average Duration: 0.0067 sec, Std Dev: 0.0005 sec
Average Memory Usage: 0.00 MB, Std Dev: 0.00 MB
All gradients computed: True

Layer: BranchLayer
Average Duration: 0.0028 sec, Std Dev: 0.0023 sec
Average Memory Usage: 0.00 MB, Std Dev: 0.00 MB
All gradients computed: True



(0.002842916499823332, 0, True)

In [26]:
# Device configuration
device = 'cuda' if th.cuda.is_available() else 'cpu'

# Parameters common to both layer types
branch_params = {
    'n_in': 8,
    'n_npb': 4,
    'n_b': 2,
    'n_next_h': 4
}

# Run test for both BranchLayer versions
gpu_test_branch_layer(BL, BLMM, branch_params, device=device)


tensor([[[ 2.2933e-01, -2.4450e+00, -1.2587e-01,  8.2587e-01],
         [ 3.4846e-01,  1.8635e+00, -3.8021e-01, -2.0549e-01]],

        [[-5.0282e-01,  3.8076e-01,  6.8913e-01, -1.3163e-01],
         [ 1.4247e+00,  6.6359e-01,  2.7320e-01, -2.4472e-01]],

        [[ 2.0457e-01, -6.7910e-01,  2.5330e-01,  1.0349e-01],
         [ 9.5158e-01,  1.1612e+00, -1.6054e-01, -1.8892e-01]],

        [[ 4.6132e-01,  7.5304e-01,  5.1296e-01,  2.5104e-01],
         [ 1.6296e-03, -6.3104e-01,  2.3748e-02,  9.0217e-02]],

        [[ 4.5345e-01, -8.4221e-01, -1.2751e+00,  2.6968e-01],
         [-1.2051e+00, -1.1460e+00,  5.1721e-01,  4.0247e-01]]],
       grad_fn=<ViewBackward0>)
tensor([[[-1.9712e+00, -4.6442e-01,  3.5482e+00, -1.6731e+00],
         [-5.8860e-01, -6.4764e-02,  9.5958e-01,  1.5001e+00]],

        [[ 3.1809e-03, -6.5825e-01,  9.3834e-01, -8.8280e-01],
         [-6.1468e-01,  7.0424e-02, -1.8015e-01, -2.9362e-01]],

        [[-1.3388e+00, -7.0561e-01,  1.3234e+00, -1.0264e+00],
         

(0.0009005951229482889, 0, False)