In [1]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import MeanShift, estimate_bandwidth, OPTICS
from cebra_utils import *


In [2]:
cebra_model_path = 'models/cebra_model_flattened_offset1.pt'


In [3]:
name_to_model_id = { 
    '2021_1_8_MV1_run' : 0,
    '2020_11_23_MV1_run' : 1,
    '2020_12_4_MV1_run' : 2,
    '2020_11_2_MV1_run' : 3,
    '2021_1_12_MV1_run' : 4,
    '2020_12_10_MV1_run' : 5,
    '2020_11_9_MV1_run' : 6,
    '2020_11_17_MV1_run' : 7,
}

In [4]:
data_directory = '/mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/'
neural_data_paths = [ data_directory + 'brain/' + \
                     file for file in os.listdir(data_directory + 'brain/')]

behavior_data_paths = [  data_directory + 'camera1/' + \
                     file for file in os.listdir(data_directory + 'brain/')]

dino_paths = [ data_directory + 'dino/' + \
                        file for file in os.listdir(data_directory + 'brain/')]

output_folder_paths = [ data_directory + 'output4/' + \
                        file for file in os.listdir(data_directory + 'brain/')]

In [5]:
def match_frame_to_embeddings(predicted_embedding, embedding_train, image_train):
  cos_dist = np.matmul(embedding_train, predicted_embedding.T)
  index_list = np.argmax(cos_dist, axis=0)
  return image_train[index_list]

def predict_embeddings(neural_path, behavior_path,dino_path, validation_cutoff, valid_size, model, label_model, session):
    # Load data
    print('Loading data')
    brain_data, name_data = import_data(neural_path, lambda x: process_brain(x), max=validation_cutoff)
    behavior_data, _ = import_data(behavior_path, lambda x: x, max=validation_cutoff)
    dino_data, _ = import_data(dino_path, lambda x: x, max=validation_cutoff)
    # Generate embeddings
    print('Generating embeddings')
    embeddings = [model[session](torch.from_numpy(np.array(x)).float().to('cuda')).to('cpu').detach().numpy() for i, x in enumerate(brain_data)]


    print('Loading test data')
    # Load the test set of data
    brain_data_test, name_data_test = import_data(neural_path, lambda x: process_brain(x), min = validation_cutoff, max=validation_cutoff + valid_size)
    behavior_data_test, _ = import_data(behavior_path, lambda x: x, min = validation_cutoff, max=validation_cutoff + valid_size)
    dino_data_test, _ = import_data(dino_path, lambda x: x, min = validation_cutoff, max= validation_cutoff + valid_size)
    # Generate embeddings
    print('Generating test embeddings')
    embeddings_test = [model[session](torch.from_numpy(np.array(x)).float().to('cuda')).to('cpu').detach().numpy() for i, x in enumerate(brain_data_test)]

    # Flatten Data
    embeddings_flat = flatten_data(embeddings).squeeze()
    behavior_flat = flatten_data(behavior_data).squeeze()
    dino_flat = flatten_data(dino_data).squeeze()
    embedding_test_flat = flatten_data(embeddings_test).squeeze()
    dino_test_flat = flatten_data(dino_data_test).squeeze()

    print('Running KNN')
    # Create KNN decoder
    decoder = cebra.KNNDecoder(n_neighbors=20, metric="cosine")
    decoder.fit(embeddings_flat, dino_flat)

    # predict
    predicted_dino = decoder.predict(embedding_test_flat)

    # normalize predicted embeddings
    predicted_dino = normalize_array(predicted_dino)

    print('generating labels')
    labels = []
    for label, data in zip(name_data, brain_data):
        if label.split('_')[0] == 'move':
            labels.extend(np.tile(np.array([0, 1]),[len(data), 1]))
        else:
            labels.extend(np.tile(np.array([1, 0]),[len(data), 1]))
    labels = np.array(labels)
    decoder = cebra.KNNDecoder(n_neighbors=72, metric='cosine')
    decoder.fit(embeddings_flat, labels)
    predicted_labels = np.argmax(decoder.predict(embedding_test_flat), axis=1)
    predicted_labels = reshape_frames(predicted_labels, embeddings_test)
    print('generating videos')
    # Match predicted embeddings to images
    predicted_images = match_frame_to_embeddings(predicted_dino, dino_flat, behavior_flat)
    reshaped_predicted_images = reshape_frames(predicted_images, brain_data_test)

    return reshaped_predicted_images, name_data_test, behavior_data_test, predicted_labels


In [6]:
def display_frames_as_video(frames, ground_truth, frame_rate, name, output_folder_path):
    # Get the dimensions of the frames
    frame_height, frame_width = frames[0].shape
    # Classify video based on median label of all frames
    # Create a VideoWriter object to write the frames into a video file
    video_writer = cv2.VideoWriter(output_folder_path + '/' +  name +'.mp4',
                                   cv2.VideoWriter_fourcc(*'mp4v'),
                                   frame_rate,
                                   (frame_width, 2 * frame_height))

    # Display frames
    for iter, frame in enumerate(zip(frames, ground_truth)):
        # Write the current frame to the video file
        color_frame = cv2.cvtColor(frame[0].astype(np.uint8), cv2.COLOR_GRAY2RGB)
        color_truth = cv2.cvtColor(frame[1].astype(np.uint8), cv2.COLOR_GRAY2RGB)
        combined = np.concatenate((color_frame, color_truth))
        video_writer.write(combined)

    # Release the VideoWriter and close the window
    video_writer.release()
    cv2.destroyAllWindows()

In [7]:
def write_generated_video(vid, name, output_folder_path):
    video_writer = cv2.VideoWriter(output_folder_path + '/' + name +'.mp4',
                                   cv2.VideoWriter_fourcc(*'mp4v'),
                                   24,
                                   (64, 64))
    # downsample video to 64x64
    vid_low = [cv2.resize(frame, (64, 64)) for frame in vid]
    for frame in vid_low:
        color_frame = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_GRAY2RGB)
        video_writer.write(color_frame)
    video_writer.release()


In [8]:
key_behavior = {
    0: 'still',
    1: 'sniffing',
    2: 'walking',
    3: 'grooming',
}

In [9]:
solver = torch.load('models/cebra_model_flattened_offset1.pt')
model = solver.model.eval().to('cuda')
label_model = torch.load('models/cebra_classifier_complete.pt')
for i, _ in enumerate(zip(neural_data_paths, behavior_data_paths, dino_paths)):
    pred_images, names, predict_behavior, pred_labels = predict_embeddings(neural_data_paths[i], behavior_data_paths[i], dino_paths[i], 0.8, 0.2, model, label_model[i], name_to_model_id[neural_data_paths[i].split('/')[-1]])
    for vid, name, ground_truth, label in zip(pred_images, names, predict_behavior, pred_labels):
        windowed_frames = choose_first_second(30, vid)
        windowed_truth = choose_first_second(30, ground_truth)
        windowed_labels = choose_first_second(30, label)
        label = np.round(np.median(windowed_labels))
        if label == 0:
            name = name + '_pred_no_move'
        else:
            name = name + '_pred_move'
        display_frames_as_video(windowed_frames, windowed_truth, 24, name, output_folder_paths[i])

Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating labels
generating videos
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating labels
generating videos
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating labels
generating videos
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating labels
generating videos
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating labels
generating videos
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating labels
generating videos
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating labels
generating videos
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating l