# Simulating the Cortex Model

This is the ipython notebook that supplements the tutorial [Cortex Simulation](https://matisse.eecs.berkeley.edu/tutorials/3_CortexSimulation.html).

Before running these codes, please make sure you have installed the required packages and set up the environment as described in the [Getting Started](https://matisse.eecs.berkeley.edu/tutorials/1_GettingStarted.html).

Also, make sure you have selected the correct python kernel (e.g. `MatisseEnv`).

---

## Setup

In [1]:
from helper import *

## Instantiate retina and cortex models

In [2]:
from Simulated.Retina.RetinaModel import RetinaModel
from Simulated.Cortex.CortexModel import CortexModel

# Load the default parameters for the trichromatic retina simulation
with open(f'{ROOT_DIR}/Experiment/Config/Default/LMS.yaml', 'r') as f:
    params = yaml.safe_load(f)

# Initialize the retina model
retina = RetinaModel(params, device=DEVICE)

# Initialize the cortex model
cortex = CortexModel(params, device=DEVICE)
cortex = torch.compile(cortex)

In [3]:
# You can change the example_image_path to the path of your own image
example_image_path = f'{ROOT_DIR}/Tutorials/data/sample_sRGB_image.png'
example_sRGB_image = load_sRGB_image(retina, example_image_path, params).to(DEVICE)

# retina.CST (color space transform) is used to convert the color space
# In this case, we convert the sRGB image to linsRGB, and then to LMS
example_linsRGB_image = retina.CST.sRGB_to_linsRGB(example_sRGB_image)
example_LMS_image = retina.CST.linsRGB_to_LMS(example_linsRGB_image)
example_LMS_image = example_LMS_image.unsqueeze(0).permute(0, 3, 1, 2)

with torch.no_grad(): # gradient computation is not needed for retina simulation
    list_of_retinal_responses = retina.forward(example_LMS_image, intermediate_outputs=True)
    optic_nerve_signals = list_of_retinal_responses[0]

optic_nerve_signals = optic_nerve_signals[:,:1]

In [None]:
learned_percepts = []


for num_gradient_updates in tqdm(range(0, 100001, 100), desc='Generating learned percepts... '):
    if os.path.exists(f'{ROOT_DIR}/Experiment/LearnedWeights/LMS/{num_gradient_updates}.pt'):

        # Load the pre-trained weights for the default cortex model
        cortex.load_state_dict(torch.load(f'{ROOT_DIR}/Experiment/LearnedWeights/LMS/{num_gradient_updates}.pt', weights_only=True, map_location=DEVICE))

        with torch.no_grad():
            warped_internal_percept = cortex.decode(optic_nerve_signals)

            # internal percept is N-channel image, where N is the latent dimension (N is formally defined in the paper)
            # We use the ns_ip module (neural scope for internal percept) to project the percept to the linsRGB space
            warped_internal_percept_linsRGB = cortex.ns_ip.forward(warped_internal_percept)

            # Then we use the retina.CST (color space transform) to convert the linsRGB space to the sRGB space
            warped_internal_percept_sRGB = retina.CST.linsRGB_to_sRGB(warped_internal_percept_linsRGB)

            # get_unwarped_percept is a helper function defined in the ipython notebook file
            internal_percept_sRGB = get_unwarped_percept(warped_internal_percept_sRGB, cortex)

            learned_percepts.append([num_gradient_updates, internal_percept_sRGB])

In [None]:
internal_percept_sRGB = get_unwarped_percept(warped_internal_percept_sRGB, cortex)

In [9]:
os.makedirs(f'{ROOT_DIR}/Tutorials/CortexSimulation/Results/LP', exist_ok=True)
# Visualize the learned percepts as a mp4 movie

for [num_gradient_updates, learned_percept] in learned_percepts:
    fig = plt.figure(figsize=(10,10))
    plt.imshow(learned_percept)
    plt.axis('off')
    plt.title(f'Learned Percept at {num_gradient_updates:06d} gradient updates')
    plt.tight_layout()
    plt.savefig(f'{ROOT_DIR}/Tutorials/CortexSimulation/Results/LP/learned_percept_{num_gradient_updates:06d}.png', bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
# Convert the images to a mp4 movie
images = [imageio.imread(f'{ROOT_DIR}/Tutorials/CortexSimulation/Results/LP/learned_percept_{timestep:06d}.png') for [timestep, _] in learned_percepts]
imageio.mimsave(f'{ROOT_DIR}/Tutorials/CortexSimulation/Results/learned_percept.gif', images, fps=10, loop=0)

subprocess.run(['ffmpeg', '-y', '-i', f'{ROOT_DIR}/Tutorials/CortexSimulation/Results/learned_percept.gif', '-vf', 'fps=10', '-c:v', 'libx264', '-pix_fmt', 'yuv420p', f'{ROOT_DIR}/Tutorials/CortexSimulation/Results/LP/learned_percept.mp4'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)

In [None]:
# Display the movie
IPython.display.Image(filename=f'{ROOT_DIR}/Tutorials/CortexSimulation/Results/learned_percept.gif')
