In [None]:
# imports

import os

import networkx as nx
import numpy as np
import random
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.utils import to_networkx

In [None]:
# load the data

submission_clips = np.load('data/raw/submission_data.npy',allow_pickle=True).item()
user_train = np.load('data/raw/user_train.npy',allow_pickle=True).item()

In [None]:
# understanding the arrays

sequence_names = list(user_train['sequences'].keys())
# len(user_train['sequences'][sequence_names[1599]]['keypoints'][1799][2][11])
# -> (1600 sequences) x (1800 keypoints) x (3 mouses) x (12 body parts) x (x, y  coordinates)

# 1) nose, 2) left ear, 3) right ear, 4) neck, 5) left forepaw, 6) right forepaw, 
# 7) center back, 8) left hindpaw, 9) right hindpaw, 10) tail base, 11) tail middle, 12) tail tip.

raw_features = user_train['sequences'][sequence_names[0]]['keypoints']
features = raw_features.reshape(-1,3*12,2)
# user_train['sequences'].keys()

In [None]:
# understanding the data

# print("Dataset keys -", submission_clips.keys())
# print("Number of submission sequences -", len(submission_clips['sequences']))

# sequence_names = list(submission_clips["sequences"].keys())
# sequence_key = sequence_names[0]
# single_sequence = submission_clips["sequences"][sequence_key]["keypoints"]

# print("Sequence name -", sequence_key)
# print("Single Sequence shape", single_sequence.shape)
# print(f"Number of Frames in {sequence_key} -", len(single_sequence))

In [None]:
# code for creating the edge indices

# for the 2 mice dataset
i = [0,0,1,1,3,3,4,5]
j = [1,2,3,2,4,5,6,6]
i2 = [x+7 for x in i]
j2 = [x+7 for x in j]
second_option = [i+[3]+i2,j+[10]+j2]
second_option

# for the 3 mice dataset
i1 = [0,0,1,2,3,3,3,4,5,6,6,6,7,8,9,10]
j1 = [1,2,3,3,4,5,6,6,6,7,8,9,9,9,10,11]
i2 = [x+12 for x in i1]
j2 = [x+12 for x in j1]
i3 = [x+2*12 for x in i1]
j3 = [x+2*12 for x in j1]

# third_option = [i1+[3]+i2+[3+11]+i3,  # with edges in between mice
#                 j1+[3+11]+j2+[3+2*11]+j3]

third_option = [i1+i2+i3,j1+j2+j3]  # without edges in between mice

In [None]:
# dataloader

class MABDataset(Dataset):
    def __init__(self, root, test=False, transform=None, pre_transform=None, pre_filter=None):
        self.test=test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        if self.test:
            # raise NotImplementedError
            return 'submission_data.npy'
        return 'user_train.npy',

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        return 'not_processed.pt'
        # self.raw_data = np.load(self.raw_paths[0], allow_pickle=True)

        # if self.test:
        #     return [f'data_test_{i}.pt' for i in range(len(self.raw_data[()]['annotator-id_0']['task1/test/mouse071_task1_annotator1']['keypoints']))]
        # else:
        #     return [f'data_{i}.pt' for i in range(len(self.raw_data[()]['annotator-id_0']['task1/train/mouse001_task1_annotator1']['keypoints']))]

    def download(self):
        pass

    def _edge_index_creator(self) -> list:
        i1 = [0,0,1,2,3,3,3,4,5,6,6,6,7,8,9,10]
        j1 = [1,2,3,3,4,5,6,6,6,7,8,9,9,9,10,11]

        i2 = [x+12 for x in i1]
        j2 = [x+12 for x in j1]
        i3 = [x+2*12 for x in i1]
        j3 = [x+2*12 for x in j1]

        result = [i1+i2+i3,
                  j1+j2+j3]
        return result

    def process(self):
        self.raw_data = np.load(self.raw_paths[0], allow_pickle=True)
        if self.test:
            raise NotImplementedError
            # raw_features = self.raw_data[()]['annotator-id_0']['task1/test/mouse071_task1_annotator1']['keypoints']
            # self.labels = self.raw_data[()]['annotator-id_0']['task1/test/mouse071_task1_annotator1']['annotations']
        else:
            # print(self.raw_data[()]['sequences'].keys())
            sequence_names = list(self.raw_data[()]['sequences'])
            raw_features = self.raw_data[()]['sequences'][sequence_names[0]]['keypoints']
            self.labels = [random.randrange(0,4) for _ in range(len(raw_features))]
        self.num_clases = 4
        features = raw_features.reshape(-1,3*12,2)

        edge_index = torch.tensor(self._edge_index_creator(), dtype=torch.long)
        for i in range(len(features)):
            x = torch.tensor(features[i], dtype=torch.float)
            y = torch.tensor(self.labels[i], dtype=torch.int)
            graph = Data(x=x, edge_index=edge_index, y=y)
            if self.test:
                torch.save(graph, os.path.join(self.processed_dir, f'data_test_{i}.pt'))
            else:
                torch.save(graph, os.path.join(self.processed_dir, f'data_{i}.pt'))

    def len(self):
        return len(self.labels)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [None]:
