<a href="https://colab.research.google.com/github/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D5_Microcircuits/student/W1D5_Tutorial1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> &nbsp; <a href="https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/neuromatch/NeuroAI_Course/main/tutorials/W1D5_Microcircuits/student/W1D5_Tutorial1.ipynb" target="_parent"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open in Kaggle"/></a>

# Tutorial 1: Sparsity and Sparse Coding

**Week 1, Day 5: Microcircuits**

**By Neuromatch Academy**

__Content creators:__ Noga Mudrik, Xaq Pitkow

__Content reviewers:__ Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk

__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk


___


# Tutorial Objectives

*Estimated timing of tutorial: 1 hour*

In this tutorial we will discuss the notion of sparsity. In particular, we will:

- Recognize various types of sparsity (population, lifetime, interaction).
- Relate sparsity to inductive bias, interpretability, and efficiency.


In [None]:
# @title Tutorial slides
# @markdown These are the slides for the videos in all tutorials today

# from IPython.display import IFrame
#link_id = "<YOUR_LINK_ID_HERE>"
# print(f"If you want to download the slides: https://osf.io/download/{link_id}/")
#IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{link_id}/?direct%26mode=render", width=854, height=480)

---
# Setup



In [None]:
# @title Install and import feedback gadget

# !pip3 install vibecheck datatops --quiet

# from vibecheck import DatatopsContentReviewContainer
# def content_review(notebook_section: str):
#     return DatatopsContentReviewContainer(
#         "",  # No text prompt - leave this as is
#         notebook_section,
#         {
#             "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
#             "name": "sciencematch_sm", # change the name of the course : neuromatch_dl, climatematch_ct, etc
#             "user_key": "y1x3mpx5",
#         },
#     ).render()

# feedback_prefix = "W1D5_T1"

In [None]:
# @title Imports

#working with data
import numpy as np
import pandas as pd
from scipy.stats import kurtosis

#plotting
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import matplotlib.patheffects as path_effects

#interactive display
import ipywidgets as widgets
from ipywidgets import interact, IntSlider
from tqdm.notebook import tqdm as tqdm

#modeling
from sklearn.datasets import make_sparse_coded_signal, make_regression
from sklearn.decomposition import DictionaryLearning, PCA
from sklearn.linear_model import OrthogonalMatchingPursuit
import tensorflow as tf

#utils
import os
import warnings
warnings.filterwarnings("ignore")

In [None]:
# @title Figure settings

logging.getLogger('matplotlib.font_manager').disabled = True
sns.set_context('talk')

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

In [None]:
# @title Plotting functions

def show_slice(slice_index):
    """
    Plot one slide of sequential data.

    Inputs:
    - slice_index (int): index of the slide to plot.
    """
    with plt.xkcd():
        plt.figure(figsize=(6, 6))
        plt.imshow(data[slice_index])
        ind = (66,133)
        plt.scatter([ind[1]], [ind[0]], facecolors='none', edgecolors='r', marker='s', s = 100, lw = 4)
        plt.axis('off')
        plt.show()

def remove_edges(ax, include_ticks = True, top = False, right = False, bottom = True, left = True):
    ax.spines['top'].set_visible(top)
    ax.spines['right'].set_visible(right)
    ax.spines['bottom'].set_visible(bottom)
    ax.spines['left'].set_visible(left)
    if not include_ticks:
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])

def add_labels(ax, xlabel='X', ylabel='Y', zlabel='', title='', xlim = None, ylim = None, zlim = None,xticklabels = np.array([None]),
               yticklabels = np.array([None] ), xticks = [], yticks = [], legend = [],
               ylabel_params = {'fontsize':19},zlabel_params = {'fontsize':19}, xlabel_params = {'fontsize':19},
               title_params = {'fontsize':29}, format_xticks = 0, format_yticks = 0):
  """
  This function add labels, titles, limits, etc. to figures;
  Inputs:
      ax      = the subplot to edit
      xlabel  = xlabel
      ylabel  = ylabel
      zlabel  = zlabel (if the figure is 2d please define zlabel = None)
      etc.
  """
  if xlabel != '' and xlabel != None: ax.set_xlabel(xlabel, **xlabel_params)
  if ylabel != '' and ylabel != None:ax.set_ylabel(ylabel, **ylabel_params)
  if zlabel != '' and zlabel != None:ax.set_zlabel(zlabel,**zlabel_params)
  if title != '' and title != None: ax.set_title(title, **title_params)
  if xlim != None: ax.set_xlim(xlim)
  if ylim != None: ax.set_ylim(ylim)
  if zlim != None: ax.set_zlim(zlim)

  if (np.array(xticklabels) != None).any():
      if len(xticks) == 0: xticks = np.arange(len(xticklabels))
      ax.set_xticks(xticks);
      ax.set_xticklabels(xticklabels);
  if (np.array(yticklabels) != None).any():
      if len(yticks) == 0: yticks = np.arange(len(yticklabels)) +0.5
      ax.set_yticks(yticks);
      ax.set_yticklabels(yticklabels);
  if len(legend)       > 0:  ax.legend(legend)

def plot_signal(signal, title = "Pixel's activity over time", ylabel = '$pixel_t$'):
    """
    Plot the given signal over time.

    Inputs:
    - signal (np.array): given signal.
    - title (str, default = "Pixel's activity over time"): title to give to the plot.
    - ylabel (str, default = '$pixel_t$'): y-axis label.
    """
    with plt.xkcd():
        fig, ax = plt.subplots(1,1, figsize = (30,10), sharex = True)
        ax.plot(signal, lw = 5)
        ax.set_xlim(left = 0)
        ax.set_ylim(bottom = 0)
        add_labels(ax, xlabel = 'Time (Frames)',ylabel = ylabel, title = title)
        remove_edges(ax)
        plt.show()

def plot_relu_signal(signal, theta = 0):
    """
    Plot the given signal over time and its thresholded value with the given theta.

    Inputs:
    - signal (np.array): given signal.
    - theta (float, default = 0): threshold parameter.
    """
    with plt.xkcd():
        fig, ax = plt.subplots(1,1, figsize = (30,10), sharex = True)
        thres_x = ReLU(signal, theta)
        ax.plot(signal, lw = 5)
        ax.plot(thres_x, lw = 5)
        ax.set_xlim(left = 0)
        ax.legend(['Signal', '$ReLU_{%d}$(signal)'%theta], ncol = 2)
        add_labels(ax, xlabel = 'Time', ylabel = 'Signal')
        remove_edges(ax)
        plt.show()

def plot_relu_histogram(signal, theta = 0):
    """
    Plot histogram of the values in the signal before and after applying ReLU operation with the given threshold.

    Inputs:
    - signal (np.array): given signal.
    - theta (float, default = 0): threshold parameter.
    """
    with plt.xkcd():
        fig, axs = plt.subplots(1,2,figsize = (15,10), sharex = True, sharey = True)
        thres_x = ReLU(signal, theta)
        axs[0].hist(sig, bins = 100)
        axs[1].hist(thres_x, bins = 100)
        [remove_edges(ax) for ax in axs]
        [add_labels(ax, ylabel = 'Count', xlabel = 'Value') for ax in axs]
        [ax.set_title(title) for title, ax in zip(['Before Thresholding', 'After Thresholding'], axs)]
    plt.show()

def plot_relu_signals(signal, theta_values):
    """
    Plot the given signal over time and its thresholded value with the given theta values.

    Inputs:
    - signal (np.array): given signal.
    - theta_values (np.array): threshold parameter.
    """
    #define colormap
    cmap_name = 'viridis'
    samples = np.linspace(0, 1, theta_values.shape[0])
    colors = plt.colormaps[cmap_name](samples)

    fig, ax = plt.subplots(1,1, figsize = (30,10), sharex = True)
    for counter, theta in enumerate(theta_values):
      ax.plot(ReLU(signal, theta), label = '$\\theta = %d$'%theta, color = colors[counter])
    ax.set_xlim(left = 0)
    ax.legend(ncol = 5)
    add_labels(ax, xlabel = 'Time', ylabel = '$ReLU_{\\theta}$(Signal)')
    remove_edges(ax)

def plot_images(images):
    """
    Plot given images.

    Inputs:
    - images (list): list of 2D np.arrays which represent images.
    """
    with plt.xkcd():
        fig, ax = plt.subplots(1, len(images), figsize = (20,8))
        if len(images) == 1:
            ax.imshow(images[0])
        else:
            for index, image in enumerate(images):
                ax[index].imshow(image)
    plt.show()

def plot_labeled_kurtosis(frame, frame_HT, labels = ['Frame', 'HT(frame)']):
    """
    Plot kurtosis value for the frame before and after applying hard threshold operation.

    Inputs:
    - frame (np.array): given image.
    - frame_HT (np.array): thresholded version of the given image.
    - labels (list): list of labels to apply for the igven data.
    """
    with plt.xkcd():
        fig, ax = plt.subplots()
        pd.DataFrame([kurtosis(frame.flatten()), kurtosis(frame_HT.flatten())],index = labels, columns = ['kurtosis']).plot.bar(ax = ax, alpha = 0.5, color = 'purple')
        remove_edges(ax)
    plt.show()

