In [14]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import cebra
from PIL import Image
import cv2
import os
import torch
import torch.nn.functional as F
from torch import nn
import itertools
from torch.utils.tensorboard import SummaryWriter
import random
import gc

In [15]:
output_model_path = 'models/cebra_model_complete.pt'
data_directory = '/mnt/teams/Tsuchitori/Allen-movie/'
neural_data_paths = [ data_directory + 'test_set/']

behavior_data_paths = [  data_directory + 'movie/']

dino_paths = [ data_directory + 'dino/']

In [16]:
def process_brain(brain_seq):
  brain_seq = np.array(brain_seq)
  flat_seq = np.array([(brain_frame.flatten()) for brain_frame in brain_seq])
  return flat_seq.astype(float)


## Loads data from a folder of TIF files
# filepath: path to folder
# processor: function to process each image
# max: max images to load as a proportion of array size
# min: min images to load as a proportion of array size
# returns: list of processed images, list of filenames
def import_data(filepath, processor, min = 0, max = 1):
    output_data = []
    output_name = []
    path_list = os.listdir(filepath)
    path_list.sort()
    random.Random(4).shuffle(path_list)
    min_index = int(min * len(path_list))
    max_index = int(max * len(path_list))
    for file in itertools.islice(path_list, min_index, max_index):
     filename = os.fsdecode(file)
     if filename.endswith(".tif"):
         out = cv2.imreadmulti(filepath + '/' + filename)[1]
         output_data.append(processor(out))
         output_name.append(filename.split('.')[0])
     elif filename.endswith(".npy"):
         output_data.append(processor(np.load(filepath + '/' + filename)))
         output_name.append(filename.split('.')[0])
     else:
         continue
    return output_data, output_name

def normalize_array(in_array):
    return np.array([x / np.linalg.norm(x) for x in in_array])

def flatten_data(data):
    return np.concatenate(data, axis=0)

In [17]:
#Getting DINO embeddings from behavior data

# from https://huggingface.co/facebook/dino-vits8
# testing transformer
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests

processor = ViTImageProcessor.from_pretrained('facebook/dino-vits8')
vit_model = ViTModel.from_pretrained('facebook/dino-vits8')
device = "cuda:0" if torch.cuda.is_available() else "cpu"
vit_model = vit_model.to(device)


## Given a DINO model and an image, return the DINO embedding
# model: the DINO model
# image: a PIL image
def get_features(model, image):
  return model(**processor(images=image, return_tensors="pt").to(device)).pooler_output.cpu().detach().numpy()

## Convert a numpy array to a PIL image
# numpy_image: a numpy array
# returns: a PIL image
def np_to_PIL(numpy_image):
    return Image.fromarray(np.uint8(numpy_image)).convert('RGB')

## Given a sequence of behavior frames, return a sequence of DINO embeddings
# behavior_video: a sequence of behavior frames
# model: the DINO model
# returns: a sequence of DINO embeddings
def get_dino_embeddings(behavior_video, model):
  behavior_video = np.array(behavior_video)
  feature_sequence = []
  for frame in behavior_video:
    feature_sequence.append(get_features(model, np_to_PIL(frame)))
  return np.array(feature_sequence)

# Get DINO embeddings for a set behavior data
def get_dino_embeddings_array(behavior_data, model):
  dino_embeddings = []
  for behavior_video in behavior_data:
    dino_embeddings.append(np.squeeze(get_dino_embeddings(behavior_video, model)))
  return dino_embeddings

Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vits8 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
def save_dino_embeddings(behavior_data, name_data, model, output_path):
  for i, video in enumerate(behavior_data):
    # check if embeddings exist first
    if output_path + '/' + name_data[i] + '.npy' in os.listdir(output_path):
      print('embeddings already found')
      pass
    else:
      dino_embeddings = normalize_array(get_dino_embeddings(video, model))
      print(output_path + '/' + name_data[i] + '.npy')
      np.save(output_path + '/' + name_data[i] + '.npy', dino_embeddings)
  return dino_embeddings

def process_behavior_data(behavior_paths, output_paths, model):
  for paths in zip(behavior_paths, output_paths):
    behavior_data_temp, name_data_temp = import_data(paths[0], lambda x : x, max=1)
    print('Saving DINO embeddings for ' + paths[0])
    save_dino_embeddings(behavior_data_temp, name_data_temp, model, paths[1])

In [19]:
process_behavior_data(behavior_data_paths, dino_paths, vit_model)

Saving DINO embeddings for /mnt/teams/Tsuchitori/Allen-movie/movie/
/mnt/teams/Tsuchitori/Allen-movie/dino//natural_movie_one.npy


In [20]:
for path in neural_data_paths:
    _, names = import_data(path, lambda x : x, 0.0, 0.8)
    labels = np.array([   'move' == name.split('_')[0] for name in names]).astype(int)
    # save each name and path to a txt file
    with open('training_names/' + path.split('/')[-1] + '.txt', 'w') as f:
        for name, label in zip(names, labels):
            f.write(name + ' ' + str(label) + '\n')

for path in neural_data_paths:
    _, names = import_data(path, lambda x : x, 0.8, 1)
    labels = np.array([   'move' == name.split('_')[0] for name in names]).astype(int)
    # save each name and path to a txt file
    with open('validation_names/' + path.split('/')[-1] + '.txt', 'w') as f:
        for name, label in zip(names, labels):
            f.write(name + ' ' + str(label) + '\n')