In [1]:
import jax
import numpy as np
import jax.numpy as jnp
import haiku as hk
from _src.dataset.cifar10 import load_cifar10
from _src.logger.WANDB import WANDBLogger
import matplotlib.pyplot as plt

jax.disable_jit(disable=True)

  from .autonotebook import tqdm as notebook_tqdm


<contextlib._GeneratorContextManager at 0x105f046d0>

In [2]:
train_loader, val_loader, test_loader = load_cifar10(batch_size=32)

Train size: 40000, Eval size: 10000, Test size: 10000
Batch images shape: torch.Size([32, 32, 32, 3]), Batch labels shape: torch.Size([32])


In [3]:
images, labels = next(iter(train_loader))
images.shape, labels.shape

(torch.Size([32, 32, 32, 3]), torch.Size([32]))

In [None]:
def show_samples(loader, class_names=None, n=8):
    """Show n samples from a PyTorch DataLoader."""
    images, labels = next(iter(loader))
    images = images[:n]
    labels = labels[:n]

    plt.figure(figsize=(12, 2))
    for i in range(n):
        plt.subplot(1, n, i + 1)
        img = images[i]
        plt.imshow(img)
        if class_names is not None:
            plt.title(class_names[labels[i]])
        plt.axis('off')
    plt.show()

# CIFAR-10 class names:
class_names = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Show a few samples from each split
show_samples(train_loader, class_names, n=8)
show_samples(val_loader, class_names, n=8)
show_samples(test_loader, class_names, n=8)


In [None]:
def patch_positions(input_shape, patch_shape):
    H, W, C = input_shape
    ph, pw, _ = patch_shape

    nh, nw = H // ph, W // pw

    # patch centers
    y = (jnp.arange(nh) * ph + ph/2) - H/2
    x = (jnp.arange(nw) * pw + pw/2) - W/2

    # normalize to [-1, 1]
    y = (y / (H/2)).astype(jnp.float32)
    x = (x / (W/2)).astype(jnp.float32)

    grid_y, grid_x = jnp.meshgrid(y, x, indexing="ij")
    positions = jnp.stack([grid_y, grid_x], axis=-1).reshape(-1, 2)

    return positions  # shape (num_patches, 2), dtype float32

In [None]:
"""
 The Inference of a cortical column
(torch.Size([32, 32, 32, 3]), torch.Size([32]))
Args:
    sensory_input:
    location_input:
    action_efference:
Layers:
    Layer 4: The what signal, Recieves sensory input 
        takes -> Sensory Input
    
    Layer 5: Recieves an action efference copy and passes it to layer 6
    Layer 6: The where signal, Grid cell like encoding position of the sensor
        Inputs -> Location Input, action efference (from layer 5)
        This layer performs path integration using the input

    Layer 1: Biological Wiring (Can be ignored) [Assumption made, I could be catastophically wrong ;)]
    Layer 2/3: Place Neurons: This is the binding layer combining what and where infor
        Inputs -> Sensory Input from L4, and the location context from L6
    
returns:
    sensory_location_latent:
    location_update:
Concerns:
    - How many neurons in each column
    - Should any part of this be learnable or should it just be an algorithm ? (it's OK if learnable)
""" 
class CorticalColumn(hk.Module):
    NetName = 'CorticalColumn'
    # The core idea is learning the structure of things through sensory motor coupling
    # Every column has its own sensory-motor loop/system - Thousand Brains Theory
    # Reference frame, path integration
    # Columns gain predictive power after a large enough refrence frame is built
    # Brain Region -> Sensory Input, Motor Output
    def __init__(
            self,
            n_hidden,
            hidden_size,
            output_size,
            enable_inhibition=False,
            name='CorticalColumn'
    ):
        super().__init__(name=name)
        self.n_hidden = n_hidden
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.enable_inhibition = enable_inhibition

    def __call__(self, sensory_input, location_input):
        sensory_emb = hk.Linear(self.hidden_size)(sensory_input.reshape(-1))
        location_emb = hk.Linear(self.hidden_size)(location_input)
        
        # Combine with nonlinear mixing instead of trivial addition
        sensory_location_emb = jax.nn.relu(
            hk.Linear(self.hidden_size)(jnp.concatenate([sensory_emb, location_emb], axis=-1))
        )

        
        # Residual MLP for binding "what" and "where"
        x = sensory_location_emb
        for _ in range(self.n_hidden):
            residual = x
            x = hk.Linear(self.hidden_size)(x)
            x = jax.nn.relu(x)
            x = residual + x
        
        # Final projections
        sensory_location_latent = hk.Linear(self.output_size)(x)
        return sensory_location_latent



