In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
from cebra import CEBRA
from PIL import Image
import cv2
import torch
import pickle
import cebra
import os

In [None]:
from logging import exception
def process_brain(brain_seq):
  try:
    brain_seq = np.array(brain_seq)
    brain_mask = (np.sum(brain_seq, axis=0) > 0)
    flat_seq = (brain_seq[:, brain_mask])
    return flat_seq.astype(float)
  except:
    print(np.shape(brain_seq))

In [None]:
def import_data(filepath, processor, max = -1):
    output_data = []
    output_name = []
    for iter, file in enumerate(os.listdir(filepath)):
     filename = os.fsdecode(file)
     if filename.endswith(".tif"):
         out = cv2.imreadmulti(filepath + '/' + filename)[1]
         output_data.append(processor(out))
         output_name.append(filename)
         if iter > max and max > 0: break
         continue
     else:
         continue
    return output_data, output_name

In [None]:
neural_data, name_data = import_data("2020_11_9_MV1_run_brain", process_brain)

In [None]:
image_data, _ = import_data("2020_11_9_MV1_run_behavior", lambda x : x)

In [None]:
neural_data_test, name_data_test = import_data("2020_12_4_MV1_run_brain", process_brain)

In [None]:
image_data_test, _ = import_data("2020_12_4_MV1_run_behavior", lambda x : x)

In [None]:
for i, name in enumerate(name_data_test):
    if name.split('_')[0] == 'nomove':
        print('del')
        del name_data_test[i]
        del image_data_test[i]
        del neural_data_test[i]

In [None]:
with open('feature_labels', 'rb') as f:
    behav_feature = pickle.load(f)

In [None]:
multi_cebra_model = CEBRA.load('cebra_multi_model2.pt')

In [None]:
def flatten_data(data):
    data_flat = np.squeeze(data[0])
    for x in data[1::]:
        data_flat = np.concatenate((data_flat, np.squeeze(x)))
    return data_flat

In [None]:
def generate_CEBRA_embeddings(neural, name, model, session = 'run'):
    embedding = []
    for run, data in enumerate(neural):
        try:
            if session == 'run':
                embedding.append(model.transform(data, session_id=run))
            else:
                embedding.append(model.transform(data, session_id=1))
        except:
            #del image[run]
            del name[run]
            del neural[run]
            print(run)
    return embedding

In [None]:
neural_embedding = generate_CEBRA_embeddings(neural_data, name_data, multi_cebra_model)

In [None]:
neural_embedding_test = generate_CEBRA_embeddings(neural_data_test, name_data_test, multi_cebra_model, 'test')

In [None]:
cutoff = (int)(len(image_data) * 0.8)
embedding_train = neural_embedding[:(cutoff )]
embedding_pred = neural_embedding[cutoff::]

In [None]:
def flatten_data(data):
    data_flat = (data[0])
    for x in data[1::]:
        data_flat = np.concatenate((data_flat, x))
    return data_flat

In [None]:
image_data_flat = flatten_data(image_data[:cutoff])
neural_embeddings_train_flat = flatten_data(embedding_train)
feature_label_train = np.squeeze(flatten_data(behav_feature[:cutoff]))
name_data_test_flat = flatten_data(name_data_test)

In [None]:
neural_embeddings_test_flat = flatten_data(neural_embedding_test)

In [None]:
image_data_test_flat = flatten_data(image_data_test)

In [None]:
import sklearn.metrics
image_decoder = cebra.KNNDecoder(n_neighbors=20, metric="cosine")
image_decoder.fit(neural_embeddings_train_flat, (feature_label_train))


In [None]:
predicted = image_decoder.predict(neural_embeddings_test_flat)

In [None]:
def normalize_array(in_array):
    return np.array([x / np.linalg.norm(x) for x in in_array])

In [None]:
predicted = normalize_array(predicted)
feature_label_train = normalize_array(feature_label_train)

In [None]:
def match_frame_to_embeddings(predicted_embedding, embedding_train, image_train):
  cos_dist = np.matmul(embedding_train, predicted_embedding.T)
  index_list = np.argmax(cos_dist, axis=0)
  return image_train[index_list]

In [None]:
vid_pred = match_frame_to_embeddings(predicted, feature_label_train, image_data_flat)

In [None]:
first_vid = predicted[0 : 87]
first_pred = match_frame_to_embeddings(first_vid, feature_label_train, image_data_flat)

In [None]:
shape_list = [np.shape(x)[0] for x in neural_embedding_test]
gen_video_list = []
index = 0
for shape in shape_list:
    gen_video_list.append((vid_pred[index : index + shape]))
    index += shape

In [None]:
def display_frames_as_video(frames, ground_truth, frame_rate, name):
    # Get the dimensions of the frames
    frame_height, frame_width = frames[0].shape

    # Create a VideoWriter object to write the frames into a video file
    video_writer = cv2.VideoWriter('output_videos3/' + name +'.mp4',
                                   cv2.VideoWriter_fourcc(*'mp4v'),
                                   frame_rate,
                                   (frame_width, 2 * frame_height))

    # Display frames
    for iter, frame in enumerate(zip(frames, ground_truth)):
        # Write the current frame to the video file
        color_frame = cv2.cvtColor(((frame[0]/2 + frames[iter - 1]/2)).astype(np.uint8), cv2.COLOR_GRAY2RGB)
        color_truth = cv2.cvtColor(frame[1].astype(np.uint8), cv2.COLOR_GRAY2RGB)
        combined = np.concatenate((color_frame, color_truth))
        video_writer.write(combined)

        # Display the frame
        #cv2.imshow(combined)

    # Release the VideoWriter and close the window
    video_writer.release()
    cv2.destroyAllWindows()

In [None]:
display_frames_as_video(first_pred, image_data_pred[0], 24, name_data_pred[0])

In [None]:
for iter, video in enumerate(gen_video_list):
    if name_data_pred.split('_')[0] == 'move':
        display_frames_as_video(video, image_data_pred[iter], 24, name_data_pred[iter])