def plot_temporal_difference_histogram(signal, temporal_diff):
    """
    Plot histogram for the values of the given signal as well as for its temporal differenced version.

    Inputs:
    - signal (np.array): given signal.
    - temporal_diff (np.array): temporal differenced version of the signal.
    """
    with plt.xkcd():
        fig, axs = plt.subplots(1,2,figsize = (15,5), sharex = True, sharey = True)
        axs[0].hist(signal, bins = 100);
        axs[1].hist(temporal_diff, bins = 100);
        [remove_edges(ax) for ax in axs]
        [add_labels(ax, ylabel = 'Count', xlabel = 'Value') for ax in axs]
        [ax.set_title(title) for title, ax in zip(['Pixel \n Before Diff.', 'Frame \n After Diff.'], axs)]
        for line in axs[0].get_children():
            line.set_path_effects([path_effects.Normal()])
        for line in axs[1].get_children():
            line.set_path_effects([path_effects.Normal()])
        plt.show()

def plot_temp_diff_histogram(signal, taus, taus_list):
    """
    Plot the histogram for the given signal over time and its temporal differenced versions for different values of lag \tau.

    Inputs:
    - signal (np.array): given signal.
    - taus (np.array): array of tau values (lags).
    - taus_list (list): temporal differenced versions of the given signal.
    """
    #define colormap
    cmap_name = 'cool'
    samples = np.linspace(0, 1, taus.shape[0])
    colors = plt.colormaps[cmap_name](samples)

    with plt.xkcd():

        # histograms
        bin_edges = np.arange(0, 256, 5)  # Define bin edges from 0 to 255 with a step of 5

        # Compute histogram values using custom bin edges
        hist = [np.histogram(tau_list, bins=bin_edges)[0] for tau, tau_list in zip(taus, taus_list)]


        fig, ax = plt.subplots(figsize = (20,5))
        [ax.plot(bin_edges[:-1]*0.5 + bin_edges[1:]*0.5, np.vstack(hist)[j], marker = 'o', color = colors[j], label = '$\\tau = %d$'%taus[j]) for j in range(len(taus))]
        ax.legend(ncol = 2)
        remove_edges(ax)
        ax.set_xlim(left = 0, right = 100)
        add_labels(ax, xlabel = '$\\tau$', ylabel = 'Count')
    plt.show()

def plot_temp_diff_separate_histograms(signal, lags, lags_list, tau = True):
    """
    Plot the histogram for the given signal over time and its temporal differenced versions for different values of lag \tau or windows.

    Inputs:
    - signal (np.array): given signal.
    - lags (np.array): array of lags (taus or windows).
    - lags_list (list): temporal differenced versions of the given signal.
    - tau (bool, default = True): which regime to use (tau or window).
    """
    with plt.xkcd():
        cmap_name = 'cool'
        samples = np.linspace(0, 1, taus.shape[0])
        colors = plt.colormaps[cmap_name](samples)

        fig, axs = plt.subplots(2, int(0.5*(2+lags.shape[0])),figsize = (25,19), sharex = True, sharey = False)
        axs = axs.flatten()
        axs[0].hist(signal, bins = 100, color = 'black');

        if tau:
            # histograms
            bin_edges = np.arange(0, 256, 5)  # Define bin edges from 0 to 255 with a step of 5

            # Compute histogram values using custom bin edges
            hist = [np.histogram(lag_list, bins=bin_edges)[0] for lag, lag_list in zip(lags, lags_list)]

            [axs[j+1].bar(bin_edges[:-1]*0.5 + bin_edges[1:]*0.5, np.abs( np.vstack(hist)[j]), color = colors[j]) for j in range(len(lags))]

        else:
            [axs[j+1].hist(np.abs(signal - diff_box_values_i), bins = 100, color = colors[j]) for j, diff_box_values_i in enumerate(lags_list)]

        [remove_edges(ax) for ax in axs]
        [add_labels(ax, ylabel = 'Count', xlabel = 'Value') for ax in axs]
        axs[0].set_title('Pixel \n Before Diff.');
        if tau:
            [ax.set_title( '$\\tau =$ %.2f'%lags[j]) for  j, ax in enumerate(axs[1:]) if j < lags.shape[0]]
        else:
            [ax.set_title( 'Window %d'%lags[j]) for  j, ax in enumerate(axs[1:]) if j < lags.shape[0]]

        for ax in axs:
            for line in ax.get_children():
                line.set_path_effects([path_effects.Normal()])
            for line in ax.get_children():
                line.set_path_effects([path_effects.Normal()])
    plt.show()

def plot_temp_diff_kurtosis(signal, lags, lags_list, tau = True):
    """
    Plot the kurtosis for the given signal over time and its temporal differenced versions for different values of lag \tau or windows.

    Inputs:
    - signal (np.array): given signal.
    - lags (np.array): array of lags (taus or windows).
    - lags_list (list): temporal differenced versions of the given signal.
    - tau (bool, default = True): which regime to use (tau or window).
    """
    with plt.xkcd():
        if tau:
            fig, ax = plt.subplots(figsize = (10,3))
            tauskur = [kurtosis(tau_i) for tau_i in lags_list]
            pd.DataFrame([kurtosis(signal)] + tauskur, index = ['Signal']+ ['$\\tau = {%d}$'%tau for tau in lags], columns = ['kurtosis']).plot.barh(ax = ax, alpha = 0.5, color = 'purple')
            remove_edges(ax)
        else:
            fig, ax = plt.subplots(figsize = (7,7))
            tauskur = [kurtosis(np.abs(signal - diff_box_values_i)) for diff_box_values_i in lags_list]
            pd.DataFrame([kurtosis(signal)] + tauskur, index = ['Signal']+ ['$window = {%d}$'%tau for tau in lags], columns = ['kurtosis']).plot.bar(ax = ax, alpha = 0.5, color = 'purple')
            remove_edges(ax)

def plot_diff_box(signal, filter, diff_box_signal):
    """
    Plot signal, the window function and the resulted convolution.

    Inputs:
    - signal (np.array): the given signal.
    - filter (int): size of the window function.
    - diff_box_signal (np.array): the resulted signal.
    """


    with plt.xkcd():
            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 4))
            ax[0].plot(signal)
            ax[0].set_title('Signal')
            ax[1].plot(filter)
            ax[1].set_title('Filter Function')
            ax[2].plot(diff_box_signal)
            ax[2].set_title('diff_box Signal')
            plt.subplots_adjust(wspace=0.3)
            [ax_i.set_xlabel('Time') for ax_i in ax]
    plt.show()

def plot_diff_with_diff_box(signal, windows, diff_box_values):
    """
    Plot difference between given signal and diff_box ones for different windows.

    Inputs:
    - signal (np.array): given signal.
    - windows (np.array): list of window sizes.
    - diff_box_values (list): list for diff_box versions of the signal.
    """
    with plt.xkcd():
        cmap_name = 'cool'
        samples = np.linspace(0, 1, windows.shape[0])
        colors = plt.colormaps[cmap_name](samples)

        fig, ax = plt.subplots(figsize = (20,10))
        ax.plot(signal, label = "Signal")
        [ax.plot(signal - diff_box_values_i, color = colors[j], label = 'window = %d'%windows[j]) for j, diff_box_values_i in enumerate(diff_box_values)]
        ax.legend(ncol = 4)
        remove_edges(ax)
        ax.set_xlim(left = 0, right = 100)
        add_labels(ax, xlabel = 'Time (frame)', ylabel = '$\Delta(pixel)$')
    plt.show()

def plot_spatial_diff(frame, diff_x,  diff_y):
    """
    Plot spatial differentiation of the given image with lag one.

    Inputs:
    - frame (np.array): given 2D signal (image).
    - diff_x (np.array): spatial difference along x-axis.
    - diff_y (np.array): spatial difference along y-axis.
    """

    with plt.xkcd():
        fig, axs = plt.subplots(1,3, figsize = (50,10))
        diff_x_norm = (diff_x - np.percentile(diff_x, 10))/(np.percentile(diff_x, 90) - np.percentile(diff_x, 10))
        diff_x_norm = diff_x_norm*255
        diff_x_norm[diff_x_norm > 255] = 255
        diff_x_norm[diff_x_norm < 0] = 0
        axs[0].imshow(diff_x_norm.astype(np.uint8))
        diff_y_norm = (diff_y - np.percentile(diff_y, 10))/(np.percentile(diff_y, 90) - np.percentile(diff_y, 10))
        diff_y_norm = diff_y_norm*255
        diff_y_norm[diff_y_norm > 255] = 255
        diff_y_norm[diff_y_norm < 0] = 0
        axs[1].imshow(diff_y_norm.astype(np.uint8))
        axs[2].imshow(frame)
        [ax.set_xticks([]) for ax in axs]
        [ax.set_yticks([]) for ax in axs]
        [ax.set_title(title, fontsize = 40) for title, ax in zip(['$\Delta x$', '$\Delta y$', 'Original'], axs)];

def plot_spatial_diff_histogram(taus, taus_list_x, taus_list_y):
    """
    Plot histograms for each of the spatial differenced version of the signal.
    """

    with plt.xkcd():
        cmap_name = 'cool'
        samples = np.linspace(0, 1, taus.shape[0])
        colors = plt.colormaps[cmap_name](samples)

        bin_edges = np.arange(0, 256, 5)  # Define bin edges from 0 to 255 with a step of 5

        hist_x = [np.histogram(tau_list.flatten() ,  bins=bin_edges)[0] for tau, tau_list in zip(taus, taus_list_x )]
        hist_y = [np.histogram(tau_list.flatten(),  bins=bin_edges)[0] for tau, tau_list in zip(taus, taus_list_y)]

        fig, axs = plt.subplots(2,1,figsize = (20,9))
        ax = axs[0]
        [ax.plot(bin_edges[:-1]*0.5 + bin_edges[1:]*0.5, np.vstack(hist_x)[j], marker = 'o', color = colors[j], label = '$\\tau = %d$'%taus[j]) for j in range(len(taus))]
        ax.set_yscale('log')
        ax.legend(ncol = 2)
        remove_edges(ax)
        ax.set_xlim(left = 0, right = 100)
        add_labels(ax, xlabel = '$\\tau_x$', ylabel = 'Count', title = 'Diff. in $x$')

        ax = axs[1]
        [ax.plot(bin_edges[:-1]*0.5 + bin_edges[1:]*0.5, np.vstack(hist_y)[j], marker = 'o', color = colors[j], label = '$\\tau = %d$'%taus[j]) for j in range(len(taus))]
        ax.set_yscale('log')
        ax.legend(ncol = 2)
        remove_edges(ax)
        ax.set_xlim(left = 0, right = 100)
        add_labels(ax, xlabel = '$\\tau_y$', ylabel = 'Count', title = 'Diff. in $y$')

        fig.tight_layout()
    plt.show()

