In [9]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import MeanShift, estimate_bandwidth, OPTICS
from cebra_utils import *


In [10]:
from torch import nn
import cebra.models
import cebra.data
from cebra.models.model import _OffsetModel, ConvolutionalModelMixin

class ChangeOrderLayer(nn.Module):
    def forward(self, x):
        return x.movedim(-2,1)  # Permute dimensions 1 and 2

@cebra.models.register("convolutional-model-offset11")
class ConvulotionalModel1(_OffsetModel, ConvolutionalModelMixin):

    def __init__(self, num_neurons, num_units, num_output, normalize=True):
        super().__init__(
            ## create a model which goes from a 128 x 128 image to a 1d vector
            ## of length num_output
            ChangeOrderLayer(),
            nn.Conv2d(5, 16, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.MaxPool2d(kernel_size=4, stride=4),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.MaxPool2d(kernel_size=4, stride=4),
            nn.Flatten(),
            nn.Linear(1024, num_output),

            num_input=num_neurons,
            num_output=num_output,
            normalize=normalize,
        )

    # ... and you can also redefine the forward method,
    # as you would for a typical pytorch model

    def get_offset(self):
        return cebra.data.Offset(2, 3)


ValueError: Name convolutional-model-offset11 is already registered for class (<class '__main__.ConvulotionalModel1'>, None).

In [11]:
cebra_model_path = 'models/cebra_model_complete.pt'


In [12]:
data_directory = '/mnt/teams/Tsuchitori/MV1_run_30hz_30frame_brain2behav_DFF_new/'
neural_data_paths = [ data_directory + 'brain/' + \
                     file for file in os.listdir(data_directory + 'brain/')]

behavior_data_paths = [  data_directory + 'camera1/' + \
                     file for file in os.listdir(data_directory + 'brain/')]

dino_paths = [ data_directory + 'dino/' + \
                        file for file in os.listdir(data_directory + 'brain/')]

output_folder_paths = [ data_directory + 'output3/' + \
                        file for file in os.listdir(data_directory + 'brain/')]

In [13]:
def match_frame_to_embeddings(predicted_embedding, embedding_train, image_train):
  cos_dist = np.matmul(embedding_train, predicted_embedding.T)
  index_list = np.argmax(cos_dist, axis=0)
  return image_train[index_list]

def predict_embeddings(neural_path, behavior_path,dino_path, validation_cutoff, valid_size, model, label_model, session):
    # Load data
    print('Loading data')
    brain_data, name_data = import_data(neural_path, lambda x: x, max=validation_cutoff)
    behavior_data, _ = import_data(behavior_path, lambda x: x, max=validation_cutoff)
    dino_data, _ = import_data(dino_path, lambda x: x, max=validation_cutoff)
    # Generate embeddings
    print('Generating embeddings')
    embeddings = [generate_CEBRA_embeddings(model, brain_data_vid, session) for brain_data_vid in brain_data]

    print('Loading test data')
    # Load the test set of data
    brain_data_test, name_data_test = import_data(neural_path, lambda x: x, min = validation_cutoff, max=validation_cutoff + valid_size)
    behavior_data_test, _ = import_data(behavior_path, lambda x: x, min = validation_cutoff, max=validation_cutoff + valid_size)
    dino_data_test, _ = import_data(dino_path, lambda x: x, min = validation_cutoff, max= validation_cutoff + valid_size)
    # Generate embeddings
    print('Generating test embeddings')
    embeddings_test = [generate_CEBRA_embeddings(model, vid_test, session) for vid_test in brain_data_test]

    # Flatten Data
    embeddings_flat = flatten_data(embeddings).squeeze()
    behavior_flat = flatten_data(behavior_data).squeeze()
    dino_flat = flatten_data(dino_data).squeeze()
    embedding_test_flat = flatten_data(embeddings_test).squeeze()
    dino_test_flat = flatten_data(dino_data_test).squeeze()

    print('Running KNN')
    # Create KNN decoder
    decoder = cebra.KNNDecoder(n_neighbors=20, metric="cosine")
    decoder.fit(embeddings_flat, dino_flat)

    # predict
    predicted_dino = decoder.predict(embedding_test_flat)

    # normalize predicted embeddings
    predicted_dino = normalize_array(predicted_dino)

    print('generating videos')
    # Match predicted embeddings to images
    predicted_images = match_frame_to_embeddings(predicted_dino, dino_flat, behavior_flat)
    reshaped_predicted_images = reshape_frames(predicted_images, brain_data_test)

    print('Running Clustering')
    #run Agglomerative clustering to predict class labels
    predicted_labels_flat = label_model(torch.from_numpy(embedding_test_flat)).detach().numpy()
    # turn one hot encoding into labels
    predicted_labels_flat = np.argmax(predicted_labels_flat, axis=1)
    #reshape predicted labels
    predicted_labels = reshape_frames(predicted_labels_flat, brain_data_test)

    return reshaped_predicted_images, name_data_test, predicted_labels, behavior_data_test


In [14]:
def display_frames_as_video(frames, ground_truth, frame_rate, name, labels, label_dict, output_folder_path):
    fontScale = 0.5
    org = (50, 50)
    font = cv2.FONT_HERSHEY_SIMPLEX
    thickness = 1
    # Get the dimensions of the frames
    frame_height, frame_width = frames[0].shape
    # Classify video based on median label of all frames
    label = np.round(np.median(labels))
    label_class = label_dict[label]
    # Create a VideoWriter object to write the frames into a video file
    video_writer = cv2.VideoWriter(output_folder_path + '/'+ str(label_class) + '_' + name +'.mp4',
                                   cv2.VideoWriter_fourcc(*'mp4v'),
                                   frame_rate,
                                   (frame_width, 2 * frame_height))

    # Display frames
    for iter, frame in enumerate(zip(frames, ground_truth, labels)):
        # Write the current frame to the video file
        color_frame = cv2.cvtColor(((frame[0]/2 + frames[iter - 1]/2)).astype(np.uint8), cv2.COLOR_GRAY2RGB)
        color_truth = cv2.cvtColor(frame[1].astype(np.uint8), cv2.COLOR_GRAY2RGB)
        combined = np.concatenate((color_frame, color_truth))

        # write corresponding label to video corner
        combined  = cv2.putText(combined, str(frame[2]) , org, font, 
                   fontScale, (0,0,255), thickness, cv2.LINE_AA)
        video_writer.write(combined)

    # Release the VideoWriter and close the window
    video_writer.release()
    cv2.destroyAllWindows()

In [15]:
key_behavior = {
    0: 'still',
    1: 'sniffing',
    2: 'walking',
    3: 'grooming',
}

In [19]:
solver = torch.load('models/cebra_model_complete.pt')
model = solver.model.eval()
label_model = torch.load('models/cebra_classifier_complete.pt')
for i, _ in enumerate(zip(neural_data_paths, behavior_data_paths, dino_paths)):
    pred_images, names, labels, predict_behavior = predict_embeddings(neural_data_paths[i], behavior_data_paths[i], dino_paths[i], 0.8, 0.2, model, label_model[i], i)
    for data in zip(labels, pred_images, names, predict_behavior):
        windowed_labels = choose_random_window(30, data[2], data[0])
        windowed_frames = choose_random_window(30, data[2], data[1])
        windowed_truth = choose_random_window(30, data[2], data[3])
        display_frames_as_video(windowed_frames, windowed_truth, 30, data[2], windowed_labels, key_behavior, output_folder_paths[i])

Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating videos
Running Clustering
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating videos
Running Clustering
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating videos
Running Clustering
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating videos
Running Clustering
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating videos
Running Clustering
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating videos
Running Clustering
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
generating videos
Running Clustering
Loading data
Generating embeddings
Loading test data
Generating test embeddings
Running KNN
gener