In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import cebra
from PIL import Image
import cv2
import torch
import pickle
import os
import random

In [None]:
brain_path = '2020_11_9_MV1_run_brain'
behavior_path = '2020_11_9_MV1_run_behavior'
feature_label_path = 'feature_labels'
cebra_model_path = 'cebra_multi_model2.pt'
output_folder_path = 'output_videos4'

In [None]:
use_separate_test_set = True
# If use_separate_test_set is false then the following is used to define the cutoff between training and test sets
test_set_size = 0.2
# Otherwise load the test set from these paths
brain_path_test = '2020_12_4_MV1_run_brain'
behavior_path_test = '2020_12_4_MV1_run_behavior'
#define cutoff for loading test set to save memory
test_cutoff = 100


In [None]:
def process_brain(brain_seq):
  try:
    brain_seq = np.array(brain_seq)
    brain_mask = (np.sum(brain_seq, axis=0) > 0)
    flat_seq = (brain_seq[:, brain_mask])
    return flat_seq.astype(float)
  except:
    print(np.shape(brain_seq))

def import_data(filepath, processor, max = -1):
    output_data = []
    output_name = []
    for iter, file in enumerate(os.listdir(filepath)):
     filename = os.fsdecode(file)
     if filename.endswith(".tif"):
         out = cv2.imreadmulti(filepath + '/' + filename)[1]
         output_data.append(processor(out))
         output_name.append(filename)
         if iter > max and max > 0: break
         continue
     else:
         continue
    return output_data, output_name

def flatten_data(data):
    data_flat = np.squeeze(data[0])
    for x in data[1::]:
        data_flat = np.concatenate((data_flat, np.squeeze(x)))
    return data_flat

def generate_CEBRA_embeddings(neural, name, model, session = 'run'):
    embedding = []
    failed = []
    for run, data in enumerate(neural):
        try:
            if session == 'run':
                embedding.append(model.transform(data, session_id=run))
            else:
                embedding.append(model.transform(data, session_id=1))
        except:
            failed.append(run)
            print(run)
    failed.reverse()
    for index in failed:
        del name[index]
        del neural[index]
    return embedding

In [None]:
# Load brain and behavior data
brain_data, name_data = import_data(brain_path, process_brain)
behavior_data, _ = import_data(behavior_path, lambda x: x)

In [None]:
if use_separate_test_set:
    brain_data_test, name_data_test = import_data(brain_path_test, process_brain, max=test_cutoff)
    behavior_data_test, _ = import_data(behavior_path_test, lambda x: x, max=test_cutoff)


In [None]:
#load feature labels
with open(feature_label_path, 'rb') as f:
    feature_labels = pickle.load(f)

In [None]:
#get neural embeddings using cebra model
model = cebra.CEBRA.load(cebra_model_path)
neural_embeddings = generate_CEBRA_embeddings(brain_data, name_data, model, session = 'run')

In [None]:
if use_separate_test_set:
    neural_embeddings_test = generate_CEBRA_embeddings(brain_data_test, name_data_test, model, session = 'run')

In [None]:
if use_separate_test_set == False:
    train_cutoff = int(len(neural_embeddings) * (1 - test_set_size))
    train_neural_embeddings = neural_embeddings[:train_cutoff]
    train_names = name_data[:train_cutoff]
    train_behavior = behavior_data[:train_cutoff]
    train_feature_labels = feature_labels[:train_cutoff]
    predict_neural_embeddings = neural_embeddings[train_cutoff:]
    predict_names = name_data[train_cutoff:]
    predict_behavior = behavior_data[train_cutoff:]
else:
    train_cutoff = int(len(neural_embeddings) * (1 - test_set_size))
    train_neural_embeddings = neural_embeddings[:train_cutoff]
    train_names = name_data[:train_cutoff]
    train_behavior = behavior_data[:train_cutoff]
    train_feature_labels = feature_labels[:train_cutoff]
    predict_neural_embeddings = neural_embeddings_test
    predict_names = name_data_test
    predict_behavior = behavior_data_test
    del behavior_data_test
    del behavior_data

In [None]:
#separate into train and predict groups, and flatten data
train_neural_embeddings_flat = flatten_data(train_neural_embeddings)
train_behavior_flat = flatten_data(train_behavior)
train_feature_labels_flat = flatten_data(train_feature_labels)

In [None]:
predict_neural_embeddings_flat = flatten_data(predict_neural_embeddings)

In [None]:
#fit kNN regressor to flattened training data
image_decoder = cebra.KNNDecoder(n_neighbors=20, metric="cosine")
image_decoder.fit(train_neural_embeddings_flat, train_feature_labels_flat)