def plot_spatial_kurtosis(frame, diff_x, diff_y):
    """
    Plot kurtosis for the signal and its spatial differenced version.
    """
    with plt.xkcd():
        fig, ax = plt.subplots()
        pd.DataFrame([kurtosis(frame.flatten()), kurtosis(diff_x.flatten()), kurtosis(diff_y.flatten())],index = ['Signal', 'Diff $x$', 'Diff $y$'], columns = ['kurtosis']).plot.barh(ax = ax, alpha = 0.5, color = 'purple')
        remove_edges(ax)
    plt.show()

def plot_spatial_histogram(frame, diff_x, diff_y):
    """
    Plot histogram for values in frame and differenced versions.
    """
    with plt.xkcd():
        fig, axs = plt.subplots(1,3,figsize = (25,10), sharex = True, sharey = True)
        axs[0].hist(np.abs(frame.flatten()), bins = 100);
        axs[1].hist(np.abs(diff_x.flatten()), bins = 100);
        axs[2].hist(np.abs(diff_y.flatten()), bins = 100);
        [remove_edges(ax) for ax in axs]
        [add_labels(ax, ylabel = 'Count', xlabel = 'Value') for ax in axs]
        [ax.set_title(title) for title, ax in zip(['Frame \n Before Diff.', 'Frame \n After Diff. x', 'Frame \n After Diff. y'], axs)]

def visualize_images_diff_box(frame, diff_box_values_x, diff_box_values_y):
    """
    Plot images with diff_box difference method.
    """
    with plt.xkcd():
        cmap_name = 'cool'
        samples = np.linspace(0, 1, num_winds)
        colors = plt.colormaps[cmap_name](samples)
        fig, axs = plt.subplots( 2, int(0.5*(2+ len(diff_box_values_x))),  figsize = (50,15))
        axs = axs.flatten()
        axs[0].imshow(frame)
        [axs[j+1].imshow(normalize(frame - diff_box_values_i)) for j, diff_box_values_i in enumerate(diff_box_values_x)]
        remove_edges(axs[-1], left = False, bottom = False, include_ticks=  False)
        fig.suptitle('x diff')

        fig, axs = plt.subplots( 2, int(0.5*(2+ len(diff_box_values_x))),  figsize = (50,15))
        axs = axs.flatten()
        axs[0].imshow(frame)
        [axs[j+1].imshow(normalize(frame.T - diff_box_values_i).T) for j, diff_box_values_i in enumerate(diff_box_values_y)]
        remove_edges(axs[-1], left = False, bottom = False, include_ticks=  False)
        fig.suptitle('y diff')

def create_legend(dict_legend, size = 30, save_formats = ['.png','.svg'],
                      save_addi = 'legend' , dict_legend_marker = {},
                      marker = '.', style = 'plot', s = 500, plot_params = {'lw':5},
                      params_leg = {}):
    fig, ax = plt.subplots(figsize = (5,5))
    if style == 'plot':
        [ax.plot([],[],
                     c = dict_legend[area], label = area, marker = dict_legend_marker.get(area), **plot_params) for area in dict_legend]
    else:
        if len(dict_legend_marker) == 0:
            [ax.scatter([],[], s=s,c = dict_legend.get(area), label = area, marker = marker, **plot_params) for area in dict_legend]
        else:
            [ax.scatter([],[], s=s,c = dict_legend[area], label = area, marker = dict_legend_marker.get(area), **plot_params) for area in dict_legend]
    ax.legend(prop = {'size':size},**params_leg)
    remove_edges(ax, left = False, bottom = False, include_ticks = False)

def plot_kurtosis(theta_value):
    """
    Plot kurtosis value for the signal before and after applying ReLU operation with the given threshold value.

    Inputs:
    - theta_value (int):  threshold parameter value.
    """
    with plt.xkcd():
        fig, ax = plt.subplots()
        relu = kurtosis(ReLU(sig, theta_value))
        pd.DataFrame([kurtosis(sig)] + [relu], index = ['Signal']+ ['$ReLU_{%d}$(signal)'%theta_value], columns = ['kurtosis']).plot.bar(ax = ax, alpha = 0.5, color = 'purple')
        remove_edges(ax)
        plt.show()

In [None]:
# @title Helper functions

def normalize(mat):
    """
    Normalize input matrix from 0 to 255 values (in RGB range).

    Inputs:
    - mat (np.ndarray): data to normalize.

    Outpus:
    - (np.ndarray): normalized data.
    """
    mat_norm = (mat - np.percentile(mat, 10))/(np.percentile(mat, 90) - np.percentile(mat, 10))
    mat_norm = mat_norm*255
    mat_norm[mat_norm > 255] = 255
    mat_norm[mat_norm < 0] = 0
    return mat_norm

def lists2list(xss):
    """
    Flatten a list of lists into a single list.

    Inputs:
    - xss (list): list of lists. The list of lists to be flattened.

    Outputs:
    - (list): The flattened list.
    """
    return [x for xs in xss for x in xs]

In [None]:
# @title Data retrieval

import os
import requests
import hashlib

# Variables for file and download URL
fnames = ["frame1.npy", "sig.npy", "reweight_digits.npy", "model.npy", "video_array.npy"] # The names of the files to be downloaded
urls = ["https://osf.io/n652y/download", "https://osf.io/c9qxk/download", "https://osf.io/ry5am/download", "https://osf.io/uebw5/download", "https://osf.io/t9g2m/download"] # URLs from where the files will be downloaded
expected_md5s = ["6ce619172367742dd148cc5830df908c", "f3618e05e39f6df5997f78ea668f2568", "1f2f3a5d08e13ed2ec3222dca1e85b60", "ae20e6321836783777c132149493ec70", "bbd1d73eeb7f5c81768771ceb85c849e"] # MD5 hashes for verifying files integrity

for fname, url, expected_md5 in zip(fnames, urls, expected_md5s):
    if not os.path.isfile(fname):
        try:
            # Attempt to download the file
            r = requests.get(url) # Make a GET request to the specified URL
        except requests.ConnectionError:
            # Handle connection errors during the download
            print("!!! Failed to download data !!!")
        else:
            # No connection errors, proceed to check the response
            if r.status_code != requests.codes.ok:
                # Check if the HTTP response status code indicates a successful download
                print("!!! Failed to download data !!!")
            elif hashlib.md5(r.content).hexdigest() != expected_md5:
                # Verify the integrity of the downloaded file using MD5 checksum
                print("!!! Data download appears corrupted !!!")
            else:
                # If download is successful and data is not corrupted, save the file
                with open(fname, "wb") as fid:
                    fid.write(r.content) # Write the downloaded content to a file

In [None]:
# @title Set random seed

import random
import numpy as np

def set_seed(seed=None, seed_torch=True):
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)

set_seed(seed = 42)

# Introdcution:

### Sparsity in Neuroscience and Artificial Intelligence

Sparsity means rare, and refers to the use of minimal resources or activations to achieve efficient representation. In both neuroscience and artificial intelligence (AI), leveraging sparsity can significantly enhance performance and efficiency.

### In neuroscience

In neuroscience, sparsity allows for economical use of neural resources, improving cognitive functions and energy efficiency. Neurons operate under a sparse regime, which facilitates clearer signal processing and robustness against noise, crucial for effective communication and learning in the brain.

### In AI

Similarly, in AI, sparsity is employed to simplify models while maintaining or enhancing predictive accuracy. Sparse architectures in neural networks reduce computational costs and memory usage, which is vital for processing large datasets and deploying AI on hardware with limited resources.

### Sparse connections sparse activity

The concept of sparsity is subdivided into two main types: **sparse connections** and **sparse activity**.

**Sparse connections** refer to the selective interaction between neurons or nodes, reducing the number of active links in a network and thereby focusing processing power where it is the most needed.

**Sparse activity**, on the other hand, involves activating only a small number of components at any given time, which helps in minimizing energy consumption and focusing computational efforts on the most salient features of the data.


### Why sparsity?
Understanding and implementing sparsity in both fields can lead to more efficient neural and computational models---mimicking the high efficiency seen in natural neural processes and advancing the capabilities of artificial systems.

### How can we measure sparsity?
- **$\ell_0$ pseudo-norm** - the count of non-zero values. However, this is not a true norm because it does not satisfy norm properties, such as scalability and the triangle inequality, making it non-convex and difficult to work with. For practical purposes, $\ell_1$ norm is often used as a proxy.

- **$\ell_1$ norm** - the sum of the absolutute values.

- **Kurtosis** - another sparsity measure by quantifying the "tailedness" of the distribution, with higher kurtosis values indicating more data points concentrated in the tail, suggesting greater sparsity.

