In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import cebra
from PIL import Image
import cv2
import os
import torch
import torch.nn.functional as F
from torch import nn
import itertools
from torch.utils.tensorboard import SummaryWriter

In [2]:
output_model_path = 'models/cebra_model_complete.pt'
neural_data_directory= []
behavior_data_directory = []

In [3]:
neural_data_paths = [ 'MV1_run_30hz_30frame_brain2behav_DFF_new/brain/' + \
                     file for file in os.listdir('MV1_run_30hz_30frame_brain2behav_DFF_new/brain/')]

behavior_data_paths = [ 'MV1_run_30hz_30frame_brain2behav_DFF_new/camera1/' + \
                     file for file in os.listdir('MV1_run_30hz_30frame_brain2behav_DFF_new/camera1/')]

dino_paths = [ 'MV1_run_30hz_30frame_brain2behav_DFF_new/dino/' + \
                        file for file in os.listdir('MV1_run_30hz_30frame_brain2behav_DFF_new/dino/')]

In [4]:

def process_brain(brain_seq):
  brain_seq = np.array(brain_seq)
  flat_seq = np.array([(brain_frame.flatten()) for brain_frame in brain_seq])
  return flat_seq.astype(float)


## Loads data from a folder of TIF files
# filepath: path to folder
# processor: function to process each image
# max: max images to load as a proportion of array size
# min: min images to load as a proportion of array size
# returns: list of processed images, list of filenames
def import_data(filepath, processor, min = 0, max = 1):
    output_data = []
    output_name = []
    min_index = int(min * len(os.listdir(filepath)))
    max_index = int(max * len(os.listdir(filepath)))
    for file in itertools.islice(os.listdir(filepath), min_index, max_index):
     filename = os.fsdecode(file)
     if filename.endswith(".tif"):
         out = cv2.imreadmulti(filepath + '/' + filename)[1]
         output_data.append(processor(out))
         output_name.append(filename)
     else:
         continue
    return output_data, output_name

In [13]:
#Getting DINO embeddings from behavior data

# from https://huggingface.co/facebook/dino-vits8
# testing transformer
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests

processor = ViTImageProcessor.from_pretrained('facebook/dino-vits8')
vit_model = ViTModel.from_pretrained('facebook/dino-vits8')
device = "cuda:0" if torch.cuda.is_available() else "cpu"
vit_model = vit_model.to(device)


## Given a DINO model and an image, return the DINO embedding
# model: the DINO model
# image: a PIL image
def get_features(model, image):
  return model(**processor(images=image, return_tensors="pt").to(device)).pooler_output.cpu().detach().numpy()

## Convert a numpy array to a PIL image
# numpy_image: a numpy array
# returns: a PIL image
def np_to_PIL(numpy_image):
    return Image.fromarray(np.uint8(numpy_image)).convert('RGB')

## Given a sequence of behavior frames, return a sequence of DINO embeddings
# behavior_video: a sequence of behavior frames
# model: the DINO model
# returns: a sequence of DINO embeddings
def get_dino_embeddings(behavior_video, model):
  behavior_video = np.array(behavior_video)
  feature_sequence = []
  for frame in behavior_video:
    feature_sequence.append(get_features(model, np_to_PIL(frame)))
  return np.array(feature_sequence)

