<h1>
Behavior classification starter kit 🐁🐀
</h1>
This code is adapted from a notebook created by Dipam Chakraborty at AIcrowd for the <a href=https://www.aicrowd.com/challenges/multi-agent-behavior-representation-modeling-measurement-and-applications>Multi-Agent Behavior Challenge</a>.


# Import necessary modules and packages 📚


In [None]:
import pandas as pd
import numpy as np
import os
import json
import urllib.request

# Download the dataset 📲

> Indented block



The CalMS21 dataset is hosted by Caltech at https://data.caltech.edu/records/1991. For now, we'll focus on the Task 1 data, which can be downloaded as follows:

In [None]:
urllib.request.urlretrieve('https://data.caltech.edu/tindfiles/serve/a86f4297-a087-4f40-9ed4-765779105c2c/', 'task1.zip')
urllib.request.urlretrieve('https://data.caltech.edu/tindfiles/serve/ca84a583-ea06-440a-995c-c184bcb0291c/', 'calms21_convert_to_npy.py')
!unzip task1.zip

The dataset files are stored as json files. For ease of handling, we'll first convert them to .npy files using the script we just downloaded, `calms21_convert_to_npy.py`. The output of this script is a pair of files, `calms21_task1_train.npy` and `calms21_task1_test.npy`.

If you include the optional `parse_treba` flag, the script will create files `calms21_task1_train_features.npy` and `calms21_task1_test_features.npy`, which contain 32 features created using <a href=https://openaccess.thecvf.com/content/CVPR2021/html/Sun_Task_Programming_Learning_Data_Efficient_Behavior_Representations_CVPR_2021_paper.html>Task Programming</a>.



In [None]:
!python calms21_convert_to_npy.py  --input_directory '.' --output_directory '.'
!python calms21_convert_to_npy.py  --input_directory '.' --output_directory '.' --parse_treba

#Load the data 💾
The following loader function can be used to unpack the `.npy` files containing your train and test sets.

In [None]:
import numpy as np

def load_task1_data(data_path):
    """ 
    Load data for task 1:
        The vocaubulary tells you how to map behavior names to class ids;
        it is the same for all sequences in this dataset.
    """
    data_dict = np.load(data_path, allow_pickle=True).item()
    dataset = data_dict['annotator-id_0']
    # Get any sequence key.
    sequence_id = list(data_dict['annotator-id_0'].keys())[0]
    vocabulary = data_dict['annotator-id_0'][sequence_id]['metadata']['vocab']
    return dataset, vocabulary


In [None]:
training_data, vocab = load_task1_data('./calms21_task1_train.npy')
test_data, _ = load_task1_data('./calms21_task1_test.npy')

## Dataset Specifications

`training_data` and `test_data` are both dictionaries with a key for each Sequence in the dataset, where a Sequence is a single resident-intruder assay. Each Sequence contains the following fields:

