In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import cebra
from PIL import Image
import cv2
import os
import torch
import torch.nn.functional as F
from torch import nn
import itertools
from torch.utils.tensorboard import SummaryWriter
import random
import gc

In [2]:
#check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
output_model_path = 'models/cebra_model_complete.pt'
neural_data_directory= []
behavior_data_directory = []

In [4]:
data_directory = '/mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/'
neural_data_paths = [ data_directory + 'brain/' + \
                     file for file in os.listdir(data_directory + 'brain/')]

behavior_data_paths = [  data_directory + 'camera1/' + \
                     file for file in os.listdir(data_directory + 'brain/')]

dino_paths = [ data_directory + 'dino/' + \
                        file for file in os.listdir(data_directory + 'brain/')]

In [5]:

def process_brain(brain_seq):
  brain_seq = np.array(brain_seq)
  flat_seq = np.array([(brain_frame.flatten()) for brain_frame in brain_seq])
  return flat_seq.astype(float)


## Loads data from a folder of TIF files
# filepath: path to folder
# processor: function to process each image
# max: max images to load as a proportion of array size
# min: min images to load as a proportion of array size
# returns: list of processed images, list of filenames
def import_data(filepath, processor, min = 0, max = 1):
    output_data = []
    output_name = []
    path_list = os.listdir(filepath)
    path_list.sort()
    random.Random(4).shuffle(path_list)
    min_index = int(min * len(path_list))
    max_index = int(max * len(path_list))
    for file in itertools.islice(path_list, min_index, max_index):
     filename = os.fsdecode(file)
     if filename.endswith(".tif"):
         out = cv2.imreadmulti(filepath + '/' + filename)[1]
         output_data.append(processor(out))
         output_name.append(filename.split('.')[0])
     elif filename.endswith(".npy"):
         output_data.append(processor(np.load(filepath + '/' + filename)))
         output_name.append(filename.split('.')[0])
     else:
         continue
    return output_data, output_name

In [6]:
def normalize_array(in_array):
    return np.array([x / np.linalg.norm(x) for x in in_array])

def flatten_data(data):
    return np.concatenate(data, axis=0)


In [7]:
def init_dataloader(brain_data, feature_data, num_steps, time_offset, conditional, batch_size=1, cebra_offset=None ):
    datasets = []
    print('loading data')
    for session in zip(brain_data, feature_data):
        brain_data_tensor  = torch.FloatTensor(session[0])
        feature_data_tensor = torch.FloatTensor(session[1])
        datasets.append(cebra.data.datasets.TensorDataset(brain_data_tensor, continuous=feature_data_tensor, offset=cebra_offset))
    dataset_collection = cebra.data.datasets.DatasetCollection(*datasets)
    return cebra.data.multi_session.ContinuousMultiSessionDataLoader(
        dataset=dataset_collection,
        batch_size=batch_size,
        num_steps=num_steps,
        time_offset=time_offset,
        conditional=conditional,
    )   



In [8]:
def load_partial_data(brain_paths, behvaior_paths, min, max, num_steps, time_offset, conditional, batch_size, cebra_offset):
    brain_data, feature_data = [], []
    for path in zip(brain_paths, behvaior_paths):
        print('importing from: ' + path[0])
        brain_data_temp, _ = import_data(path[0], process_brain, min, max)
        feature_data_temp, _ = import_data(path[1], lambda x: np.squeeze(x), min, max)
        brain_data.extend((brain_data_temp))
        feature_data.extend((feature_data_temp))
        del brain_data_temp
        del feature_data_temp
        del _
        gc.collect()
    return init_dataloader(brain_data, feature_data, num_steps, time_offset, conditional, batch_size, cebra_offset)

In [9]:
## Creat and train the model in partial batches of data
def partial_train(brain_paths, behavior_paths, min, max, slice_size, num_steps, time_offset, conditional, batch_size, cebra_offset, hidden_units, output_dimension, model_name, device, saved_model = None):
    ## Load dataloader for first slice of data
    print('Loading data')
    dataloader= load_partial_data(brain_paths, behavior_paths, min, min+slice_size, num_steps, time_offset, conditional, batch_size, cebra_offset)
    print('Creating model')
    ## create list of models
    model = torch.nn.ModuleList([
    cebra.models.init(model_name, dataset.input_dimension,
                        hidden_units, output_dimension, True)
    for dataset in dataloader.dataset.iter_sessions()
    ]).to(device)
    if saved_model is not None:
        model.__setstate__(saved_model)

    ## Load optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    ## Load criterion
    criterion = cebra.models.criterions.LearnableCosineInfoNCE(temperature=1.0, min_temperature=0.01)

    print('Loading solver')
    ## Load solver and train on first slice of data
    solver = cebra.solver.MultiSessionSolver(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        tqdm_on=True,
    )
    print('Training on slice 1')
    solver.fit(dataloader,
                save_frequency=500,
                logdir='runs',)
    solver.save('models/'+model_name+'_slice1.pt')
    for i in range(2, int((max-min)/slice_size)+1):
        ## Load next slice of data
        dataloader= load_partial_data(brain_paths, behavior_paths, min+slice_size*(i-1), min+slice_size*i, num_steps, time_offset, conditional, batch_size, cebra_offset)
        ## Train on next slice of data
        print('Training on slice '+str(i))
        solver.fit(dataloader,
                save_frequency=5000,
                logdir='runs',)        
        torch.save(model, output_model_path)
    print('Training complete, saving model')
    solver.save(output_model_path)
    return solver


In [10]:
model = partial_train([neural_data_paths[0]],
            [dino_paths[0]],
            min=0.4,
            max=0.8,
            slice_size=0.4,
            num_steps=2500,
            time_offset=10,
            conditional='time_delta',
            batch_size=128,
            cebra_offset=cebra.data.datatypes.Offset(0,1),
            hidden_units=128,
            output_dimension=8,
            model_name='offset1-model',
            device='cuda:0'  
          )

Loading data
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2021_1_8_MV1_run


loading data
Creating model
Loading solver
Training on slice 1


pos:  0.2499 neg:  10.6273 total:  10.8772 temperature:  1.0000:   6%|▌         | 146/2500 [23:01<6:10:47,  9.45s/it]

In [None]:
checkpoint = torch.load('runs/checkpoint_0009990.pth')

In [None]:
[1]

[1]

In [None]:
# model = partial_train(neural_data_paths,
#             dino_paths,
#             min=0.4,
#             max=0.8,
#             valid_max=0.2,
#             slice_size=0.4,
#             num_steps=10000,
#             time_offset=15,
#             conditional='time_delta',
#             batch_size=4096,
#             cebra_offset=cebra.data.datatypes.Offset(0,1),
#             hidden_units=128,
#             output_dimension=16,
#             model_name='offset1-model',
#             device='cuda:0',
#             saved_model = checkpoint['model']
#           )

Loading data
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2021_1_8_MV1_run


importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2020_11_23_MV1_run
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2020_12_4_MV1_run
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2020_11_2_MV1_run
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2021_1_12_MV1_run
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2020_12_10_MV1_run
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2020_11_9_MV1_run
importing from: /mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/brain/2020_11_17_MV1_run
loading data
Creating model
Loading solver
Training on slice 1


pos:  0.2769 neg:  9.5146 total:  9.7915 temperature:  1.0000: 100%|██████████| 10000/10000 [10:48:38<00:00,  3.89s/it]


Training complete, saving model