![Sparsity notion visualization.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D5_Microcircuits/static/sparsity.png?raw=true)

--------------------------
# Section 1: Natural sparsity

**Under what scenarios do we encounter sparsity in real life?**

In this first section, we will explore various contexts and methods through which natural sparsity manifests in real-world signals.

In [None]:
# @title Video 1: Introduction to sparsity

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'TQdxKttMYlc')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

## Coding Exercise 1

In this exercise, we will understand the causes of sparsity by measuring sparsity under different operations of
nonlinearity, e.g. thresholding. We will create temporal sparsity in 1D time series (e.g. piecewise constant) by feedforward inhibition (or temporal differencing) as well as we will create population sparsity by lateral inhibition on images.

For the first exercise, we will analyze a video of a bird in San Francisco to extract both temporal and spatial sparsity. You can navigate through the video using the slider provided below.

Specifically, to explore temporal sparsity (i.e., sparsity observable over time), we will focus on the activity of a particular pixel, the one marked in red, across various frames.

In [None]:
# @markdown Execute the cell to see the interactive widget!

data = np.load('video_array.npy')
slider = IntSlider(min=0, max=data.shape[0] - 1, step=1, value=0, description= 'Time Point')
interact(show_slice, slice_index=slider)

Now, let's take a look at the activity of the marked pixel over time and visualize it.

In [None]:
# @markdown Plot of change in pixel's value over time

sig = np.load('sig.npy')
plot_signal(sig)

Write a function called "**ReLU**" that receives 2 inputs:

1) $x$ - a 1d numpy array of $p$ floats.

2) $\theta$ - a scalar.

The function should return a numpy array called `thres_x` such that:

$$
\text{thres-x}_j =
\begin{cases}
x_j - \theta & \text{if } x_j \geq \theta \\
0 & \text{otherwise}
\end{cases}
$$
for every element $j\in [1,J]$.

Apply the ReLU  function to the signal "**sig**" with a threshold of $\theta =  150$.

In [None]:
def ReLU(x, theta = 0):
    """
    Calculates ReLU function for the given level of theta.

    Inputs:
    - x (np.ndarray): input data.
    - theta (float, default = 0): threshold parameter.

    Outputs:
    - thres_x (np.ndarray): filtered values.
    """
    ###################################################################
    ## Fill out the following then remove
    raise NotImplementedError("Student exercise: complete `thres_x` array calculation as defined.")
    ###################################################################

    thres_x = ...

    return thres_x

In [None]:
#to_remove solution
def ReLU(x, theta = 0):
    """
    Calculates ReLU function for the given level of theta.

    Inputs:
    - x (np.ndarray): input data.
    - theta (float, default = 0): threshold parameter.

    Outputs:
    - thres_x (np.ndarray): filtered values.
    """

    thres_x = np.maximum(x - theta, 0)

    return thres_x

In [None]:
# @markdown Plot your results

plot_relu_signal(sig, theta = 150)

Let us also take a look at the aggregated plot which takes threshold parameter values from 0 to 240 with step 20 and plots ReLU version of the signal for each of them.

In [None]:
# @markdown Threshold value impact on ReLU version of the signal

theta_values = np.arange(0, 240, 20)
plot_relu_signals(sig ,theta_values)

Finally, let's calculate kurtosis value to estimate the sparsity of the signal compared to its version passed through the ReLU function.

Try to gradually increase the threshold parameter ($\theta$) from 0 to 240 in intervals of 20 and plot the result for each value.

In [None]:
# @markdown Kurtosis value comparison

slider = IntSlider(min=0, max=240, step=20, value=0, description='Threshold')
interact(plot_kurtosis, theta_value = slider)

---
# Section 2: Temporal differentiation

*Estimated timing to here from start of tutorial: 15 minutes.*

In this section, you will take a look at the temporal differentaition of a natural signal and you are going to quantify its sparsity.


In [None]:
# @title Video 2: Temporal differentiation

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'JwltmHZzIYg')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

## Coding Exercise 2

Now, we will focus on the temporal differences of a natural signal. Particularly, let the pixel value at time $t$ be $pixel_t$ for all $t = 1\dots T$. We look on:
$$\Delta(pixel_t) = pixel_t - pixel_{t-1} \quad  \forall \quad t = 2\dots T$$

### Coding Exercise 2A: Temporal differencing signal

Define a temporal differencing signal (abs. value) called `temporal_diff`. You can use the built-in `np.diff` function. That being said, `temporal_diff` calculation consists of 2 steps:
- apply `np.diff` on the signal `sig` (to obtain difference between consecutive values);
- apply `np.abs` on the result to get absolute values.


In [None]:
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete temporal differentiation.")
###################################################################
temporal_diff = ...
plot_signal(temporal_diff, title = "", ylabel = "$| pixel_t - pixel_{t-1} | $")

In [None]:
#to_remove solution
temporal_diff = np.abs(np.diff(sig))
plot_signal(temporal_diff, title = "", ylabel = "$| pixel_t - pixel_{t-1} | $")

Let's take a look at the histogram of the temporal difference values as well as kurtosis values.

In [None]:
# @markdown Histogram plots for pixel values for the signal and its differenced version
plot_temporal_difference_histogram(sig, temporal_diff)

In [None]:
# @markdown Kurtosis values for the signal and its differenced version

plot_labeled_kurtosis(sig, temporal_diff, labels = ['Signal', 'Temporal \n Diff.'])

### Coding Exercise 2B: Increasing the temporal differcing intervals

So far, we focused on single intervals, i.e. $\Delta_t = pixel_t - pixel_{t-1}$.

What happen if we look on longer-scale differences? I.e. $$\Delta_t(\tau) = pixel_t - pixel_{t-\tau}$$ for some $\tau > 1$?

In this exercise, we will explore the effects of increasing $\tau$ values on the sparsity of the temporal differentiation signal.

1) Create an array of 10 different $\tau$ values: $taus = [1, 11, 21... , 91]$.

2) Create a list called **taus_list** composed of of 10 arrays where the $j$-th array is the temporal differetiation with an interval equal to the $j$-th $\tau$.

3) Plot a heatmap that compares the histograms of the different taus.

Pay attention: here, it is NOT recommended to use the built-in `np.diff` function.

In [None]:
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete calcualtion of `taus` and `taus_list`.")
###################################################################
num_taus = 10

# create taus
taus = np.linspace(1, 91, ...).astype(int)

# create taus_list
taus_list = [np.abs(sig[...:] - sig[:-tau]) for tau in taus]

In [None]:
#to_remove solution
num_taus = 10

# create taus
taus = np.linspace(1,91,num_taus).astype(int)

# create taus_list
taus_list = [np.abs(sig[tau:] - sig[:-tau]) for tau in taus]

In [None]:
# @markdown Plot your results
plot_temp_diff_histogram(sig, taus, taus_list)

Now, let us take a look at the histograms for each $\tau$ separately as well as for the kurtosis values.

In [None]:
# @markdown Histogram plots for different values of tau
plot_temp_diff_separate_histograms(sig, taus, taus_list)

In [None]:
# @markdown Kurtosis value comparison for different values of tau
plot_temp_diff_kurtosis(sig, taus, taus_list)

### Exploring temporal differencing with box filter

While various types of filters could be considered, here, we will focus on the box moving average filter.


First, let's define the `diff_box` function, which accepts a signal and the size of the window as inputs. Internally, it performs a simple convolution between the signal and an exponential window, where the size of the window determines the size of the array used for the convolution. Visually observe the results for `window = 10`. We will explore temporal differences by comparing the raw signal with its `diff_box` - processed version for various window sizes.

In [None]:
# @markdown Filter visualization

def diff_box(data, window, pad_size = 4):
    filter = np.concatenate([np.repeat(0,pad_size), np.repeat(0,window), np.array([1]), np.repeat(-1,window), np.repeat(0,pad_size)]).astype(float)
    filter /= np.sum(filter**2)**0.5

    filter_plus_sum =  filter[filter > 0].sum()
    filter_min_sum = np.abs(filter[filter < 0]).sum()
    filter[filter > 0] *= filter_min_sum/filter_plus_sum
    diff_box = np.convolve(data, filter, mode='full')[:len(data)]
    diff_box[:window] = diff_box[window]
    return diff_box, filter

window = 10
diff_box_signal, filter = diff_box(sig, window)

with plt.xkcd():
    fig, ax = plt.subplots()
    plot_e1 = np.arange(len(filter))
    plot_e2 = np.arange(len(filter)) + 1
    plot_edge_mean = 0.5*(plot_e1 + plot_e2)
    plot_edge = lists2list( [[e1 , e2] for e1 , e2 in zip(plot_e1, plot_e2)])
    ax.plot(plot_edge, np.repeat(filter, 2), alpha = 0.3, color = 'purple')
    ax.scatter(plot_edge_mean, filter, color = 'purple')
    add_labels(ax,ylabel = 'filter value', title = 'box filter', xlabel = 'Frame')

#### Discussion

1. Why do you think the filter is assymetric?
2. How might a filter influence the sparsity patterns observed in data?

In [None]:
#to_remove explanation

"""
Discussion: 1. Why do you think the filter is assymetric?
2. How might a filter influence the sparsity patterns observed in data?

1. As filter reflects how do we process time series data, it accounts for the past & for the future with different power. Particularly, for this filter past information is not included in the result at all while the future one is evenly distributed for the defined window size.
2. Note that it the filter takes average of the future (not only one time point), thus it would be much smoother than the regular one.
""";

In [None]:
# @markdown Plot signal and its diff-box versions with window of size 10
plot_diff_box(sig, filter, diff_box_signal)