# Get DINO embeddings for a set behavior data
def get_dino_embeddings_array(behavior_data, model):
  dino_embeddings = []
  for behavior_video in behavior_data:
    dino_embeddings.append(np.squeeze(get_dino_embeddings(behavior_video, model)))
  return dino_embeddings

Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vits8 and are newly initialized: ['pooler.dense.weight', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
def normalize_array(in_array):
    return np.array([x / np.linalg.norm(x) for x in in_array])

def flatten_data(data):
    data_flat = np.squeeze(data[0])
    for x in data[1::]:
        data_flat = np.concatenate((data_flat, np.squeeze(x)))
    return data_flat


In [21]:
def save_dino_embeddings(behavior_data, name_data, model, output_path):
  for i, video in enumerate(behavior_data):
    dino_embeddings = normalize_array(get_dino_embeddings(video, model))
    print(output_path + '/' + name_data[i] + '.npy')
    np.save(output_path + '/' + name_data[i] + '.npy', dino_embeddings)
  return dino_embeddings

def process_behavior_data(behavior_paths, output_paths, model):
  for paths in zip(behavior_paths, output_paths):
    behavior_data_temp, name_data_temp = import_data(paths[0], lambda x : x, max=1)
    print('Saving DINO embeddings for ' + paths[0])
    save_dino_embeddings(behavior_data_temp, name_data_temp, model, paths[1])


In [22]:
process_behavior_data(behavior_data_paths, dino_paths, vit_model)

Saving DINO embeddings for MV1_run_30hz_30frame_brain2behav_DFF_new/camera1/2021_1_12_MV1_run
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/nomove_0015_0002.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/nomove_0028_0011.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/nomove_0039_0012.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/move_0015_0008.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/move_0072_0006.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/move_0056_0003.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/nomove_0013_0017.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/nomove_0044_0006.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/move_0017_0002.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2021_1_12_MV1_run/nomove_0066_0011.tif.npy
MV1_run_30hz_30frame_bra

[ WARN:0@15118.548] global grfmt_tiff.cpp:838 readData OpenCV TIFF(line 838): failed TIFFReadRGBAStrip(tif, y, (uint32*)src_buffer)
imreadmulti_('MV1_run_30hz_30frame_brain2behav_DFF_new/camera1/2020_11_2_MV1_run/nomove_0030_0014.tif'): can't read data: OpenCV(4.7.0) /io/opencv/modules/imgcodecs/src/grfmt_tiff.cpp:838: error: (-2:Unspecified error) OpenCV TIFF: failed TIFFReadRGBAStrip(tif, y, (uint32*)src_buffer) in function 'readData'



Saving DINO embeddings for MV1_run_30hz_30frame_brain2behav_DFF_new/camera1/2020_11_2_MV1_run
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/nomove_0015_0002.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/nomove_0039_0012.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/move_0012_0010.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/move_0056_0003.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/move_0042_0003.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/move_0034_0010.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/nomove_0044_0006.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/move_0017_0002.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/nomove_0008_0011.tif.npy
MV1_run_30hz_30frame_brain2behav_DFF_new/dino/2020_11_2_MV1_run/nomove_0039_0008.tif.npy
MV1_run_30hz_30frame_brain

In [None]:
def init_dataloader(brain_data, feature_data, num_steps, time_offset, conditional, batch_size=1, cebra_offset=None ):
    datasets = []
    for session in zip(brain_data, feature_data):
        brain_data_tensor  = torch.FloatTensor(session[0])
        feature_data_tensor = torch.FloatTensor(session[1])
        datasets.append(cebra.data.datasets.TensorDataset(brain_data_tensor, continuous=feature_data_tensor, offset=cebra_offset))
    dataset_collection = cebra.data.datasets.DatasetCollection(*datasets)
    return cebra.data.multi_session.ContinuousMultiSessionDataLoader(
        dataset=dataset_collection,
        batch_size=batch_size,
        num_steps=num_steps,
        time_offset=time_offset,
        conditional=conditional,
    )   



In [None]:
def load_partial_data(brain_paths, behvaior_paths, min, max, num_steps, time_offset, conditional, batch_size, cebra_offset):
    brain_data, feature_data = [], []
    for path in zip(brain_paths, behvaior_paths):
        brain_data_temp, _ = import_data(path[0], process_brain, min, max)
        behavior_data_temp, _ = import_data(path[1], lambda x: x, min, max)
        print('calculating dino embeddings')
        feature_data_temp = flatten_data(get_dino_embeddings_array(behavior_data_temp, vit_model))
        brain_data.append(flatten_data(brain_data_temp))
        feature_data.append(normalize_array(feature_data_temp))
    return init_dataloader(brain_data, feature_data, num_steps, time_offset, conditional, batch_size, cebra_offset)

In [None]:
## Creat and train the model in partial batches of data
def partial_train(brain_paths, behavior_paths, min, max, valid_max, slice_size, num_steps, time_offset, conditional, batch_size, cebra_offset, hidden_units, output_dimension, model_name, device):
    ## Load dataloader for first slice of data
    dataloader= load_partial_data(brain_paths, behavior_paths, min, min+slice_size, num_steps, time_offset, conditional, batch_size, cebra_offset)
    validation_loader = load_partial_data(brain_paths, behavior_paths, max, max+valid_max, num_steps, time_offset, conditional, batch_size, cebra_offset)
    ## Create a summary writer for tensorboard
    writer = SummaryWriter()
    
    ## create list of models
    model = torch.nn.ModuleList([
    cebra.models.init(model_name, dataset.input_dimension,
                        hidden_units, output_dimension, True)
    for dataset in dataloader.dataset.iter_sessions()
    ]).to(device)
    ## Load optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    ## Load criterion
    criterion = cebra.models.criterions.LearnableCosineInfoNCE(temperature=1.0, min_temperature=None)

    ## Load solver and train on first slice of data
    solver = cebra.solver.MultiSessionSolver(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        tqdm_on=True
    )
    writer.add_graph(model, dataloader.dataset[0][0])
    print('Training on slice 1')
    solver.fit(dataloader, valid_loader = validation_loader)
    torch.save(model.state_dict(), output_model_path)
    for i in range(2, int((max-min)/slice_size)+1):
        ## Load next slice of data
        dataloader= load_partial_data(brain_paths, behavior_paths, min+slice_size*(i-1), min+slice_size*i, num_steps, time_offset, conditional, batch_size, cebra_offset)
        ## Train on next slice of data
        print('Training on slice '+str(i))
        solver.fit(dataloader, valid_loader = validation_loader)
        torch.save(model.state_dict(), output_model_path)
    print('Training complete, saving model')
    writer.close()
    torch.save(model.state_dict(), output_model_path)


In [None]:
partial_train(neural_data_paths,
            behavior_data_paths,
            min=0,
            max=0.05,
            valid_max=0.1,
            slice_size=0.05,
            num_steps=20,
            time_offset=20,
            conditional='time_delta',
            batch_size=1024,
            cebra_offset=cebra.data.datatypes.Offset(0,1),
            hidden_units=128,
            output_dimension=16,
            model_name='offset1-model',
            device='cuda:1'  
          )

In [None]:
# output_dimension = 8
# model_name = 'offset5-model'
# hidden_units = 128
# temp = 0.1
# device = 'cuda:1'
# ## Load data into CEBRA format
# #convert data to tensors
# #data must be in float tensor format
# datasets = []
# for session in zip(brain_data, features):
#     brain_data_t = torch.FloatTensor(session[0])
#     features_t = torch.FloatTensor(session[1])
#     datasets.append(cebra.data.datasets.TensorDataset(brain_data_t, continuous=features_t, offset=cebra.data.datatypes.Offset(3,2)))

# dataset_collection = cebra.data.datasets.DatasetCollection(*datasets)

# dataloader = cebra.data.multi_session.ContinuousMultiSessionDataLoader(
#     dataset=dataset_collection, num_steps=200,
#     batch_size=512,
#     conditional='time_delta',
#     time_offset=10
#     )

# ## Load CEBRA model
# #load model
# model = torch.nn.ModuleList([
#     cebra.models.init(model_name, dataset.input_dimension,
#                         hidden_units, output_dimension, True)
#     for dataset in dataloader.dataset.iter_sessions()
# ]).to(device)

# ## Load  Adam Optimizer
# #load optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# ## Load criterion
# criterion = cebra.models.criterions.LearnableCosineInfoNCE(temperature=1.0, min_temperature=None)

# ## Load solver
# solver = cebra.solver.MultiSessionSolver(model=model, criterion=criterion, optimizer=optimizer, tqdm_on=True)