In [None]:
#predict feature label vector from neural embedding
predict_feature_labels_flat = image_decoder.predict(predict_neural_embeddings_flat)

In [None]:
def normalize_array(in_array):
    return np.array([x / np.linalg.norm(x) for x in in_array])

def match_frame_to_embeddings(predicted_embedding, embedding_train, image_train):
  cos_dist = np.matmul(embedding_train, predicted_embedding.T)
  index_list = np.argmax(cos_dist, axis=0)
  return image_train[index_list]

In [None]:
#normalize predicted feature label vector and train feature label vector\
predict_feature_labels_flat = normalize_array(predict_feature_labels_flat)
train_feature_labels_flat = normalize_array(train_feature_labels_flat)

In [None]:
# generate series of predicted images
predicted_frames = match_frame_to_embeddings(predict_feature_labels_flat, train_feature_labels_flat, train_behavior_flat)

In [None]:
#run Agglomerative clustering to predict class labels
from sklearn.cluster import AgglomerativeClustering
clustering = AgglomerativeClustering(n_clusters=3, metric='cosine', linkage='average' ).fit(predict_neural_embeddings_flat)
predicted_labels_flat = clustering.labels_

In [None]:
def reshape_frames(frames, shape_ref):
    shape_list = [np.shape(x)[0] for x in shape_ref]
    gen_video_list = []
    index = 0
    for shape in shape_list:
        gen_video_list.append((frames[index : index + shape]))
        index += shape
    return gen_video_list

#choose a random window of set size from the data deterministically based on seed
def choose_random_window( window_size, seed, data):
    random.seed(seed)
    start = random.randint(0, len(data) - window_size)
    return data[start:start+window_size]


def display_frames_as_video(frames, ground_truth, frame_rate, name, labels, label_dict):
    fontScale = 1
    org = (50, 50)
    font = cv2.FONT_HERSHEY_SIMPLEX
    thickness = 2
    # Get the dimensions of the frames
    frame_height, frame_width = frames[0].shape
    # Classify video based on median label of all frames
    label = np.round(np.median(labels))
    label_class = label_dict[label]
    # Create a VideoWriter object to write the frames into a video file
    video_writer = cv2.VideoWriter(output_folder_path + '/'+ label_class + '_' + name +'.mp4',
                                   cv2.VideoWriter_fourcc(*'mp4v'),
                                   frame_rate,
                                   (frame_width, 2 * frame_height))

    # Display frames
    for iter, frame in enumerate(zip(frames, ground_truth, labels)):
        # Write the current frame to the video file
        color_frame = cv2.cvtColor(((frame[0]/2 + frames[iter - 1]/2)).astype(np.uint8), cv2.COLOR_GRAY2RGB)
        color_truth = cv2.cvtColor(frame[1].astype(np.uint8), cv2.COLOR_GRAY2RGB)
        combined = np.concatenate((color_frame, color_truth))

        # write corresponding label to video corner
        combined  = cv2.putText(combined, str(label_dict[frame[2]]), org, font, 
                   fontScale, (0,0,255), thickness, cv2.LINE_AA)
        video_writer.write(combined)

    # Release the VideoWriter and close the window
    video_writer.release()
    cv2.destroyAllWindows()

In [None]:
pred_vid_list  = reshape_frames(predicted_frames, predict_behavior)
pred_labels= reshape_frames(predicted_labels_flat, predict_behavior)

In [None]:
label_classes = {0 : 'Still', 1 : 'Running', 2 : 'Sniffing'}

In [None]:
for data in zip(pred_labels, pred_vid_list, predict_names, predict_behavior):
    windowed_labels = choose_random_window(30, data[2], data[0])
    windowed_frames = choose_random_window(30, data[2], data[1])
    windowed_truth = choose_random_window(30, data[2], data[3])
    display_frames_as_video(windowed_frames, windowed_truth, 30, data[2], windowed_labels, label_classes)

In [None]:
from matplotlib.colors import ListedColormap
from sklearn.manifold import TSNE
tsne_embedding = TSNE(n_components=2, n_iter=5000, learning_rate='auto', metric='cosine',
                   init='random', perplexity=30).fit_transform(predict_neural_embeddings_flat)
colors = ListedColormap(['r','b','g'])
classes = ['still', 'running', 'sniffing']
#plot tsne plot as a scatterplot with labels
scatter = plt.scatter(tsne_embedding[:,0], tsne_embedding[:,1], c=predicted_labels_flat, cmap=colors)
plt.legend(handles=scatter.legend_elements()[0], labels=classes)
plt.show()
