# Intro to 3D CNNs
A description and demo notebook to go through creating a 3D CNN and using it with dummy data

## 1. Setting up dummy data

Before learning how to use the cores, let's create a dummy data images. This data will be similar to a batch of images.

Throughout the notebook we will refer to the elements of this shape in the following manner:

[1] is the number of channels (can be input, hidden, output)

[50] is the number of frames representing time

[144] is the height of image or feature maps

[256] is the height of image or feature maps

[32] is the batch size, which is not as relevant for understanding the material in this notebook.

In [2]:
# To access to neuropixel_predictor
import sys
sys.path.append('../')

# Basic imports
import warnings
import random

# Essential imports
import numpy as np
import torch


images = torch.ones(32, 1, 10, 144, 256)

In [3]:
warnings.filterwarnings("ignore", category=UserWarning)
device = "cuda" if torch.cuda.is_available() else "cpu"
random_seed = 42

## 2. Using Stacked 3D Core

In [4]:
stacked3dcore_config = {
    # core args
    'input_channels': 1,
    'input_kernel': 7,
    'hidden_kernel': 5,
    'hidden_channels': 64,
    'layers': 2
}

In [5]:
from neuropixel_predictor.layers.cores import Basic3dCore 

stacked3d_core = Basic3dCore(**stacked3dcore_config)
stacked3d_core
# stacked3dcore_out = stacked3d_core(images)
# print(stacked3dcore_out.shape)

Basic3dCore(
  (_input_weight_regularizer): LaplaceL2norm(
    (laplace): Laplace()
  )
  (temporal_regularizer): DepthLaplaceL21d(
    (laplace): Laplace1d()
  )
  (features): Sequential(
    (layer0): Sequential(
      (conv): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1))
      (norm): BatchNorm3d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (nonlin): ELU(alpha=1.0)
    )
    (layer1): Sequential(
      (conv_1): Conv3d(64, 64, kernel_size=(5, 5, 5), stride=(1, 1, 1))
      (norm): BatchNorm3d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (nonlin): ELU(alpha=1.0)
    )
  )
) [Basic3dCore regularizers: gamma_input_spatial = 0|gamma_input_temporal = 0]

## 3. Using Readout to attach Factorized Readout

In [6]:
from neuropixel_predictor.layers.readouts import Gaussian3d, MultiReadoutBase

In [7]:
in_shapes_dict = {
    '21067-10-18': torch.Size([64, 144, 256]),
    '22846-10-16': torch.Size([64, 144, 256])
}

n_neurons_dict = {'21067-10-18': 8372, '22846-10-16': 7344}

In [9]:
gaussian_readout = MultiReadoutBase(
    in_shape_dict=in_shapes_dict,
    n_neurons_dict=n_neurons_dict,
    base_readout=Gaussian3d,
    bias=True,
)

## 4. Invoke core and readout

In [None]:
core_output = core(example_batch.images)
readout_output_sample = factorized_readout(core_output, data_key="21067-10-18")