In [None]:
class VisualCortex(hk.Module):
    NetName = 'VisiualCortex'
    # The core idea is learning the structure of things through sensory motor coupling
    # Every column has its own sensory-motor loop/system - Thousand Brains Theory
    # Reference frame, path integration
    # Columns gain predictive power after a large enough refrence frame is built
    # Brain Region -> Sensory Input, Motor Output
    def __init__(
            self,
            #cortical
            cortical_side_size,
            cortical_n_hidden,
            cortical_hidden_size,
            cortical_output_size,
            name='VisiualCortex'
            ):
        super().__init__(name=name)
        # Cortical Config
        self.cortical_side_size = cortical_side_size
        self.cortical_n_hidden = cortical_n_hidden
        self.cortical_hidden_size = cortical_hidden_size
        self.cortical_output_size = cortical_output_size
    def cortical_observation(
        self,
        cortical_input,
        side_size,
        n_hidden,
        hidden_size,
        output_size
    ):
        H, W, C = cortical_input.shape
        assert side_size // H == 0 and side_size < H, 'Cortical Side must perfectly fit'
        sensory_input = jnp.reshape(
            cortical_input,
            (-1, side_size, side_size, C)
        )
        location_input = patch_positions(
            input_shape=cortical_input.shape,
            patch_shape=(side_size, side_size, C)
        )
        #
        cortical_sheet_net = jax.vmap(CorticalColumn(
            n_hidden=n_hidden,
            hidden_size=hidden_size,
            output_size=output_size,
            enable_inhibition=True,
            name=f'CorticalColumn_{side_size}'
        ), in_axes=0)
        #
        sensory_location = cortical_sheet_net(
            sensory_input, location_input
        )
        return sensory_location
    def __call__(self, visual_input):
        """ The Inference of the visual cortex
        The visual cortex carries out a feature extraction process using crtical columns
        Args:
           visual_input: a tensor of dimentions: height, width, channels (32,32,3)
        returns:
            visual_features
        """  
        H, W, C = visual_input.shape
        assert H == W, "Must be a square observation"
        assert self.cortical_side_size // H == 0 and self.cortical_side_size < H, 'Cortical Side must perfectly fit'
        sensory_location = visual_input
        side_size = self.cortical_side_size
        n_hidden = self.cortical_n_hidden
        hidden_size = self.cortical_hidden_size
        output_size = self.cortical_output_size

        sensory_location = self.cortical_observation(
            sensory_location,
            side_size,
            n_hidden,
            hidden_size,
            output_size
        )            
            
        return sensory_location

In [None]:
class VisualCortexConstructor():
    def __init__(
        self,
        seed,
        # Inpus dims
        input_width_size,
        input_height_size,
        input_channel_size,
        #cortical
        cortical_side_size,
        cortical_n_hidden,
        cortical_hidden_size,
        cortical_output_size,
        # meta
        Logger,
        *args,
        **kwargs,
    ):
        """Testing the VisualCortex network
        """
        self.seed = seed
        # Input dimensions
        self.W = input_width_size
        self.H = input_height_size
        self.C = input_channel_size

        # Cortical Config
        self.cortical_side_size = cortical_side_size
        self.cortical_n_hidden = cortical_n_hidden
        self.cortical_hidden_size = cortical_hidden_size
        self.cortical_output_size = cortical_output_size
        
        # meta
        self.NetName = VisualCortex.NetName
        self.Logger = Logger

    def construct(self):
        def net_module(x):
            output = VisualCortex(
                #
                cortical_side_size=self.cortical_side_size,
                cortical_n_hidden=self.cortical_n_hidden,
                cortical_hidden_size=self.cortical_hidden_size,
                cortical_output_size=self.cortical_output_size,
            )(x)
            return output
        key = jax.random.PRNGKey(seed=self.seed)
        example_batch = jax.random.normal(key, (self.W, self.H, self.C))
        model_init, model_apply = hk.transform(net_module, apply_rng=True)
        model_params = model_init(key, example_batch)
        return model_init, model_apply, model_params

In [None]:
configs = { 
    "seed": 0,
    "logger": 'wandb',
    # Reference Frame configuration
    # Input configuration
    "input_width_size": 32,
    "input_height_size":32,
    "input_channel_size": 3,
    # Cortical Configuration
    "cortical_side_size": 4,
    "cortical_n_hidden": 3,
    "cortical_hidden_size": 8,
    "cortical_output_size": 16,
}
logger_config = {
    "name": 'wandb',
    "api_key": 'bd0584875dd3c52df37cbd4565c0e22319f9cef6',
    "mode": 'offline'

}
logger = WANDBLogger(**logger_config)
column_constructor = VisualCortexConstructor(
    **configs,
    Logger=logger
)
_, model_apply, model_param = column_constructor.construct()

In [None]:
jax.tree_util.tree_map(lambda x: x.shape, model_param)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import jax.numpy as jnp


def mse(a, b):
    return jnp.mean((a - b) ** 2)


def mirror_symmetry_loss(img):
    """
    Measures how different an image is from its mirror reflection.
    Lower = more symmetric.
    """
    flipped = jnp.flip(img, axis=1)  # horizontal flip
    return mse(img, flipped)


def rotate_180(img):
    """Rotate 180° using array reversal."""
    return jnp.flip(jnp.flip(img, axis=0), axis=1)


def rotational_symmetry_loss(img):
    """
    Measures how different an image is from its 180° rotated version.
    Lower = more symmetric.
    """
    rotated = rotate_180(img)
    return mse(img, rotated)


def composite_symmetry_loss(img, w_mirror=0.5, w_rot=0.5):
    """
    Weighted combination of mirror and rotational symmetry.
    """
    m_loss = mirror_symmetry_loss(img)
    r_loss = rotational_symmetry_loss(img)
    return w_mirror * m_loss + w_rot * r_loss

# Example normalization: loss -> score


def symmetry_score(img):
    loss = composite_symmetry_loss(img)
    score = jnp.exp(-5 * loss)  # higher = more symmetric, approx [0,1]
    return score


def load_image(path):
    img = Image.open(path).convert('L').resize((128, 128))
    arr = np.array(img) / 255.0
    return jnp.array(arr)


sym_img = load_image('sym.png')
asym_img = load_image('asym.png')

sym_score = symmetry_score(sym_img)
asym_score = symmetry_score(asym_img)

print("Symmetric image score:", float(sym_score))
print("Asymmetric image score:", float(asym_score))

# Visualization
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(sym_img, cmap='gray')
axs[0].set_title(f'Symmetric (score={sym_score:.3f})')
axs[1].imshow(asym_img, cmap='gray')
axs[1].set_title(f'Asymmetric (score={asym_score:.3f})')
plt.show()