Now, we will define the window for 10 different values: `windows = [1,11,21,...91]` and calculate corresponding `diff_box` versions of the signal.

In [None]:
# @markdown Define different window values
windows = np.linspace(1, 91, 10)
diff_box_values = [diff_box(sig, int(window))[0] for window in windows]

Now, let's take a look at the temporal differenced versions of filtered signals as well as for separate histograms and kurtosis values.

In [None]:
# @markdown Visualization of differenced signals for different window sizes

plot_diff_with_diff_box(sig, windows, diff_box_values)

In [None]:
# @markdown Histogram for each of the window size

plot_temp_diff_separate_histograms(sig, windows, diff_box_values, tau = False)

In [None]:
# @markdown Kurtosis values comparison for different window sizes

plot_temp_diff_kurtosis(sig, windows, diff_box_values, tau = False)

#### Discussion

1. What do you observe about the kurtosis after applying the temporal differentiation?

In [None]:
#to_remove explanation

"""
Discussion: What do you observe about the kurtosis after applying the temporal differentiation?

In general, kurtosis value becomes higher meaning that temporal differentiation makes the signal more concentrated.
""";

---
# Section 3: Sparsity in space

*Estimated timing to here from start of tutorial: 30 minutes.*

In this section we will focus on spatial sparsity.


Below you can find the 1st frame which we will focus on for the spatial differentiation (i.e. differentiation in space).

In [None]:
# @markdown Visualization of the frame 

frame = np.load('frame1.npy')
plot_images([frame])

In [None]:
# @title Video 3: Sparsity in space

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'ylKP9s8qqYY')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

## Coding Excercise 3

So far, we focused on applying ReLU, which can also be considered as a form of **'soft thresholding'**. Now, we will apply a different kind of thresholding, called '**hard thresholding**'.

Here, you will have to write a function called **hard_thres** that receives 2 inputs:

1) 'frame' - 2d array as input of $p \times p$ floats.

2) 'theta' - threshold scalar values.

The function has to return a 2d array called **frame_HT** with the same dimensions as of **frame**, however, all values smaller than $\theta$ need to be set to 0, i.e.:
 $$
\text{frame-HT}_{[i,j]} =
\begin{cases}
frame_{[i,j]} & \text{if } frame_{[i,j]} \geq \theta \\
0 & \text{otherwise}
\end{cases}
$$
for every pixel [i,j] in frame.

In [None]:
def hard_thres(frame, theta):
    """
    Return hard thresholded array of values based on the parameter value theta.

    Inputs:
    - frame (np.array): 2D signal.
    - theta (float, default = 0): threshold parameter.
    """
    ###################################################################
    ## Fill out the following then remove
    raise NotImplementedError("Student exercise: complete `frame_HT` array calculation as defined.")
    ###################################################################
    frame_HT = frame.copy()
    frame_HT[...] = 0
    return frame_HT

In [None]:
#to_remove solution
def hard_thres(frame, theta):
    """
    Return hard thresholded array of values based on the parameter value theta.

    Inputs:
    - frame (np.array): 2D signal.
    - theta (float, default = 0): threshold parameter.
    """
    frame_HT = frame.copy()
    frame_HT[frame_HT < theta] = 0
    return frame_HT

In [None]:
# @markdown Test your function

frame_HT = hard_thres(frame, 150)
plot_images([frame, frame_HT])

Notice how many pixels became of the same **violet** color. Above, the violet color corresponds to 0. Hence, these pixels are the **thresholded** ones.

Here, we will define the variable **frame_HT** as the first frame after the hard threshold operator, using the above **hard_thres** function with threshold of 80% of the frame values.

In [None]:
# @markdown Observe 80% of thresholded signal

low_perc = np.percentile(frame, 80)
frame_HT = hard_thres(frame, low_perc)
plot_images([frame, frame_HT])

Notice that even more pixels become violet as we threshold the vast majority of the signal (80%).

We will now explore the differences before and after applying hard thresholding, using the histogram. Generate a histogram for the `frame_HT`. It might take a while as image is of high quality and it contains a lot of pixels.

In [None]:
# @markdown Histogram comparison

with plt.xkcd():
    fig, axs = plt.subplots(1,2,figsize = (15,5), sharey = True)
    axs[0].hist(frame.flatten(), bins = 100);
    axs[1].hist(frame_HT.flatten(), bins = 100);

    #utils
    [ax.set_yscale('log') for ax in axs]
    [remove_edges(ax) for ax in axs]
    [add_labels(ax, ylabel = 'Count', xlabel = 'Value') for ax in axs]
    [ax.set_title(title) for title, ax in zip(['Before Thresholding', 'After Thresholding'], axs)]

Let us also take a look at the kurtosis to assess sparsity.

In [None]:
# @markdown Kurtosis values comparison

plot_labeled_kurtosis(frame, frame_HT)

---
# Section 4: Spatial differentiation

*Estimated timing to here from start of tutorial: 40 minutes.*

In this section we are going to apply differentiation operation for spatial data. Spatial differeitation is common in identifying image patterns and extract meaningful features.
Let's apply it on the 1st frame to look on the result (meaning we return back to the 2D signal which represents image).

In [None]:
# @title Video 4: Spatial differentiation

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'ylKP9s8qqYY')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

## Coding Exercise 4

Let's apply spatial differentiation on the 1st frame to look on the result.

Below, apply 1-frame spatial differentiation on **frame**.

Particularly, define a variable **diff_x** that stores the x-axis differenration and **diff_y** that stores the y-axis differentiation. Compare the results by presenting the differetiated image.

In [None]:
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete calcualtion of `diff_x` and `diff_y`.")
###################################################################
diff_x = np.diff(frame, axis = ...)
diff_y = np.diff(...)

In [None]:
#to_remove solution

diff_x = np.diff(frame, axis = 1)
diff_y = np.diff(frame, axis = 0)

In [None]:
# @markdown Observe the result
plot_spatial_diff(frame, diff_x, diff_y)

Now, as usual, let's look on the histogram and kurtosis.

In [None]:
# @markdown Observe the histogram 
plot_spatial_histogram(frame, diff_x, diff_y)

In [None]:
# @markdown Observe the kurtosis values 
plot_spatial_kurtosis(frame, diff_x, diff_y)

### Multi-Step Analysis in Spatial Differencing

For spatial differencing, rather than focusing on differences between adjacent pixels, we'll employ wider filters. Specifically, we will apply the differentiation using the `diff_box` method, similar to our approach in the temporal case. Define the `diff_box_values_y` variable for the specified windows, for the first frame, in both `x` and `y` directions.

In contrast to the temporal differencing, we aim for the `diff_box` to be symmetric in the spatial case. This is important because, in spatial analysis, we need to consider the context from both sides of a given point, as opposed to time-series data, where there is a clear directional flow from past to present.

![Comparison of filters.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D5_Microcircuits/static/filters.png?raw=true)

Define a **symmetric** box filter in the function below.

In [None]:
def diff_box_spatial(data, window, pad_size = 4):
    """
    Implement & apply spatial filter on the signal.

    Inputs:
    - data (np.ndarray): input signal.
    - window (int): size of the window.
    - pad_size (int, default = 4): size of pad around data.
    """
    ###################################################################
    ## Fill out the following then remove
    raise NotImplementedError("Student exercise: add pad around to complete filter operation.")
    ###################################################################
    filter = np.concatenate([np.repeat(0, ...), np.repeat(-1, window), np.array([1]), np.repeat(-1, window), np.repeat(0, ...)]).astype(float)
    
    #normalize
    filter /= np.sum(filter**2)**0.5

    #make sure the filter sums to 0
    filter_plus_sum =  filter[filter > 0].sum()
    filter_min_sum = np.abs(filter[filter < 0]).sum()
    filter[filter > 0] *= filter_min_sum/filter_plus_sum

    #convolution of the signal with the filter
    diff_box = np.convolve(data, filter, mode='full')[:len(data)]
    diff_box[:window] = diff_box[window]
    return diff_box, filter

In [None]:
#to_remove solution
def diff_box_spatial(data, window, pad_size = 4):
    """
    Implement & apply spatial filter on the signal.

    Inputs:
    - data (np.ndarray): input signal.
    - window (int): size of the window.
    - pad_size (int, default = 4): size of pad around data.
    """
    filter = np.concatenate([np.repeat(0, pad_size), np.repeat(-1, window), np.array([1]), np.repeat(-1, window), np.repeat(0, pad_size)]).astype(float)
    
    #normalize
    filter /= np.sum(filter**2)**0.5

    #make sure the filter sums to 0
    filter_plus_sum =  filter[filter > 0].sum()
    filter_min_sum = np.abs(filter[filter < 0]).sum()
    filter[filter > 0] *= filter_min_sum/filter_plus_sum

    #convolution of the signal with the filter
    diff_box = np.convolve(data, filter, mode='full')[:len(data)]
    diff_box[:window] = diff_box[window]
    return diff_box, filter

Let's observe the filter. How it differs from the temporal one?

In [None]:
# @markdown Visualization of the filter
window = 10
diff_box_signal, filter = diff_box_spatial(sig, window)

with plt.xkcd():
    fig, ax = plt.subplots()
    plot_e1 = np.arange(len(filter))
    plot_e2 = np.arange(len(filter)) + 1
    plot_edge_mean = 0.5*(plot_e1 + plot_e2)
    plot_edge = lists2list( [[e1 , e2] for e1 , e2 in zip(plot_e1, plot_e2)])
    ax.scatter(plot_edge_mean, filter, color = 'purple')
    ax.plot(plot_edge, np.repeat(filter, 2), alpha = 0.3, color = 'purple')
    add_labels(ax,ylabel = 'Filter value', title = 'Box Filter', xlabel = 'Space (pixel)')