# create dataset

train_dataset = MABDataset(root='./data')
# test_dataset = MABDataset(root='./data',test=True)

In [None]:
# plot the graph

G = to_networkx(train_dataset[0], to_undirected=True)

# method 1
G = nx.Graph(G, length=20)
nx.draw(G, node_size=100, alpha=0.8, arrowsize=8, with_labels=False)

# method 2
# def visualize(h, color, epoch=None, loss=None):
#     plt.figure(figsize=(7,7))
#     plt.xticks([])
#     plt.yticks([])

#     if torch.is_tensor(h):
#         h = h.detach().cpu().numpy()
#         plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
#         if epoch is not None and loss is not None:
#             plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
#     else:
#         nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
#                          node_color=color, cmap="Set2")
#     plt.show()

In [None]:
'''code from the aicrowd webpage
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import colors
from matplotlib import rc

rc('animation', html='jshtml')

# Note: Image processing may be slow if too many frames are animated.

#Plotting constants
FRAME_WIDTH_TOP = 850
FRAME_HEIGHT_TOP = 850

M1_COLOR = 'lawngreen'
M2_COLOR = 'skyblue'
M3_COLOR = 'tomato'

PLOT_MOUSE_START_END = [(0, 1), (1, 3), (3, 2), (2, 0),        # head
                        (3, 6), (6, 9),                        # midline
                        (9, 10), (10, 11),                     # tail
                        (4, 5), (5, 8), (8, 9), (9, 7), (7, 4) # legs
                       ]

class_to_number = {s: i for i, s in enumerate(user_train['vocabulary'])}

number_to_class = {i: s for i, s in enumerate(user_train['vocabulary'])}

def num_to_text(anno_list):
  return np.vectorize(number_to_class.get)(anno_list)

def set_figax():
    fig = plt.figure(figsize=(8, 8))

    img = np.zeros((FRAME_HEIGHT_TOP, FRAME_WIDTH_TOP, 3))

    ax = fig.add_subplot(111)
    ax.imshow(img)

    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    return fig, ax

def plot_mouse(ax, pose, color):
    # Draw each keypoint
    for j in range(10):
        ax.plot(pose[j, 0], pose[j, 1], 'o', color=color, markersize=3)

    # Draw a line for each point pair to form the shape of the mouse

    for pair in PLOT_MOUSE_START_END:
        line_to_plot = pose[pair, :]
        ax.plot(line_to_plot[:, 0], line_to_plot[
                :, 1], color=color, linewidth=1)

def animate_pose_sequence(video_name, seq, start_frame = 0, stop_frame = 100, skip = 0,
                          annotation_sequence = None):
    # Returns the animation of the keypoint sequence between start frame
    # and stop frame. Optionally can display annotations.

    image_list = []

    counter = 0
    if skip:
        anim_range = range(start_frame, stop_frame, skip)
    else:
        anim_range = range(start_frame, stop_frame)

    for j in anim_range:
        if counter%20 == 0:
          print("Processing frame ", j)
        fig, ax = set_figax()
        plot_mouse(ax, seq[j, 0, :, :], color=M1_COLOR)
        plot_mouse(ax, seq[j, 1, :, :], color=M2_COLOR)
        plot_mouse(ax, seq[j, 2, :, :], color=M3_COLOR)

        if annotation_sequence is not None:
          annot = annotation_sequence[j]
          annot = number_to_class[annot]
          plt.text(50, -20, annot, fontsize = 16,
                   bbox=dict(facecolor=class_to_color[annot], alpha=0.5))

        ax.set_title(
            video_name + '\n frame {:03d}.png'.format(j))

        ax.axis('off')
        fig.tight_layout(pad=0)
        ax.margins(0)

        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(),
                                        dtype=np.uint8)
        image_from_plot = image_from_plot.reshape(
            fig.canvas.get_width_height()[::-1] + (3,))

        image_list.append(image_from_plot)

        plt.close()
        counter = counter + 1

    # Plot animation.
    fig = plt.figure(figsize=(8,8))
    plt.axis('off')
    im = plt.imshow(image_list[0])

    def animate(k):
        im.set_array(image_list[k])
        return im,
    ani = animation.FuncAnimation(fig, animate, frames=len(image_list), blit=True)
    return ani
    '''