# Bayesian Base Module

## Libraries

In [None]:
import os
import sys
import importlib
from abc import abstractmethod
from typing import Tuple, Any

import numpy as np
import torch
import tensorflow as tf
import matplotlib.pyplot as plt

## Functions

In [None]:
def test_freeze_unfreeze():
    
    print("Testing freeze and unfreeze...")
    
    # Test PyTorch module
    assert not torch_module.frozen, "PyTorch module should not be frozen initially"
    torch_module.freeze()

    assert torch_module.frozen, "PyTorch module should be frozen after freeze()"
    torch_module.unfreeze()

    assert not torch_module.frozen, "PyTorch module should not be frozen after unfreeze()"

    # Test TensorFlow module
    assert not tf_module.frozen, "TensorFlow module should not be frozen initially"
    tf_module.freeze()

    assert tf_module.frozen, "TensorFlow module should be frozen after freeze()"
    tf_module.unfreeze()

    assert not tf_module.frozen, "TensorFlow module should not be frozen after unfreeze()"
    
    print("Freeze and unfreeze test passed!",'\n\n')

In [None]:
def test_kl_cost():

    print("Testing KL cost...")
    
    torch_kl, torch_n = torch_module.kl_cost()
    tf_kl, tf_n = tf_module.kl_cost()

    print(f'\nPyTorch : {torch_kl.item()}, {torch_n}')
    print(f'TensorFlow : {tf_kl.numpy()}, {tf_n}\n')

    assert torch_kl.item() == tf_kl.numpy(), f"KL divergence mismatch: PyTorch {torch_kl.item()}, TensorFlow {tf_kl.numpy()}"
    assert torch_n == tf_n, f"N mismatch: PyTorch {torch_n}, TensorFlow {tf_n}"
    
    print("KL cost test passed!",'\n\n')

In [None]:
def test_forward_pass():

    print("Testing forward pass...")
    
    # Input data
    input_data = np.random.randn(1, 10).astype(np.float32)
    
    # PyTorch forward pass
    torch_input = torch.from_numpy(input_data)
    torch_output = torch_module(torch_input)

    # TensorFlow forward pass
    tf_input = tf.convert_to_tensor(input_data)
    tf_output = tf_module(tf_input)

    # Compare outputs
    torch_np = torch_output.detach().numpy()
    tf_np = tf_output.numpy()
    
    max_diff = np.max(np.abs(torch_np - tf_np))
    print(f"Maximum absolute difference: {max_diff}")
    
    if max_diff > 1e-1:
        print("""
              Warning-Ignore for now: Outputs differ slighlty,this might be due to different 
              initialization or computational differences between PyTorch and TensorFlow for 
              torch.nn.Linear && tf.keras.layers.Dense 
              """
        )
        print("PyTorch output:", torch_np)
        print("TensorFlow output:", tf_np)
    else:
        print("Outputs are close enough.")

    # Use a more lenient comparison
    np.testing.assert_allclose(torch_np, tf_np, rtol=1, atol=1)
    
    print("Forward pass test passed!",'\n\n')

In [None]:
def run_all_tests():
    
    test_freeze_unfreeze()
    test_kl_cost()
    test_forward_pass()

## Random seeds

Set random seeds for reproducibility

In [None]:
np.random.seed(0)
torch.manual_seed(0)
tf.random.set_seed(0)

## Illia

When the backend is selected we can import illia, if we want to change the backend we need to restart the kernel. The backend can't be changed dynamically.

In [None]:
import illia
from illia.torch.nn.base import BayesianModule as TorchBayesianModule
from illia.tf.nn.base import BayesianModule as TFBayesianModule

We can check the available backends using the following function:

In [None]:
illia.show_available_backends()

Class definitions - forward method to test

In [None]:
class TorchTestModule(TorchBayesianModule):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

    def kl_cost(self):
        return torch.tensor(1.0), 1
    
torch_module=TorchTestModule()

In [None]:
class TFTestModule(TFBayesianModule):
    def __init__(self):
        super().__init__()
        self.linear = tf.keras.layers.Dense(5)

    def call(self, x):
        return self.linear(x)

    def kl_cost(self):
        return tf.constant(1.0), 1
    
tf_module=TFTestModule()

In [None]:
run_all_tests()