Let's apply the filter to the data. What do you observe?

In [None]:
###################################################################
## Fill out the following then remove.
raise NotImplementedError("Student exercise: complete calcualtion of `diff_box_values_y`.")
###################################################################

num_winds = 10
windows = np.linspace(2,92,num_winds)
rows = frame.shape[0]
cols = frame.shape[1]

diff_box_values_x = [np.array([diff_box_spatial(frame[row], int(window))[0]  for row in range(rows)]) for window in windows]

diff_box_values_y = [np.array([diff_box_spatial(frame[:, ...], int(window))[0] for col in range(cols)]) for window in windows]

In [None]:
#to_remove solution

num_winds = 10
windows = np.linspace(2,92,num_winds)
rows = frame.shape[0]
cols = frame.shape[1]

diff_box_values_x = [np.array([diff_box_spatial(frame[row], int(window))[0]  for row in range(rows)]) for window in windows]

diff_box_values_y = [np.array([diff_box_spatial(frame[:,col], int(window))[0] for col in range(cols)]) for window in windows]

In [None]:
# @markdown Observe the results

visualize_images_diff_box(frame, diff_box_values_x, diff_box_values_y)

In [None]:
# @markdown Kurtosis values comparison

with plt.xkcd():
    fig, ax = plt.subplots(figsize = (17,7))
    tauskur = [kurtosis(np.abs(frame - diff_box_values_i).flatten()) for diff_box_values_i in diff_box_values_x]
    tauskur_y = [kurtosis(np.abs(frame.T - diff_box_values_i).flatten()) for diff_box_values_i in diff_box_values_y]

    df1 = pd.DataFrame([kurtosis(sig)] + tauskur, index = ['Signal']+ ['$window = {%d}$'%tau for tau in windows], columns = ['kurtosis x'])
    df2 = pd.DataFrame([kurtosis(sig)] + tauskur_y, index = ['Signal']+ ['$window = {%d}$'%tau for tau in windows], columns = ['kurtosis y'])
    dfs = pd.concat([df1,df2], axis = 1)
    dfs.plot.barh(ax = ax, alpha = 0.5, color =['purple', 'pink'])
    remove_edges(ax)
    plt.show()

---
# Section 5: Sparse coding

*Estimated timing to here from start of tutorial: 50 minutes.*

Sparse coding refers to a coding strategy where the representation of data is sparse, meaning that only a small number of elements are active or nonzero while the majority are zero or inactive. This concept is often applied to sensory inputs, such as images or sounds, where the goal is to find a concise and efficient representation that captures the essential features of the input.

In [None]:
# @title Video 5: Sparse coding

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'hLxSTxmLbwg')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

## Neuroscience background to sparse coding


A pivotal experiment conducted by David Hubel and Torsten Wiesel [1] at Johns Hopkins University in 1959 aimed to address this inquiry by implanting an electrode into the visual cortex of a living cat and measuring its electrical activity in response to displayed images.

Despite prolonged exposure to various images, no significant activity was recorded.

However, an unexpected observation occurred when the slide was inserted and removed from the projector, revealing a response to the shape of the slide's edge.

This led to the discovery of neurons highly sensitive to **edge orientation and location**, providing the first insights into the type of information coded by neurons.



