In [None]:
!python -m pip install --upgrade pip
!pip install mindscope_utilities --upgrade
!pip install --upgrade scikit-learn
!pip install dill
!pip install pickle

Collecting pip
  Downloading pip-21.2.4-py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 6.8 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-21.2.4
Collecting mindscope_utilities
  Downloading mindscope_utilities-0.1.8.tar.gz (11 kB)
Collecting flake8
  Downloading flake8-3.9.2-py2.py3-none-any.whl (73 kB)
[K     |████████████████████████████████| 73 kB 1.7 MB/s 
Collecting allensdk
  Downloading allensdk-2.12.2-py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 15.3 MB/s 
Collecting simpleitk<3.0.0,>=2.0.2
  Downloading SimpleITK-2.1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (48.4 MB)
[K     |████████████████████████████████| 48.4 MB 35 kB/s 
[?25hCollecting psycopg2-binary<3.0.0,>=2.7
  Downloading psycopg2_binary-2.9.1-cp37-cp37m-manylinux_2_17_

Collecting scikit-learn
  Downloading scikit_learn-0.24.2-cp37-cp37m-manylinux2010_x86_64.whl (22.3 MB)
[K     |████████████████████████████████| 22.3 MB 1.4 MB/s 
Collecting threadpoolctl>=2.0.0
  Downloading threadpoolctl-2.2.0-py3-none-any.whl (12 kB)
Installing collected packages: threadpoolctl, scikit-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 0.22.2.post1
    Uninstalling scikit-learn-0.22.2.post1:
      Successfully uninstalled scikit-learn-0.22.2.post1
Successfully installed scikit-learn-0.24.2 threadpoolctl-2.2.0
[31mERROR: Could not find a version that satisfies the requirement pickle (from versions: none)[0m
[31mERROR: No matching distribution found for pickle[0m


In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import json

from tqdm.notebook import tqdm_notebook
from scipy.stats import zscore
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
from allensdk.brain_observatory.ecephys.visualization import plot_mean_waveforms, plot_spike_counts, raster_plot
from allensdk.brain_observatory.visualization import plot_running_speed

## Load the session and experiment summary tables

The AllenSDK provides functionality for downloading tables that describe all session types ('brain_observatory_1.1', 'functional_connectivity') in the Visual Coding – Neuropixels. We have first download the data cache:


- Brain Observatory 1.1 and Functional Connectivity sessions correspond to different stimulus sets

In [None]:
# this path determines where downloaded data will be stored
manifest_path = os.path.join('/local1/ecephys_cache_dir/', "manifest.json")
#manifest_path = os.path.join('/temp', "manifest.json")

cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

print(cache.get_all_session_types())

['brain_observatory_1.1', 'functional_connectivity']


Then we can access the session table directly

In [None]:
sessions = cache.get_session_table()
brain_observatory_ids = list(sessions[sessions["session_type"] == "brain_observatory_1.1"].index.unique())
func_connectivity_ids = list(sessions[sessions["session_type"] == "functional_connectivity"].index.unique())

### Which sessions have all the areas in the visual cortex `VISp` `VISl` `VISpm` `VISam` `VISal` `VISrl`

In [None]:
vis_cortex_areas = ["VISp", "VISl", "VISpm", "VISam", "VISal", "VISrl"]

mask = sessions.ecephys_structure_acronyms.apply(lambda x: all(elem in x for elem in vis_cortex_areas))
vis_cortex_ids = sessions[mask].index.values
vis_cortex_ids

array([719161530, 750332458, 750749662, 754312389, 755434585, 756029989,
       778240327, 778998620, 791319847, 794812542, 797828357, 831882777,
       847657808])

#### Make this a  `get_sessions` function

In [None]:
def get_session(session_id):
  session = cache.get_session_data(session_id)
  return session

#### Make a `get_brain_regions` function

In [None]:
def get_brain_regions(session_id):
  brain_regions = sessions.loc[session_id, "ecephys_structure_acronyms"].tolist()
  return brain_regions

In [None]:
brain_regions = get_brain_regions(vis_cortex_ids[1])

## Make `get_spikes` function to load spike data into memory


The function below will load the spikes data for the specified cells and stimulus type as well as the corresponding presentations into memory.

We will extract spikes using `EcephysSession.presentationwise_spike_times`, which returns spikes annotated by the `units` (`neurons`) that emitted them and the stimulus presentation during which they were emitted.

In [None]:
def get_spikes(session, neuron_types, stimulus_types, stimulus_table):

  spikes = {}

  for stimulus_type in stimulus_types:
    scene_presentations = stimulus_table[stimulus_table["stimulus_name"] == stimulus_type]

    spikes[stimulus_type] = {}
    
    for neuron_type in neuron_types:   
        
      neuron_units = session.units[session.units["ecephys_structure_acronym"] == neuron_type]

      spikes[stimulus_type][neuron_type] = session.presentationwise_spike_times(
                                            stimulus_presentation_ids=scene_presentations.index.values,
                                            unit_ids=neuron_units.index.values[:])

  return spikes

## Build Design Matrix

In [None]:
def build_matrix(session, stimulus_types, neuron_types, stimulus_table, features):

  """

  """
  design_matrix = {}

  for feature in features:

    if feature == "spike count":
      spike_data = get_spikes(session, neuron_types, stimulus_types, stimulus_table)

      for stimulus_type in tqdm_notebook(spike_data.keys()):
        for neuron_type in spike_data[stimulus_type].keys():

          spike_data[stimulus_type][neuron_type]["count"] = np.zeros(spike_data[stimulus_type][neuron_type] .shape[0])
          spike_data[stimulus_type][neuron_type] = spike_data[stimulus_type][neuron_type].groupby(["stimulus_presentation_id", 
                                                               "unit_id"]).count()
          data_matrix = pd.pivot_table(
                                    spike_data[stimulus_type][neuron_type], 
                                    values="count", 
                                    index="stimulus_presentation_id", 
                                    columns="unit_id", 
                                    fill_value=0.0,
                                    aggfunc=np.sum
                                    )
          spike_data[stimulus_type][neuron_type] = data_matrix

      design_matrix[feature] = spike_data
      

  return design_matrix
    



In [None]:
def get_labels(design_matrix, stimulus_types):
  
  neuron_type = list(design_matrix["spike count"][stimulus_types[0]].keys())
  labels = {}
  
  for stimulus_type in tqdm_notebook(stimulus_types):
    design_matr = design_matrix["spike count"][stimulus_type][neuron_type[0]]
    if stimulus_type == " ":  
      labels[stimulus_type] = stimulus_table[stimulus_table["stimulus_name"] == 
                              stimulus_type].loc[design_matr.index.values, "frame"]
      labels[stimulus_type] = pd.DataFrame(labels[stimulus_type])

    elif stimulus_type == "static_gratings":
      labels[stimulus_type] = stimulus_table[stimulus_table["stimulus_name"] == 
                              stimulus_type].loc[design_matr.index.values, "orientation"]
      labels[stimulus_type] = pd.DataFrame(labels[stimulus_type])

  return labels


In [None]:
def get_pca(design_matrix, threshold, normalize=False):

  """

  """
  
  if normalize:
    design_matrix = normalize(design_matrix)

  
  pca = PCA().fit(design_matrix)
  n_components = np.sum(np.cumsum(pca.explained_variance_ratio_) <= threshold)
  pca_model = PCA(n_components=n_components)
  design_matrix = pca_model.fit_transform(design_matrix)

  return pca_model, pd.DataFrame(design_matrix)

In [None]:
def get_images(labels_df):

  image_labels = list(labels_df["natural_scenes"]["frame"].unique())
  images = {}

  if -1 in image_labels:
    image_labels.remove(-1)

  for image_index in tqdm_notebook(image_labels):
    images[image_index] = cache.get_natural_scene_template(image_index)

  return images

In [None]:
features = ["spike count"]
stimulus_types = ["natural_scenes", "static_gratings"]
neuron_types = "VISpm"
design_matrix = build_matrix(session, stimulus_types, neuron_types, stimulus_table, features)
labels_df = get_labels(design_matrix, stimulus_types)

NameError: ignored

In [None]:
images = get_images(labels_df)

NameError: ignored

# Model

In [None]:
# Import libraries
import os
import time
from tqdm.notebook import tqdm_notebook
import torch
import IPython
import torchvision

import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
from torchvision.models import AlexNet
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
from PIL import Image
from torchvision.utils import save_image
from torchsummary import summary

from PIL import Image
from io import BytesIO

resnet = torchvision.models.resnet50(pretrained=True)
#AlexNet = torchvision.models.alexnet(pretrained=True)

## Test learned `weights` on other session

### Get usual `spike_count` data. Also make `test_design_matrix` and obtain corresponding `test_labels`

In [None]:
test_session = get_session(vis_cortex_ids[3])

In [None]:
test_stimulus_table = test_session.get_stimulus_table()

In [None]:
neuron_types = vis_cortex_area
test_spike_data = get_spikes(test_session, neuron_types, stimulus_types, test_stimulus_table)


In [None]:
features = ["spike count"]
stimulus_types = ["natural_scenes", "static_gratings"]
test_design_matrix = build_matrix(test_session, stimulus_types, neuron_types, test_stimulus_table, features)


In [None]:
test_labels_df = get_labels(test_design_matrix, stimulus_types)

In [None]:
test_design_matrix["spike count"]["natural_scenes"][visual_cortex_area].shape

In [None]:
train_voxel_dims = design_array.shape[1]
pca = PCA(n_components=train_voxel_dims)
pca.fit(test_design_matrix["spike count"]["natural_scenes"][visual_cortex_area])
test_design_array = pd.DataFrame(pca.transform(test_design_matrix["spike count"]["natural_scenes"][visual_cortex_area]))

In [None]:
test_design_array

### Encoded Feature Vectors

In [None]:
weights = pd.read_csv(f"weights_{vis_cortex_area}")

In [None]:
encoded_feature_vectors = test_design_array.to_numpy() @ weights.T.to_numpy()

In [None]:
encoded_feature_vectors = pd.DataFrame(encoded_feature_vectors, index=stim_ids)

In [None]:
encoded_feature_vectors = encoded_feature_vectors.drop(nan_obsIDs)

In [None]:
encoded_feature_vectors.to_csv(f"encoded_feature_vectors_{vis_cortex_area}")

In [None]:
encoded_feature_vectors.head(5)

In [None]:
encoded_feature_vectors.shape

# Reconstruct Images

### Create `VAE_decoder` class

In [None]:
class VAE_decoder(nn.Module):
    def __init__(self, feat_size, output_dim=(1, 256, 256)):
        """
        Initializes the VAE decoder network.
        Optional args:
        - feat_size (int): size of the final features layer (default: 256)
        - output_dim (tuple): output image dimensions (channels, width, height) 
            (default: (1, 256, 256))
        """

        super().__init__()
        self.feat_size = feat_size
        self._vae = True
        self.output_dim = output_dim

        self.decoder_linear = nn.Sequential(
              nn.Linear(self.feat_size, 512),
              nn.ReLU(),
              nn.BatchNorm1d(512, affine=False),
              nn.Linear(512, 1024),
              nn.ReLU(),
              nn.BatchNorm1d(1024, affine=False),
              nn.Linear(1024, 59536),
              nn.ReLU()
        )
        self.decoder_conv = nn.Sequential(
              nn.UpsamplingNearest2d(scale_factor=2),
              nn.BatchNorm2d(16, affine=False),
              nn.ConvTranspose2d(
                  in_channels=16, out_channels=6, kernel_size=5, stride=1
                  ),
              nn.ReLU(),
              nn.UpsamplingNearest2d(scale_factor=2),
              nn.BatchNorm2d(6, affine=False),
              nn.ConvTranspose2d(
                  in_channels=6, out_channels=1, kernel_size=5, stride=1
                  )
        )

        self._test_output_dim()

    @property
    def vae(self):
        return self._vae

    def _test_output_dim(self):
        dummy_tensor = torch.ones(1, self.feat_size)
        reset_training = self.training
        self.eval()
        with torch.no_grad():
            decoder_output_shape = self.reconstruct(dummy_tensor).shape[1:]
        if decoder_output_shape != self.output_dim:
            raise ValueError(f"Decoder produces output of shape "
                f"{decoder_output_shape} instead of expected "
                f"{self.output_dim}.")
        if reset_training:
            self.train()

    def decode(self, z):
        h3 = self.decoder_linear(z.float())
        h3 = h3.view(-1, 16, 61, 61)
        recon_x_logits = self.decoder_conv(h3)
        return recon_x_logits

    def forward(self, X):
        z = X
        recon_x_logits = self.decode(z)
        return recon_x_logits

    def reconstruct(self, X):
        with torch.no_grad():
            recon_x = torch.sigmoid(self.decode(X))
        return recon_x


#### Specify VAE `loss_function`

In [None]:
def vae_loss_function(recon_X_logits, X, beta=1.0):
    """
    vae_loss_function(recon_X_logits, X, mu, logvar)
    Returns the weighted VAE loss for the batch.
    Required args:
    - recon_X_logits (4D tensor): logits of the X reconstruction 
        (batch_size x shape of x)
    - X (4D tensor): X (batch_size x shape of x)
    - mu (2D tensor): mu values (batch_size x number of features)
    - logvar (2D tensor): logvar values (batch_size x number of features)
    Optional args:
    - beta (float): parameter controlling weighting of KLD loss relative to 
        reconstruction loss. (default: 1.0)
    
    Returns:
    - (float): weighted VAE loss
    """

    BCE = torch.nn.functional.binary_cross_entropy_with_logits(
        recon_X_logits, X, reduction="sum"
        )
    #KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE



#### Train VAE

In [None]:
def train_vae(encoder, labels_df, image_dataset, dataset, train_sampler, num_epochs=10, batch_size=500, 
              beta=1.0,  use_cuda=True, verbose=False):
    """
    train_vae(encoder, dataset, train_sampler)
    Function to train an encoder using the SimCLR loss.
    
    Required args:
    - encoder (nn.Module): Encoder network instance for extracting features. 
        Should have method get_features().
    - dataset (dSpritesTorchDataset): dSprites torch dataset
    - train_sampler (SubsetRandomSampler): Training dataset sampler.
    
    Optional args:
    - num_epochs (int): Number of epochs over which to train the classifier. 
        (default: 10)
    - batch_size (int): Batch size. (default: 100)
    - beta (float): parameter controlling weighting of KLD loss relative to 
        reconstruction loss. (default: 1.0)
    - use_cuda (bool): If True, cuda is used, if available. (default: True)
    - verbose (bool): If True, 5 first batch reconstructions are plotted at 
        each epoch. (default: False)
    Returns: 
    - encoder (nn.Module): trained encoder
    - decoder (nn.Module): trained decoder
    - loss_arr (list): training loss at each epoch
    """

    device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu"

    decoder = VAE_decoder(encoded_feature_vectors.shape[1] , (1, 256, 256)).to(device)

    # if not encoder.vae:
    #     raise ValueError("Must pass encoder for which self.vae is True.")

    train_dataloader = torch.utils.data.DataLoader(
        torch.tensor(encoded_feature_vectors.to_numpy()), batch_size=batch_size,
        )

    
    image_dataloader = torch.utils.data.DataLoader(
                        image_dataset, batch_size=batch_size)

    dataiter = iter(image_dataloader)
    images_loader, labels = dataiter.next()


    # Define loss and optimizers
    # train_params = list(encoder.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=500
        )

    # Train model on training set
    # reset_encoder_training = encoder.training
    # encoder.train()
    decoder.train()

    loss_arr = []
    for epoch in tqdm_notebook(range(num_epochs)):
        total_loss = 0
        num_total = 0
        for batch_idx, X in enumerate(train_dataloader):
            optimizer.zero_grad()
            recon_X_logits = decoder(X.to(device))


            start_idx = int(batch_idx*batch_size)
            end_idx = int(batch_size*(batch_idx + 1) - 1)

            if batch_idx == len(train_dataloader) - 1:
              image_labels = labels_df.loc[start_idx:, "frame"].tolist()
            else: 
              image_labels = labels_df.loc[start_idx: end_idx, "frame"].tolist()
            
            ims = np.zeros((X.shape[0], 1, 256, 256))

            for indx, image_label in enumerate(image_labels):
              ims[indx, :, :, :] = images_loader[int(image_label)].sum(0).unsqueeze(0)
            ims = torch.Tensor(ims)


            loss = vae_loss_function(
                recon_X_logits=recon_X_logits, X=ims.to(device), beta=beta
                )
            total_loss += loss.item()
            num_total += len(recon_X_logits)
            loss.backward()
            optimizer.step()
            if verbose and epoch % 10 == 9 and batch_idx == 0:
                num_images = 5
                # encoder.eval()
                decoder.eval()
                with torch.no_grad():
                    input_imgs = X[:num_images].detach().cpu().numpy()
                    output_imgs = decoder.reconstruct(
                                  input_imgs.to(device))\
                                  .detach().cpu().numpy()
                # encoder.train()
                decoder.train()

                title = (f"Epoch {epoch}, batch {batch_idx}, "
                    f"loss {loss.item():.2f}")
                plot_util.plot_dsprite_image_doubles(
                    list(input_imgs), list(output_imgs), "Reconstr.",
                    title=title)

        loss_arr.append(total_loss / num_total)
        scheduler.step()

    # set final decoder state and reset original encoder state
    decoder.train()
    decoder.cpu()
    # if reset_encoder_training:
    #     encoder.train()
    # else:
    #     encoder.eval()
    # encoder.to(reset_encoder_device)

    return decoder, loss_arr

In [None]:
images = get_images(labels_df)

In [None]:
parent_dir = 'decoding_datasets'

transformations = [transforms.ToTensor(),                 
                  transforms.Resize((256, 256)),
                  transforms.ToPILImage()
                  ]

image_labels = list(labels_df["natural_scenes"]["frame"].unique())

if -1 in image_labels:
  image_labels.remove(-1)

for image_index in tqdm_notebook(image_labels):

  directory = f"image{image_index}"
  path = os.path.join(parent_dir, directory)

  if not os.path.exists(path):
    os.makedirs(path) 
  
  im = Image.fromarray(images[image_index])
  transform = transforms.Compose(transformations)
  transformed_image = transform(im)
  im = transform(im)
  im.save(path+f'/image_{image_index}.jpeg')

In [None]:
image_dataset = ImageFolder('decoding_datasets',
                              transform=train_transform)


In [None]:
temp_labels_df = labels_df["natural_scenes"][labels_df["natural_scenes"]["frame"] != -1].reset_index()

In [None]:
decoder, loss = train_vae(encoded_feature_vectors, temp_labels_df, image_dataset, dataset=None, train_sampler=None, num_epochs=10, 
                          batch_size=500, beta=1.0,  use_cuda=True, verbose=False)