<ul>
<li><b>keypoints</b>: tracked locations of body parts on the two interacting mice. These are produced using a Stacked Hourglass network trained on 15,000 hand-labeled frames.
<ul>
<li>Dimensions: (# frames) x (mouse ID) x (x, y coordinate) x (body part).
<li>Units: pixels; coordinates are relative to the entire image. Original image dimensions are 1024 x 570.
</ul>
<li><b>scores</b>: confidence estimates for the tracked keypoints.
<ul>
<li>Dimensions: (# frames) x (mouse ID) x (body part).
<li>Units: unitless, range 0 (lowest confidence) to 1 (highest confidence).
</ul>
<li> <b>annotations</b>: behaviors id as an integer annotated at each frame by a domain expert. See below for the behavior id to behavior name mappings.
<ul>
<li>Dimensions: (# frames) .
</ul>
<li><b>metadata</b>: The recorded metadata is annotator_id which is represented by an int, and the vocab, containing a dictionary which maps behavior names to integer ids in annotations.
</ul>

The 'taskprog_features' file contains the additional field:

<ul>
<li><b>features</b>: pre-computed features from a model trained with task programming on the trajectory data of the CalMS21 unlabeled videos set.
<ul>
<li>Dimensions: (# frames) x (feature dim = 32).
</li>
</ul>
</ul>

<b>NOTE:</b> for all keypoints, mouse 0 is the resident (black) mouse and mouse 1 is the intruder (white) mouse. There are 7 tracked body parts, ordered (nose, left ear, right ear, neck, left hip, right hip, tail base).

## What does the data look like? 🔍

### Data overview

As described above, our dataset consists of train and test sets, which are both dictionaries of Sequences, and an accompanying vocabulary telling us which behavior is which:

In [None]:
print("Sample dataset keys: ", list(training_data.keys())[:3])
print("Vocabulary: ", vocab)
print("Number of train Sequences: ", len(training_data))
print("Number of test Sequences: ", len(test_data))

### Sample overview
Next let's take a look at one example Sequence:

In [None]:
sequence_names = list(training_data.keys())
sample_sequence_key = sequence_names[0]
single_sequence = training_data[sample_sequence_key]
print("Name of our sample sequence: ", sample_sequence_key)
print("Sequence keys: ", single_sequence.keys())
print("Sequence metadata: ", single_sequence['metadata'])
print(f"Number of Frames in Sequence \"{sample_sequence_key}\": ", len(single_sequence['annotations']))
print(f"Keypoints data shape of Sequence \"{sample_sequence_key}\": ", single_sequence['keypoints'].shape)


# Helper functions for visualization 💁


This cell contains some helper functions that we'll use to create an animation of the mouse movements. You can ignore the contents, but be sure to run it or the next section won't work.

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import colors
from matplotlib import rc
 
rc('animation', html='jshtml')
 
# Note: Image processing may be slow if too many frames are animated.                
 
#Plotting constants
FRAME_WIDTH_TOP = 1024
FRAME_HEIGHT_TOP = 570
 
RESIDENT_COLOR = 'lawngreen'
INTRUDER_COLOR = 'skyblue'
 
PLOT_MOUSE_START_END = [(0, 1), (0, 2), (1, 3), (2, 3), (3, 4),
                        (3, 5), (4, 6), (5, 6), (1, 2)]
 
class_to_color = {'other': 'white', 'attack' : 'red', 'mount' : 'green',
                  'investigation': 'orange'}
 
class_to_number = {s: i for i, s in enumerate(vocab)}
 
number_to_class = {i: s for i, s in enumerate(vocab)}
 
def num_to_text(anno_list):
  return np.vectorize(number_to_class.get)(anno_list)
 
def set_figax():
    fig = plt.figure(figsize=(6, 4))
 
    img = np.zeros((FRAME_HEIGHT_TOP, FRAME_WIDTH_TOP, 3))
 
    ax = fig.add_subplot(111)
    ax.imshow(img)
 
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
 
    return fig, ax
 
def plot_mouse(ax, pose, color):
    # Draw each keypoint
    for j in range(7):
        ax.plot(pose[j, 0], pose[j, 1], 'o', color=color, markersize=5)
 
    # Draw a line for each point pair to form the shape of the mouse
 
    for pair in PLOT_MOUSE_START_END:
        line_to_plot = pose[pair, :]
        ax.plot(line_to_plot[:, 0], line_to_plot[
                :, 1], color=color, linewidth=1)
 
def animate_pose_sequence(video_name, keypoint_sequence, start_frame = 0, stop_frame = 100, 
                          annotation_sequence = None):
    # Returns the animation of the keypoint sequence between start frame
    # and stop frame. Optionally can display annotations.
    seq = keypoint_sequence.transpose((0,1,3,2))
 
    image_list = []
    
    counter = 0
    for j in range(start_frame, stop_frame):
        if counter%20 == 0:
          print("Processing frame ", j)
        fig, ax = set_figax()
        plot_mouse(ax, seq[j, 0, :, :], color=RESIDENT_COLOR)
        plot_mouse(ax, seq[j, 1, :, :], color=INTRUDER_COLOR)
        
        if annotation_sequence is not None:
          annot = annotation_sequence[j]
          annot = number_to_class[annot]
          plt.text(50, -20, annot, fontsize = 16, 
                   bbox=dict(facecolor=class_to_color[annot], alpha=0.5))
 
        ax.set_title(
            video_name + '\n frame {:03d}.png'.format(j))
 
        ax.axis('off')
        fig.tight_layout(pad=0)
        ax.margins(0)
 
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(),
                                        dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(
            fig.canvas.get_width_height()[::-1] + (3,)) 
 
        image_list.append(image_from_plot)
 
        plt.close()
        counter = counter + 1
 
    # Plot animation.
    fig = plt.figure()
    plt.axis('off')
    im = plt.imshow(image_list[0])
 
    def animate(k):
        im.set_array(image_list[k])
        return im,
    ani = animation.FuncAnimation(fig, animate, frames=len(image_list), blit=True)
    return ani
 
def plot_behavior_raster(annotation_sequence, start_frame = 0, stop_frame = 100, title="Behavior Labels"):
  # Plot annotations as a behavior raster
 
  # Map annotations to a number.
  annotation_num = []
  for item in annotation_sequence[start_frame:stop_frame]:
    annotation_num.append(class_to_number[item])
 
  all_classes = list(set(annotation_sequence[start_frame:stop_frame]))
 
  cmap = colors.ListedColormap(['red', 'orange', 'green', 'white'])
  bounds=[-0.5,0.5,1.5, 2.5, 3.5]
  norm = colors.BoundaryNorm(bounds, cmap.N)
 
  height = 200
  arr_to_plot = np.repeat(np.array(annotation_num)[:,np.newaxis].transpose(),
                                                  height, axis = 0)
  
  fig, ax = plt.subplots(figsize = (16, 3))
  ax.imshow(arr_to_plot, interpolation = 'none',cmap=cmap, norm=norm)
 
  ax.set_yticks([])
  ax.set_xlabel('Frame Number')
  plt.title(title)
 
  import matplotlib.patches as mpatches
 
  legend_patches = []
  for item in all_classes:
    legend_patches.append(mpatches.Patch(color=class_to_color[item], label=item))
 
  plt.legend(handles=legend_patches,loc='center left', bbox_to_anchor=(1, 0.5))
 
  plt.tight_layout()

# Visualize the animals' movements 🎥

Let's make some gifs of our sample sequence to get a sense of what the raw data looks like! You can change the values of `start_frame` and `stop_frame` to look around.

In [None]:
#@title
keypoint_sequence = single_sequence['keypoints']
annotation_sequence = single_sequence['annotations']

ani = animate_pose_sequence(sample_sequence_key,
                            keypoint_sequence, 
                            start_frame = 5000,
                            stop_frame = 5100,
                            annotation_sequence = annotation_sequence)

# Display the animaion on colab
ani

### We can also look at a **behavior raster**, which shows what behavior was annotated on each frame of this video.

In [None]:
annotation_sequence = single_sequence['annotations']
text_sequence = num_to_text(annotation_sequence)
 
plot_behavior_raster(
    text_sequence,
    start_frame=0,
    stop_frame=len(annotation_sequence)
)

# Basic exploratory data analysis 🤓
Each Sequence has different amounts of each behavior, depending on what the mice do during the assay. Here, we get the percentage of frames of each behavior in each sequence. We can use this to split the training set into train and validation sets in a stratified way.

In [None]:
def get_percentage(sequence_key):
  anno_seq = num_to_text(training_data[sequence_key]['annotations'])
  counts = {k: np.mean(np.array(anno_seq) == k)*100.0 for k in vocab}
  return counts

anno_percentages = {k: get_percentage(k) for k in training_data}

anno_perc_df = pd.DataFrame(anno_percentages).T
print("Percentage of frames in every sequence for every class")
anno_perc_df.head()

## Percent of frames of each behavior in the full training set
Having looked at behavior distributions in a couple example Sequences, let's now look at the average over the entire training set.

In [None]:
all_annotations = []
for sk in training_data:
  anno = training_data[sk]['annotations']
  all_annotations.extend(list(anno))
all_annotations = num_to_text(all_annotations)
classes, counts = np.unique(all_annotations, return_counts=True)
pd.DataFrame({"Behavior": classes,
              "Percentage Frames": counts/len(all_annotations)*100.0})

# Split training data into train/validation sets
Because we don't want to overfit to our test set, we'll create a new validation set to test on while we're experimenting with our model.

We'll use the first cell to create some helper functions, and then implement the split in the following cell.

In [None]:
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd


def num_to_text(number_to_class, anno_list):
    """ 
    Convert list of class numbers to list of class names
    """
    return np.vectorize(number_to_class.get)(anno_list)


def split_validation(orig_pose_dictionary, vocabulary, seed=2021,
               test_size=0.5, split_videos=False):
    """ 
    Split data into train and validation sets:
    * Full sequences are either put into train or validation to avoid data leakage
    * By default, the "attack" behavior's presence is used to stratify the split
    * Optionally, the sequences may be split into half and treated as separate sequences
    """

    if test_size == 0.0:
        return orig_pose_dictionary, None

    number_to_class = {v: k for k, v in vocabulary.items()}
    if split_videos:
        pose_dictionary = {}
        for key in orig_pose_dictionary:
            key_pt1 = key + '_part1'
            key_pt2 = key + '_part2'
            anno_len = len(orig_pose_dictionary[key]['annotations'])
            split_idx = anno_len//2
            pose_dictionary[key_pt1] = {
                'annotations': orig_pose_dictionary[key]['annotations'][:split_idx],
                'keypoints': orig_pose_dictionary[key]['keypoints'][:split_idx]}
            pose_dictionary[key_pt2] = {
                'annotations': orig_pose_dictionary[key]['annotations'][split_idx:],
                'keypoints': orig_pose_dictionary[key]['keypoints'][split_idx:]}
    else:
        pose_dictionary = orig_pose_dictionary

    def get_percentage(sequence_key):
        anno_seq = num_to_text(
            number_to_class, pose_dictionary[sequence_key]['annotations'])
        counts = {k: np.mean(np.array(anno_seq) == k) for k in vocabulary}
        return counts

    anno_percentages = {k: get_percentage(k) for k in pose_dictionary}

    anno_perc_df = pd.DataFrame(anno_percentages).T

    rng_state = np.random.RandomState(seed)
    try:
        idx_train, idx_val = train_test_split(anno_perc_df.index,
                                              stratify=anno_perc_df['attack'] > 0,
                                              test_size=test_size,
                                              random_state=rng_state)
    except:
        idx_train, idx_val = train_test_split(anno_perc_df.index,
                                              test_size=test_size,
                                              random_state=rng_state)

    train_data = {k: pose_dictionary[k] for k in idx_train}
    val_data = {k: pose_dictionary[k] for k in idx_val}
    return train_data, val_data

In [None]:
train, val = split_validation(training_data, vocab, test_size=0.25)
print("Number of Sequences in train set: ", len(train))
print("Number of Sequences in validation set: ", len(val))

# Training and testing a baseline model implemented in Tensorflow 🏋️‍♂️

The CalMS21 dataset is accompanied by a set of baseline models, implemented in Tensorflow. These can be found with accompanying documentation at <a href=https://gitlab.aicrowd.com/aicrowd/research/mab-e/mab-e-baselines>https://gitlab.aicrowd.com/aicrowd/research/mab-e/mab-e-baselines</a>.

The baseline model is a simple neural network that takes as input a "Trajectory" - a short snippet of a sequence showing what the mice are doing before and after the "current" frame (the frame on which behavior is to be classified.) We tested multiple simple baseline architectures in the CalMS21 paper, found at <a href=https://arxiv.org/abs/2104.02710>https://arxiv.org/abs/2104.02710</a>.

In [None]:
! git clone http://gitlab.aicrowd.com/aicrowd/research/mab-e/mab-e-baselines.git
%cd mab-e-baselines
! pip install -r requirements.txt

In [None]:
!cp ../calms21_task1_train.npy data/train.npy
!cp ../calms21_task1_test.npy data/test.npy

In [None]:
import numpy as np
import os

from tensorflow import keras
import tensorflow as tf
from keras.models import Sequential
import keras.layers as layers
import tensorflow_addons as tfa

import sklearn
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from copy import deepcopy
import tqdm
import gc

## Seeding helper
Its good practice to seed before every run, so you can reproduce your results.

In [None]:
def seed_everything(seed):
  np.random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  tf.random.set_seed(seed)

seed=2021
seed_everything(seed)

## Generator 🔌

The generator is used to feed data into our models. It randomly samples frames from the Sequences in our training set, and for each frame takes the animals' poses in a small window in time around that frame (we call these snippets Trajectories.)

It also provides support for data augmentation:
1.   Random rotation
2.   Random translation

🚧 Note that the same augmentation is applied across all frames in a selected window, e.g - Random rotation by 10 degrees will rotate all frames in the input window by the same angle.


In [None]:
from tensorflow import keras
import numpy as np


def calculate_input_dim(feature_dim, architechture, past_frames, future_frames):
    """
    Data is arranged as [t, flattened_feature_dimensions]
           where t => [past_frames + 1 + future_frames]

    In this version, we flatten the feature dimensions
    But another generator, inherited from this class,
    could very well retain the actual structure of the mice
    coordinates.
    """
    flat_dim = np.prod(feature_dim)
    if architechture != 'fully_connected':
        input_dim = ((past_frames + future_frames + 1), flat_dim,)
    else:
        input_dim = (flat_dim * (past_frames + future_frames + 1),)
    return input_dim


def mabe_generator(data, augment, shuffle, sequence_key, kwargs):
    if data is not None:
        return MABe_Data_Generator(data,
                               augment=augment,
                               shuffle=shuffle,
                               sequence_key=sequence_key,
                               **kwargs)
    else:
        return None


class MABe_Data_Generator(keras.utils.Sequence):
    """
    Generates window of frames from sequence data
    * Each window comprises of past and future frames
    * Frame skip > 1 can be used to increased for subsampling
    * Augments by rotation and shifting frames
    * Boundaries are padded with zeros for when window exceeds the limits
    """
    def __init__(self,  pose_dict,
                 class_to_number,
                 batch_size=2,
                 input_dimensions=(2, 2, 7),
                 augment=False,
                 past_frames=100,
                 future_frames=100,
                 frame_skip=1,
                 shuffle=True,
                 sequence_key = 'keypoints'):

        self.batch_size = batch_size
        self.dim = input_dimensions

        self.classname_to_index_map = class_to_number
        self.n_classes = len(self.classname_to_index_map)

        self.past_frames = past_frames
        self.future_frames = future_frames
        self.frame_skip = frame_skip

        self.shuffle = shuffle
        self.augment = augment

        self.sequence_key = sequence_key

        # Raw Data Containers
        self.X = {}
        self.y = []

        # Setup Dimensions of data points
        # self.setup_dimensions()

        # Load raw pose dictionary
        self.load_pose_dictionary(pose_dict)

        # Setup Utilities
        self.setup_utils()

        # Generate a global index of all datapoints
        self.generate_global_index()

        # Epoch End preparations
        self.on_epoch_end()

    def load_pose_dictionary(self, pose_dict):
        """ Load raw pose dictionary """
        self.pose_dict = pose_dict
        self.video_keys = list(pose_dict.keys())

    def setup_utils(self):
        """ Set up padding utilities """
        self.setup_padding_utils()

    def setup_padding_utils(self):
        """ Prepare to pad frames """
        self.left_pad = self.past_frames * self.frame_skip
        self.right_pad = self.future_frames * self.frame_skip

        if self.sequence_key == 'keypoints':
            self.pad_width = (self.left_pad, self.right_pad), (0, 0), (0, 0), (0, 0)
        else:
            self.pad_width = (self.left_pad, self.right_pad), (0, 0)

    def classname_to_index(self, annotations_list):
        """
        Converts a list of string classnames into numeric indices
        """
        return np.vectorize(self.classname_to_index_map.get)(annotations_list)

    def generate_global_index(self):
        """ Define arrays to map video keys to frames """
        self.video_indices = []
        self.frame_indices = []

        self.action_annotations = []

        # For all video keys....
        for video_index, video_key in enumerate(self.video_keys):
            # Extract all annotations
            annotations = self.pose_dict[video_key]['annotations']
            # add annotations to action_annotations
            self.action_annotations.extend(annotations)

            number_of_frames = len(annotations)

            # Keep a record for video and frame indices
            # Keep a record of video_indices
            self.video_indices.extend([video_index] * number_of_frames)
            # Keep a record of frame indices
            self.frame_indices.extend(range(number_of_frames))
            # Add padded keypoints for each video key
            self.X[video_key] = np.pad(
                self.pose_dict[video_key][self.sequence_key], self.pad_width)

        self.y = np.array(self.action_annotations)
        # self.y = self.classname_to_index(self.action_annotations) # convert text labels to indices
        self.X_dtype = self.X[video_key].dtype  # Store D_types of X

        # generate a global index list for all data points
        self.indices = np.arange(len(self.frame_indices))

    def __len__(self):
        ct = len(self.indices) // self.batch_size
        ct += int((len(self.indices) % self.batch_size) > 0)
        return ct

    def get_X(self, data_index):
        """
        Obtains the X value from a particular global index
        """
        # Obtain video key for this datapoint
        video_key = self.video_keys[
            self.video_indices[data_index]
        ]
        # Identify the (local) frame_index
        # to offset original data padding
        frame_index = self.frame_indices[data_index] + self.left_pad
        # Slice from beginning of past frames to end of future frames
        slice_start_index = frame_index - self.left_pad
        slice_end_index = frame_index + self.frame_skip + self.right_pad
        assert slice_start_index >= 0
        _X = self.X[video_key][
            slice_start_index:slice_end_index:self.frame_skip
        ]
        if self.augment:
            _X = self.augment_fn(_X)
        return _X

    def augment_fn(self, to_augment):
        """ 
        Augment sequences
            * Rotation - All frames in the sequence are rotated by the same angle
                using the euler rotation matrix
            * Shift - All frames in the sequence are shifted randomly
                but by the same amount
        """
        if len(to_augment.shape) != 4:
            x = to_augment[:, :28].reshape(-1, 2, 7, 2)
        else:
            x = to_augment

        # Rotate
        angle = (np.random.rand()-0.5) * (np.pi * 2)
        c, s = np.cos(angle), np.sin(angle)
        rot = np.array([[c, -s], [s, c]])
        x = np.dot(x, rot)

        # Shift - All get shifted together
        shift = (np.random.rand(2)-0.5) * 2 * 0.25
        x = x + shift

        if len(to_augment.shape) != 4:

            x = np.concatenate([x.reshape(-1, 28), to_augment[:, 28:]], axis = -1)

        return x

    def __getitem__(self, index):
        batch_size = self.batch_size
        batch_indices = self.indices[
            index*batch_size:(index+1)*batch_size]  # List indexing overflow gets clipped

        batch_size = len(batch_indices)  # For the case when list indexing is clipped

        X = np.empty((batch_size, *self.dim), self.X_dtype)

        for batch_index, data_index in enumerate(batch_indices):
            # Obtain the post-processed X value at the said data index
            _X = self.get_X(data_index)
            # Reshape the _X to the expected dimensions
            X[batch_index] = np.reshape(_X, self.dim)

        y_vals = self.y[batch_indices]
        # Converting to one hot because F1 callback needs one hot
        y = keras.utils.to_categorical(y_vals, num_classes=self.n_classes)
        return X, y

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)


## Trainer 🏋️

The trainer class implements a unified interface for using the data Generator we made above.

It supports fully connected or 1D convolutional networks, as well as other hyperparameters for the model and the generator.

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
import tensorflow.keras.layers as layers
import tensorflow_addons as tfa
import sklearn
import pandas as pd

from utils.model_utils import add_layer


class Trainer:
    """
    Custom Trainer class for sequential window data
    Setup and manage training for different models
    Supports different architectures
    """
    def __init__(self, *,
                 train_generator,
                 val_generator,
                 input_dim,
                 num_classes,
                 class_to_number=None,
                 architecture="conv_1d",
                 test_generator=None,
                 arch_params={}):

        self.input_dim = input_dim
        self.num_classes = num_classes

        self.class_to_number = class_to_number

        self.train_generator = train_generator
        self.val_generator = val_generator
        self.test_generator = test_generator

        self.architecture = architecture
        self.arch_params = arch_params

    def delete_model(self):
        self.model = None

    def initialize_model(self, layer_channels=(512, 256), dropout_rate=0.,
                         learning_rate=1e-3, conv_size=5):
        """ Instantiate the model based on the architecture """
        inputs = layers.Input(self.input_dim)
        x = layers.BatchNormalization()(inputs)

        if self.architecture == 'lstm':
            lstm_size = self.arch_params.lstm_size
            x = layers.LSTM(lstm_size, activation='tanh')(x)

        for ch in layer_channels:
            x = add_layer(x, ch, drop=dropout_rate,
                          architecture=self.architecture,
                          arch_params=self.arch_params)
        x = layers.Flatten()(x)
        x = layers.Dense(self.num_classes, activation='softmax')(x)

        metrics = [tfa.metrics.F1Score(num_classes=self.num_classes)]
        optimizer = tf.keras.optimizers.Adam(lr=learning_rate)

        model = Model(inputs, x)

        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=metrics)

        self.model = model

    def _set_model(self, model):
        """ Set an external, provide initialized and compiled keras model """
        self.model = model

    def train(self, epochs=20, class_weight=None, callbacks=[]):
        """ Train the model for given epochs """
        if self.model is None:
            print("Please Call trainer.initialize_model first")
            return

        self.model.fit(self.train_generator,
                    #    validation_data=self.val_generator,
                       epochs=epochs,
                       class_weight=class_weight,
                       callbacks=callbacks)

    def get_generator_by_mode(self, mode='validation'):
        """ Select the generator - Train, Validation or Test"""
        if mode == 'validation':
            return self.val_generator
        elif mode == 'train':
            return self.train_generator
        elif mode == 'test':
            return self.test_generator
        else:
            raise NotImplementedError

    def get_labels(self, generator):
        """ Get all the ground truth labels"""
        y_val = []
        for _, y in generator:
            y_val.extend(list(y))
        y_val = np.argmax(np.array(y_val), axis=-1)
        return y_val

    def get_prediction_probabilities(self, generator):
        """ Get all the model predictions """
        return self.model.predict(generator, verbose=True)

    def get_predictions(self, generator):
        """ Get all the model predictions """
        y_pred = self.get_prediction_probabilities(generator)
        y_pred = np.argmax(y_pred, axis=-1)
        return y_pred

    def get_metrics(self, mode='validation'):
        """
        Get metrics - F1, Precision, Recall for each class
        "mode" can be set to use training or validation data
        """
        generator = self.get_generator_by_mode(mode)

        if generator is None:
            return None

        labels = self.get_labels(generator)
        probabilites = self.get_prediction_probabilities(generator)
        predictions = np.argmax(probabilites, axis=-1)

        f1_scores = sklearn.metrics.f1_score(labels, predictions, average=None)
        rec_scores = sklearn.metrics.precision_score(
            labels, predictions, average=None)
        prec_scores = sklearn.metrics.recall_score(
            labels, predictions, average=None)

        # Average precsion - all labels not equal to correct label are mistakes
        ap_scores = []
        for single_label in sorted(np.unique(labels)):
            labels_l = labels == single_label
            probabilites_l = probabilites[:, single_label] 
            ap_score_l = sklearn.metrics.average_precision_score(
                labels_l, probabilites_l, average='micro')
            ap_scores.append(ap_score_l)

        classes = sorted(self.class_to_number, key=self.class_to_number.get)
        metrics = pd.DataFrame({"Class": classes, "F1": f1_scores,
                                "Precision": prec_scores, "Recall": rec_scores,
                                "Average Precision": ap_scores})

        if len(classes) > 2:
            try:
                average_scores = metrics[metrics['Class'] != 'other'].mean()
                metrics = metrics.append(average_scores, ignore_index=True)
                metrics.iloc[-1, 0] = 'Macro Average'
            except Exception:
                pass

        return metrics

## Preprocess

We'll normalize the data based on the information that the frame size is 1024x570

The original data is of shape (sequence length, mouse, x y coordinate, keypoint)
 = (length, 2, 2, 7)

 We'll swap the x y and the keypoint axis, which will help in rotation augmentation.

In [None]:
def normalize_data(orig_pose_dictionary):
  for key in orig_pose_dictionary:
    X = orig_pose_dictionary[key]['keypoints']
    X = X.transpose((0,1,3,2)) #last axis is x, y coordinates
    X[..., 0] = X[..., 0]/1024
    X[..., 1] = X[..., 1]/570
    orig_pose_dictionary[key]['keypoints'] = X
  return orig_pose_dictionary

# Train function and inference

The below function uses a set of hyperparameters we found with some tuning, though results can be improved with further tuning or other models.

You can adjust hyperparameter values using the `config` dictionary defined in the following cell.

In [None]:
import os
import argparse
from copy import deepcopy

import tensorflow as tf
from utils.load_data import load_mabe_data_task1
from utils.dirs import create_dirs
from utils.preprocessing import normalize_data, transpose_last_axis
from utils.split_data import split_data
from utils.seeding import seed_everything
from trainers.mab_e_trainer import Trainer
from data_generator.mab_e_data_generator import mabe_generator
from data_generator.mab_e_data_generator import calculate_input_dim
from utils.save_results import save_results


def train_task1(results_dir, dataset, vocabulary, test_data, config,
                pretrained_model_path=None, skip_training=False, read_features = False):


    # Create directories if not present
    create_dirs([results_dir])

    # Seed for reproducibilty
    seed_everything(config.seed)

    if not read_features:
      sequence_key = 'keypoints'
      feature_dim = (2, 7, 2)
    else:
      sequence_key = 'features'      
      feature_dim = (60)

    # Transpose last axis, used for augmentation and normalization
    dataset = transpose_last_axis(deepcopy(dataset), sequence_key = sequence_key)
    test_data = transpose_last_axis(deepcopy(test_data), sequence_key = sequence_key)

    # Normalize the x y coordinates
    if config.normalize:
        dataset = normalize_data(dataset, sequence_key = sequence_key)
        test_data = normalize_data(test_data, sequence_key = sequence_key)

    # Split the data
    train_data, val_data = split_data(dataset,
                                      seed=config.seed,
                                      vocabulary=vocabulary,
                                      test_size=config.val_size,
                                      split_videos=config.split_videos)
    num_classes = len(vocabulary)

    # Calculate the input dimension based on past and future frames
    # Also flattens the channels as required by the architecture
    input_dim = calculate_input_dim(feature_dim,
                                    config.architecture,
                                    config.past_frames,
                                    config.future_frames)

    # Initialize data generators
    common_kwargs = {"batch_size": config.batch_size,
                     "input_dimensions": input_dim,
                     "past_frames": config.past_frames,
                     "future_frames": config.future_frames,
                     "class_to_number": vocabulary,
                     "frame_skip": config.frame_gap}

    train_generator = mabe_generator(train_data,
                                     augment=config.augment,
                                     shuffle=True,
                                     sequence_key=sequence_key,
                                     kwargs=common_kwargs)

    val_generator = mabe_generator(val_data,
                                   augment=False,
                                   shuffle=False,
                                   sequence_key=sequence_key,                                   
                                   kwargs=common_kwargs)

    test_generator = mabe_generator(test_data,
                                    augment=False,
                                    shuffle=False,
                                    sequence_key=sequence_key,
                                    kwargs=common_kwargs)

    trainer = Trainer(train_generator=train_generator,
                      val_generator=val_generator,
                      test_generator=test_generator,
                      input_dim=input_dim,
                      class_to_number=vocabulary,
                      num_classes=num_classes,
                      architecture=config.architecture,
                      arch_params=config.architecture_parameters)

    # In case of only using
    if skip_training and pretrained_model_path is not None:
        trainer.model = tf.keras.models.load_model(pretrained_model_path)

        # Print model summary
        trainer.model.summary()

        print("Skipping Training")
    else:
        # Model initialization
        trainer.initialize_model(layer_channels=config.layer_channels,
                                 dropout_rate=config.dropout_rate,
                                 learning_rate=config.learning_rate)
        # Print model summary
        trainer.model.summary()

        # Train model
        trainer.train(epochs=config.epochs)

    # Get metrics
    train_metrics = trainer.get_metrics(mode='train')
    val_metrics = trainer.get_metrics(mode='validation')
    test_metrics = trainer.get_metrics(mode='test')

    # Save the results
    save_results(results_dir, 'task1',
                 trainer.model, config,
                 train_metrics, val_metrics, test_metrics)
    
    # return predictions on the test set
    return trainer.get_predictions(test_generator)


In [None]:
from easydict import EasyDict

# Baseline config - Convolution 1D
config = {"seed": 42,
          "val_size": 0.2,
          "split_videos": False,
          "normalize": True,
          "past_frames": 100,
          "future_frames": 100,
          "frame_gap": 2,
          "architecture": "conv_1D",
          "architecture_parameters": EasyDict({"conv_size": 5}),
          "batch_size": 128,
          "learning_rate": 1e-3,
          "dropout_rate": 0.5,
          "layer_channels": (128, 64, 32),
          "epochs": 15,
          "augment": False}

config = EasyDict(config)


Run this cell to train the network characterized by the settings in `config` above, and test it on the test set. If you want to keep the test set separate while doing hyperparameter exploration, you should split training_data into training and validation sets, and pass the validation set in place of the test set here.



In [None]:
results_dir = '.'
predictions = train_task1(results_dir,
            dataset=training_data,
            vocabulary=vocab,
            test_data=test_data,
            config=config)