*[1] Hubel, David H., and Torsten N. Wiesel. ["Effects of monocular deprivation in kittens."](http://hubel.med.harvard.edu/papers/HubelWiesel1964NaunynSchmiedebergsArchExpPatholPharmakol.pdf) Naunyn-Schmiedebergs Archiv for Experimentelle Pathologie und Pharmakologie 248 (1964): 492–7.*

- **What implications do these specialized properties of neural representation hold for our understanding of visual perception?**
 - **How might these findings inform computational models of visual processing in artificial systems?**

## Neuroscience to AI

In 1996, Bruno Olshausen and David Field [2] published a paper that demonstrated sparse coding in the visual system, particularly in V1.

Their research focused on understanding receptive field properties of neurons in the visual cortex, an aspect of sensory processing in the brain.

Olshausen and Field showed that emergent properties of learned receptive fields resembled those in biological visual systems.

They found that neuron selectivity in the visual cortex could be explained through **sparse coding**, where only **a small subset of neurons responded to specific features or patterns**.


[2] *Olshausen, B. A., & Field, D. J. (1996). ["Emergence of simple-cell receptive field properties by learning a sparse code for natural images."](https://www.nature.com/articles/381607a0) Nature*

## $\ell_0$ pseudo-norm regularization to promote sparsity

The $\ell_0$ pseudo-norm is defined as the number of non-zero elements in the signal. Particularly, let $h \in \mathbb{R}^{J}$ be a vector with $J$ "latent activity" elements. Then:

$$\|h\|_0 = \sum_{j = 1}^J \mathbb{1}_{h_{j} \neq 0}$$

Hence, the  $\|\ell\|_0$ pseudo-norm  can be used to promote sparsity by adding it to a cost function to "punish" the number of non-zero elements.

Particularly, let's assume that we have a simple linear model, where we want to capture the observations $y$ using the linear model $D$ with sparse weights $x$.
Particularly, we are looking for the weights $x$ under the assumption that:
$$ y = Dh + \epsilon$$ where $\epsilon$ is an *i.i.d* Gaussian noise with zero mean and std of $\sigma_\epsilon$, i.e., $\epsilon \sim \mathcal{N}(0, \sigma_\epsilon^2)$.

If we assume that $h$ is sparse, i.e., that it has only a few non-zero elements, we would like to "punish" the number of non-zero elements.

We thus want to solve the following minimization problem:
$$
\hat{h} = \arg \min_x \|y - Dh \|_2^2 + \lambda \|h\|_0
$$




![Features to data.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D5_Microcircuits/static/dictionary.png?raw=true)

One of the methods to find the sparsest solution for a linear decomposition with $\ell_0$ regaularization is though OMP [1].


[1] Pati, Y. C., Rezaiifar, R., & Krishnaprasad, P. S. (1993, November). ["Orthogonal matching pursuit: Recursive function approximation with applications to wavelet decomposition."](https://www.khoury.northeastern.edu/home/eelhami/courses/EE290A/OMP_Krishnaprasad.pdf) In Proceedings of 27th Asilomar conference on signals, systems and computers (pp. 40-44). IEEE.


As explained in the video, the OMP (Orthogonal Matching Pursuit) algorithm is a method to find the best matching features to represent a target signal.

OMP iteratively selects the features that best correlate with the remaining part of the signal, updates the remaining part by subtracting the contribution of the chosen features, and repeats until the remaining part is minimized or the desired number of features is selected.


In this context, a "dictionary" is a collection of features that we use to represent the target signal. These features are like building blocks, and the goal of the OMP algorithm is to find the right combination of these blocks from the dictionary to best match the target signal.

## Coding Exercise 5

Write the OMP algorithm with increasing sparsity levels to explore how the pixel value are captured by different freqeuencies.



In [None]:
T_ar = np.arange(len(sig))

#100 different frequency values
freqs = np.linspace(0.001, 1, 100)
set_sigs = [np.sin(f*T_ar) for f in freqs]

# define 'reg' --- an sklearn object of OrthogonalMatchingPursuit and fit it to the data, where the frequency bases are the atoms and the signal is the label
reg = OrthogonalMatchingPursuit(fit_intercept = True, n_nonzero_coefs = 10).fit(np.vstack(set_sigs).T, sig)

Observe the plot of 3 example signals.

In [None]:
# @markdown Plot of 3 basis signals

with plt.xkcd():
    fig, axs = plt.subplots(3,1,sharex = True, figsize = (10,5), sharey = True)
    axs[0].plot(set_sigs[0], lw = 3)
    axs[1].plot(set_sigs[1], lw = 3)
    axs[2].plot(set_sigs[-1], lw = 3)
    [remove_edges(ax) for ax in axs]
    [ax.set_xlim(left = 0) for ax in axs]
    fig.tight_layout()

Here we plot prediction and basis operators.

In [None]:
# @markdown Weights visualization and original vs reconstrcuted signal comparison

with plt.xkcd():
    fig, axs = plt.subplots(2,2, figsize = (30,20))
    axs = axs.flatten()

    sns.heatmap(np.vstack(set_sigs), ax = axs[0])
    axs[0].set_yticks(np.arange(0,len(freqs)-.5, 4)+ 0.5)
    axs[0].set_yticklabels(['$f = %.4f$'%freqs[int(j)] for j in np.arange(0,len(freqs)-0.5, 4)], rotation = 0)
    add_labels(axs[0], xlabel = 'Time', ylabel= 'Basis Elements', title = 'Frequency Basis Elements')

    axs[1].plot(sig, label = 'Original', lw = 4)
    axs[1].plot(reg.predict(np.vstack(set_sigs).T), lw = 4, label = 'Reconstruction')
    remove_edges(axs[1])
    axs[1].legend()
    add_labels(axs[1], xlabel = 'Time', ylabel= 'Signal', title = 'Reconstruction')
    axs[1].set_xlim(left = 0)

    axs[2].stem(freqs, reg.coef_)
    remove_edges(axs[2])
    add_labels(axs[2], xlabel = 'Frequencies', ylabel= 'Frequency weight', title = 'Frequency Components Contributions')
    axs[2].set_xlim(left = 0)

    num_colors = np.sum(reg.coef_ != 0)
    cmap_name = 'winter'
    samples = np.linspace(0, 1, num_colors)
    colors = plt.colormaps[cmap_name](samples)
    colors_expand = np.zeros((len(reg.coef_), 4))
    colors_expand[reg.coef_!= 0] = np.vstack(colors)

    axs[-1].plot(sig, label = 'Original', color = 'black')
    [axs[-1].plot(reg.coef_[j]*set_sigs[j] + reg.intercept_, label = '$f = %.4f$'%f, lw = 4, color = colors_expand[j]) for j, f in enumerate(freqs) if  reg.coef_[j] != 0]
    remove_edges(axs[-1])
    axs[-1].legend(ncol = 4)
    axs[-1].set_xlim(left = 0)
    add_labels(axs[-1], xlabel = 'Time', ylabel= 'Signal', title = 'Contribution')

Now we will explore the effect of increasing the cardinality. I.e., increasing the number of non-zero elements in the coefficients.
Below, please run OMP with increasing number of non-zero, coefficiets, from 1 to 101 in intervals ot 5.
We will then compare the accuracy and reconstruction performance of eahc order.

In [None]:
# @markdown OMP for different cardinality values

# define a list or numpy array of optional cardinallities from 1 to 51 in intervals of 5.
cardinalities = np.arange(1,101,5)

# for each of the optional cardinalities, run OMP using the pixel's signal from before. Create a list called "regs" that include all OMP's fitted objects
regs = [OrthogonalMatchingPursuit(fit_intercept = True, n_nonzero_coefs = card).fit(np.vstack(set_sigs).T, sig) for card in cardinalities]

Now let's observe the effect of the caridnality on the reconstruction.

In [None]:
# @markdown Cardinality impact on reconstruction quality

with plt.xkcd():
    fig, axs = plt.subplots(1, 2, figsize = (20,5))
    ax = axs[0]
    ax.plot(cardinalities, [reg.score(np.vstack(set_sigs).T, sig) for reg in regs], marker = '.')

    ax.set_xlim(left = 0)
    ax.set_ylim(bottom = 0)
    add_labels(ax, ylabel = 'coefficient of determination', xlabel  = 'Cardinality')
    remove_edges(ax)
    mean_er = np.vstack([np.mean((reg.predict(np.vstack(set_sigs).T) - sig)**2)**0.5 for reg in regs])

    axs[1].plot(cardinalities, mean_er)
    axs[1].set_xlim(left = 0)
    axs[1].set_ylim(bottom = 0)
    remove_edges(axs[1])
    add_labels(axs[1], ylabel = 'rMSE', xlabel  = 'Cardinality')
    plt.show()

What is your conclusion? How do you think the above will affect generalization? How would you define the best cardinality?

In [None]:
# @markdown Assigned weights plot for each frequency for different cardinality values

with plt.xkcd():
    fig, axs = plt.subplots(5, int(len(cardinalities)/5), figsize = (40,20),sharey = False, sharex = True)
    axs = axs.flatten()
    for j, (ax,reg) in enumerate(zip(axs,regs)):
      ax.stem(freqs, reg.coef_)
      remove_edges(ax)
      add_labels(ax, xlabel = 'Frequencies', ylabel= 'Frequency weight', title = 'Cardinality %d'%cardinalities[j])
      ax.set_xlim(left = 0)
    fig.tight_layout()
    plt.show()

---
# Section 6: Hidden features as dictionary learning

*Estimated timing to here from start of tutorial: 1 hour 10 minutes.*

## Challenges with OMP

While OMP is a common and simple practice for finding the sparse representation of a dictionary of features, it presents multiple challenges.

First, the $\ell_0$ norm is not convex, which raises questions about identifying the sparse vector. In other words, inferring the sparse components in a different, potentially non-iterative way may yield better results. **What happens if we explore all possible k-sparse representations? Will we obtain the same result as with OMP?**

Another challenge is computational complexity. This approach is highly demanding, requiring iterations over all features, constant updating of the support, and resolving the least squares problem. All of these processes can be time-consuming.

Hence, approaches addressing the above issues exist. Some of them replace the $\ell_0$ norm with its $\ell_1$ approximation, which also promotes sparsity but is convex and hence easier to work with.

For example, solving the following optimization problem:

$$
\hat{h} = \arg \min_x \|y - Dh \| + \lambda \|h\|_1
$$
where we replaced the $\ell_0$ regularization on $h$ with $\ell_1$.

## What should we do if we do not know the dictionary of features?

In OMP and similar methods, we've operated under the assumption that we already know the features and are solely focused on identifying which sparse combination of them can best describe the data. However, in real-world scenarios, we frequently lack knowledge about the fundamental components underlying the data. In other words, we don't know the features and aim to identify them as part of the learning process.

**Dictionary Learning:**
To address this problem, Dictionary Learning aims to simultaneously learn both the dictionary of features and the sparse representation of the data. It iteratively updates the dictionary and the sparse codes until convergence, typically minimizing a reconstruction error term along with a sparsity-inducing regularizer.

**How is it related to the brain?**
Dictionary Learning draws parallels to the visual system's ability to extract features from raw sensory input. In the brain, the visual system processes raw visual stimuli through hierarchical layers, extracting increasingly complex features at each stage. Similarly, Dictionary Learning aims to decompose complex data into simpler components, akin to how the visual system breaks down images into basic visual elements like edges and textures.

Now, let's apply Dictionary Learning to various types of data to unveil the latent components underlying it.

In [None]:
# @title Video 6: Hidden atoms as dictionary learning

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'TGgjedW502E')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

## Coding Exercise 6

In this exercise, we first load a new short video created from MNIST, as shown below. We then use sklearn's `DictionaryLearning` to find the dictionary's ($D$) features.

In [None]:
###################################################################
## Fill out the following then remove.
raise NotImplementedError("Student exercise: complete calcualtion of `X_transformed`.")
###################################################################

# Video_file is a 3D array representing pixels X pixels X time
video_file = np.load('reweight_digits.npy')

# Create a copy of the video_fire array
im_focus = video_file.copy()

# Get the number of frames in the video
T = im_focus.shape[2]

# Get the number of rows in the video
N0 = im_focus.shape[0]

# Get the number of columns in the video, leaving out 10 columns
N1 = im_focus.shape[1] - 10

# Create a copy of the extracted frames
low_res = im_focus.copy()

# Get the shape of a single frame
shape_frame = low_res[:, :, 0].shape

# Flatten each frame and store them in a list
video_file_ar = [low_res[:, :, frame].flatten() for frame in range(low_res.shape[2])]

# Create dict_learner object
dict_learner = DictionaryLearning(
    n_components=15, transform_algorithm='lasso_lars', transform_alpha=0.9,
    random_state=402,
)

# List to np.array
video_v = np.vstack(video_file_ar)

# Fit and transform `video_v`
X_transformed = dict_learner.fit(...).transform(...)

In [None]:
#to_remove solution
# Video_file is a 3D array representing pixels X pixels X time
video_file = np.load('reweight_digits.npy')

# Create a copy of the video_fire array
im_focus = video_file.copy()

# Get the number of frames in the video
T = im_focus.shape[2]

# Get the number of rows in the video
N0 = im_focus.shape[0]

# Get the number of columns in the video, leaving out 10 columns
N1 = im_focus.shape[1] - 10

# Create a copy of the extracted frames
low_res = im_focus.copy()

# Get the shape of a single frame
shape_frame = low_res[:, :, 0].shape

# Flatten each frame and store them in a list
video_file_ar = [low_res[:, :, frame].flatten() for frame in range(low_res.shape[2])]

# Create dict_learner object
dict_learner = DictionaryLearning(
    n_components=15, transform_algorithm='lasso_lars', transform_alpha=0.9,
    random_state=402,
)

# List to np.array
video_v = np.vstack(video_file_ar)

# Fit and transform `video_v`
X_transformed = dict_learner.fit(video_v).transform(video_v)

Let's visualize the identified features!

In [None]:
# @markdown Visualization of atoms

with plt.xkcd():
    num_thetas = 20
    cmap_name = 'plasma'
    samples = np.linspace(0, 1, len(dict_learner.components_)+1)
    colors = plt.colormaps[cmap_name](samples)

    num_comps = len(dict_learner.components_)

    fig, axs = plt.subplots(2, int((1+num_comps)/2) , figsize = (16,4), sharey = True, sharex = True)
    axs = axs.flatten()
    [sns.heatmap(dict_learner.components_[j].reshape(shape_frame), square = True, cmap ='gray', robust = True, ax = axs[j]) for j in range(num_comps)]
    titles = ['Atom %d'%j for j in range(1, num_comps +1)]
    [ax.set_title(title, color = colors[j], fontsize = 10)
    for j, (title, ax) in enumerate(zip(titles, axs))]
    axs[-1].set_visible(False)

In [None]:
# @markdown Visualization of impact of atoms through the time

x = np.linalg.pinv(dict_learner.components_.T) @ video_v.T
X_transformed_norm = x.T.copy()
with plt.xkcd():
    fig, ax = plt.subplots(figsize = (16,3))
    [ax.plot(X_transformed_norm[:,j], color = colors[j], lw = 5, label = 'Atom %d'%(j+1))
     for j in range(num_comps)]
    remove_edges(ax)
    ax.set_xlim(left = 0)
    [ax.set_yticks([]) for ax in axs]
    [ax.set_xticks([]) for ax in axs]

    add_labels(ax, xlabel = 'Time', ylabel = 'Coefficients', title = 'Elements Ceofficients')
    create_legend({'Atom %d'%(j+1): colors[j] for j in range(num_comps)},
                  params_leg = {'ncol': 3})

Now, let's compare the components to PCA.

In [None]:
# @markdown PCA components visualization

data_mat = np.vstack(video_file_ar)
# Assuming your data is stored in a variable named 'data_mat' (N by p numpy array)
# where N is the number of samples and p is the number of features

# Create a PCA object with 20 components
pca = PCA(n_components=15)

# Fit the PCA model to the data
pca.fit(data_mat)

# Access the 20 PCA components
pca_components = pca.components_

pca_componens_images = [pca_components[comp,:].reshape(shape_frame) for comp in range(num_comps)]

with plt.xkcd():
    fig, axs = plt.subplots(2, int((1+num_comps)/2) , figsize = (16,4), sharey = True, sharex = True)
    axs = axs.flatten()
    [sns.heatmap(im, ax = axs[j], square = True, cmap ='gray', robust = True) for j,im in enumerate(pca_componens_images)]
    titles = ['Atom %d'%j for j in range(1, num_comps +1)]
    [ax.set_title(title, color = colors[j], fontsize = 10) for j, (title, ax) in enumerate(zip(titles, axs))]
    axs[-1].set_visible(False)

---
# Section 7: Identifying Sparse Visual Fields in Natural Images Using Sparse Coding

*Estimated timing to here from start of tutorial: 1 hour 25 minutes.*

In this section we will understand the basis of visual field perception.

In [None]:
# @title Video 7: Sparse Coding in Natural Images
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'TGgjedW502E')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

We will use the CIFAR10 dataset. Let's, at first, load the data.

In [None]:
# @markdown Load CIFAR-10 dataset

import contextlib
import io

with contextlib.redirect_stdout(io.StringIO()): #to suppress output
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

Now, let's create a short video of these images switching. Now, we will detect the underlying visual fields. 

In [None]:
# @markdown Create & explore video

num_image_show = 50
img_index = np.random.randint(0, len(x_train), num_image_show)
x_train_ar = np.copy(x_train)[img_index]
video_gray = 0.2989 * x_train_ar[:, :, :,0] + 0.5870 * x_train_ar[:, :, :,1] + 0.1140 * x_train_ar[:, :, :,2]

def show_slice_CIFAR(slice_index):
    with plt.xkcd():
        plt.figure(figsize=(2, 2))
        plt.imshow(data_CIFAR[slice_index] , cmap='gray')
        plt.axis('off')
        plt.show()

data_CIFAR = np.repeat(video_gray, 1, axis = 0)

# Create a slider to select the slice
slider = IntSlider(min=0, max=data_CIFAR.shape[0] - 1, step=1, value=0, description='Frame')

# Connect the slider and the function
interact(show_slice_CIFAR, slice_index=slider)

Now, let's try to extract the visual field from the images. Look at the components identified below, and compare the ones identified by PCA to those identified by Olshausen & Field via sparse dictionary learning.

![Components comparison.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D5_Microcircuits/static/components.png?raw=true)

Particularly, following Olshausen & Field paper, we will demonstrate how to receive visual field components emerge from real-world images.

We expect to find elements that are:
1.   Spatially localized.
2.   Oriented.
3.   Selective to structure at different spatial scales.

We would like to understand receptive field properties of neurons in the visual cortex by applying sparsity to identify the basic fundamental components underlying natural images.

We will use the CIFAR data we loaded before and reshape it such that the time (frame) axis is the last one.

In [None]:
video_file = data_CIFAR.transpose((1,2,0))

For the exercise, we will focus on a small subset of the data, just so you can understand the idea. We will later load better results trained for longer time.

We will first set some parameters to take a small subset of the video. In the cell below you can explore parameters which are set by default

In [None]:
# @markdown Set default parameters

size_check =  4 # 32 for the full size

# how many images we want to consider? (choose larger numbers for better results)
num_frames_include =  80

# set number of dictionary elements
dict_comp = 8

# set number of model iterations
max_iter = 100

num_frames_include = np.min([num_frames_include , x_train.shape[0]])

Let's visualize the small subset we are taking for the training.

In [None]:
# @markdown Images visualization

# let's choose some random images to focus on
img_index = np.random.randint(0, len(x_train), num_frames_include)
x_train_ar = np.copy(x_train)[img_index]

# change to grayscale
video_gray = 0.2989 * x_train_ar[:, :, :,0] + 0.5870 * x_train_ar[:, :, :,1] + 0.1140 * x_train_ar[:, :, :,2]

# update data and create short movie
data_CIFAR = np.repeat(video_gray, 1, axis = 0)

# Create a slider to select the slice
slider = IntSlider(min=0, max=data_CIFAR.shape[0] - 1, step=1, value=0, description='Frame')

# Connect the slider and the function
interact(show_slice_CIFAR, slice_index=slider)

Find the latent components by training dictionary learner and then visualize the results. Training time will take around 1 minute.

In [None]:
# @markdown Coefficient values through time

# Get the number of frames in the video
T = video_file.shape[2]

# Get the number of rows in the video
N0 = video_file.shape[0]

# Get the number of columns in the video, leaving out 10 columns
N1 = video_file.shape[1] - 10

# Get the shape of a single frame
shape_frame = video_file[:, :, 0].shape

# Flatten each frame and store them in a list
num_frames_include = np.min([num_frames_include, video_file.shape[2]])

size_check = np.min([size_check, video_file.shape[0]])

video_file_ar = [video_file[:size_check, :size_check, frame].flatten() for frame in range(num_frames_include)]

alpha = 0.01

dict_learner = DictionaryLearning(
    n_components=dict_comp, transform_algorithm = 'lasso_lars', transform_alpha=alpha,
    random_state=42,
)

X_transformed = dict_learner.fit_transform(np.vstack(video_file_ar))

with plt.xkcd():
    num_rows = 3
    num_cols =  int(( num_frames_include + 2)/num_rows)
    shape_frame = (size_check,size_check)
    
    cmap_name = 'plasma'
    samples = np.linspace(0, 1, len(dict_learner.components_)+1)
    colors = plt.colormaps[cmap_name](samples)
    num_comps = len(dict_learner.components_)
    
    vmin =  np.min(np.vstack(dict_learner.components_))
    vmax = np.max(np.vstack(dict_learner.components_))
    
    
    [sns.heatmap(dict_learner.components_[j].reshape(shape_frame), ax = axs[j], square = True, vmin = vmin, vmax = vmax,
                  cmap = 'PiYG', center = 0, cbar = False)  for j in range(num_comps)]
    
    [ax.set_xticks([]) for ax in axs]
    [ax.set_yticks([]) for ax in axs]
    
    titles = ['Atom %d'%j for j in range(1, num_comps +1)];
    [ax.set_title(title,  fontsize = 40) for j, (title, ax) in enumerate(zip(titles, axs))]
    
    
    fig, ax = plt.subplots(figsize = (20,5))
    [ax.plot(X_transformed[:,j],  lw = 5, label = 'Atom %d'%(j+1))
      for j in range(num_comps)]
    remove_edges(ax)
    ax.set_xlim(left = 0)
    [ax.set_yticks([]) for ax in axs]
    [ax.set_xticks([]) for ax in axs]
    
    add_labels(ax, xlabel = 'Time', ylabel = 'Coefficients', title = 'Elements Coefficients')    
    fig.tight_layout()

Due to our efforts to limit computation time, we focused on a small area of the images and only considered a limited number of them with a restricted number of iterations. As a result, the quality of the outcomes was poor. For your exploration, we are providing better results obtained from training on a larger amount of images.
This is important for discovering the components that better resemble the biological visual field.

In [None]:
# @markdown Load full model & visualize atoms
res = np.load('model.npy', allow_pickle = True).item()
comps = res['dict_learner_components_']

with plt.xkcd():
    shape_frame = (32,32)

    cmap_name = 'plasma'
    samples = np.linspace(0, 1, len(comps)+1)
    colors = plt.colormaps[cmap_name](samples)
    num_comps = len(comps)
    
    vmin =  np.min(np.vstack(comps))
    vmax = np.max(np.vstack(comps))
    
    fig, axs = plt.subplots(2, int(np.ceil((num_comps+1)/2)) , figsize = (50,15), sharey = True, sharex = True)
    axs = axs.flatten()
    
    
    [axs[j].imshow(comps[j].reshape(shape_frame), vmin = 0)  for j in range(num_comps)]
    
    [ax.set_xticks([]) for ax in axs]
    [ax.set_yticks([]) for ax in axs]
    
    titles = ['Atom %d'%j for j in range(1, num_comps +1)];
    [ax.set_title(title,  fontsize = 40, color = colors[j]) for j, (title, ax) in enumerate(zip(titles, axs))]
    axs[-1].set_visible(False)
    axs[-2].set_visible(False)

In [None]:
# @markdown Activity of the components

X_transformed = res['X_transformed']

fig, ax = plt.subplots(figsize = (20,5))
[ax.plot(X_transformed[:,j], color = colors[j], lw = 5, label = 'Atom %d'%(j+1))
  for j in range(num_comps)]
remove_edges(ax)
ax.set_xlim(left = 0)
[ax.set_yticks([]) for ax in axs]
[ax.set_xticks([]) for ax in axs]

add_labels(ax, xlabel = 'Time', ylabel = 'Coefficients', title = 'Elements Coefficients')
create_legend({'Atom %d'%(j+1): colors[j] for j in range(num_comps)}, params_leg = {'ncol': 5})

---
# Summary

Estimated timing of tutorial: 1 hour 40 minutes.