In [1]:
import os
import numpy as np
import pandas as pd


import torch
import torchvision
import torch.nn as nn

from models import *
from augmentation import *
from utils import set_seed
from dataset import MABeDataset, create_dataloaders
from metrics import BCE, F1_score
from trainer import Trainer

os.makedirs('deliveries', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('history', exist_ok=True)


In [2]:
PATH = '../../Datasets/MABe-mouse-behavior-detection/'

In [3]:
ds = MABeDataset()
LABELS = ds.LABELS
ds.DF.columns

1156 / 1156


Index(['lab_id', 'video_id', 'mouse1_strain', 'mouse1_color', 'mouse1_sex',
       'mouse1_id', 'mouse1_age', 'mouse1_condition', 'mouse2_strain',
       'mouse2_color', 'mouse2_sex', 'mouse2_id', 'mouse2_age',
       'mouse2_condition', 'mouse3_strain', 'mouse3_color', 'mouse3_sex',
       'mouse3_id', 'mouse3_age', 'mouse3_condition', 'mouse4_strain',
       'mouse4_color', 'mouse4_sex', 'mouse4_id', 'mouse4_age',
       'mouse4_condition', 'frames_per_second', 'video_duration_sec',
       'pix_per_cm_approx', 'video_width_pix', 'video_height_pix',
       'arena_width_cm', 'arena_height_cm', 'arena_shape', 'arena_type',
       'body_parts_tracked', 'behaviors_labeled', 'tracking_method', 'label',
       'mouse_1', 'mouse_2', 'chunk', 'ok', 'arena_w', 'arena_h'],
      dtype='object')

In [4]:
idx = 4#np.random.randint(1024)
(x,al,c),y=ds[idx]
idx

4

In [5]:
#np.unique(ds.DF['label'].explode().dropna()).tolist()

In [6]:
class Feature_Engineering(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        if config: self.config.update(config)
        self.pdist = torch.nn.PairwiseDistance(p=2)

    def forward(self, x, context):
        B,C,N,_ = x.shape
        C = C//2
        x1 = x[:,:C]
        x2 = x[:,C:]

        
        arena_wh = context[:,[-2,-1]].unsqueeze(1).unsqueeze(1).to(x.device)
        zero = torch.zeros((1,1,1,2)).to(x.device)
        circle = context[:,-3].unsqueeze(1).unsqueeze(1)

        dist_wall_circle_1 = torch.abs((x1-arena_wh/2).norm(dim=-1)-context[:,-1].unsqueeze(1).unsqueeze(1)/2)
        dist_wall_rect_1,_ = torch.cat([torch.abs(x1-arena_wh), torch.abs(x1-zero)], dim=-1).min(dim=-1)
        dist_wall_1 = circle * dist_wall_circle_1 + (1-circle) * dist_wall_rect_1

        dist_wall_circle_2 = torch.abs((x2-arena_wh/2).norm(dim=-1)-context[:,-1].unsqueeze(1).unsqueeze(1)/2)
        dist_wall_rect_2,_ = torch.cat([torch.abs(x2-arena_wh), torch.abs(x2-zero)], dim=-1).min(dim=-1)
        dist_wall_2 = circle * dist_wall_circle_2 + (1-circle) * dist_wall_rect_2
        

        rel_x = (x1-x2)
        rel_dist = rel_x.norm(dim=-1)

        dx1 = x1.diff(dim=2, prepend=x1[:, :, :1])
        dx2 = x2.diff(dim=2, prepend=x2[:, :, :1])

        adx1 = torch.einsum('...i,...i->...', dx1[:, :, 1:], dx1[:, :, :-1]) / (dx1[:, :, 1:].norm(dim=-1) * dx1[:, :, :-1].norm(dim=-1) + 1e-4)
        adx1 = torch.cat([torch.zeros_like(adx1[:, :, :1]), adx1], dim=2)

        adx2 = torch.einsum('...i,...i->...', dx2[:, :, 1:], dx2[:, :, :-1]) / (dx2[:, :, 1:].norm(dim=-1) * dx2[:, :, :-1].norm(dim=-1) + 1e-4)
        adx2 = torch.cat([torch.zeros_like(adx2[:, :, :1]), adx2], dim=2)

        dx1_thresh = (dx1.norm(dim=-1)>.1).float()
        dx2_thresh = (dx2.norm(dim=-1)>.1).float()

        dx1_= dx1.norm(dim=-1)
        dx2_= dx2.norm(dim=-1)

        ddx1 = dx1.diff(dim=2, prepend=dx1[:, :, :1])
        ddx2 = dx2.diff(dim=2, prepend=dx2[:, :, :1])

        cross_prod_1 = (dx1[:,:,:,[0]] * ddx1[:,:,:,[1]] - dx1[:,:,:,[1]] * ddx1[:,:,:,[0]]).squeeze(dim=-1)
        cross_prod_2 = (dx2[:,:,:,[0]] * ddx2[:,:,:,[1]] - dx2[:,:,:,[1]] * ddx2[:,:,:,[0]]).squeeze(dim=-1)
        
        dirs = torch.einsum('...i,...i->...', dx1, dx2) / (dx1.norm(dim=-1) * dx2.norm(dim=-1) + 1e-6)
        cross = (dx1[:,:,:,[0]] * dx2[:,:,:,[1]] - dx1[:,:,:,[1]] * dx2[:,:,:,[0]]).squeeze(dim=-1)
        d = torch.cat([
            self.pdist(x1.roll(i+1, dims=1), x2) #* mask_x1.roll(i+1, dims=1) * mask_x2
        for i in range(x1.size(1))], dim=1)
        
        dd = d.diff(dim=-1, prepend=d[:, :, :1])


        lead_1 = torch.einsum('...i,...i->...', dx1, rel_x) / (dx1.norm(dim=-1) * rel_x.norm(dim=-1) + 1e-6)
        lead_2 = torch.einsum('...i,...i->...', dx2, -rel_x) / (dx2.norm(dim=-1) * rel_x.norm(dim=-1) + 1e-6)
        drel_x = rel_x.diff(dim=2, prepend=rel_x[:, :, :1])

        
        tail_to_neck_1 = (x1[:, [6]] - x1[:, [4]])
        neck_to_nose_1 = (x1[:, [4]] - x1[:, [5]])
        tail_to_nose_1 = (x1[:, [6]] - x1[:, [5]])
        head_angle_1 = torch.einsum('...i,...i->...', tail_to_neck_1, neck_to_nose_1) / (tail_to_neck_1.norm(dim=-1) * neck_to_nose_1.norm(dim=-1) + 1e-6)

        tail_to_neck_2 = (x2[:, [6]] - x2[:, [4]])
        neck_to_nose_2 = (x2[:, [4]] - x2[:, [5]])
        tail_to_nose_2 = (x2[:, [6]] - x2[:, [5]])
        head_angle_2 = torch.einsum('...i,...i->...', tail_to_neck_2, neck_to_nose_2) / (tail_to_neck_2.norm(dim=-1) * neck_to_nose_2.norm(dim=-1) + 1e-6)

        ears_span_1 = (x1[:, [0]] - x1[:, [1]]).norm(dim=-1)
        ears_span_2 = (x2[:, [0]] - x2[:, [1]]).norm(dim=-1)

        len1 = (x1[:, -1] - x1[:, -2]).norm(dim=-1)
        len2 = (x2[:, -1] - x2[:, -2]).norm(dim=-1)

        # f_1 = torch.einsum('...i,...i->...', dx1, tail_to_nose_2) / (dx1.norm(dim=-1) * tail_to_nose_2.norm(dim=-1) + 1e-6)
        # f_2 = torch.einsum('...i,...i->...', dx2, tail_to_nose_1) / (dx2.norm(dim=-1) * tail_to_nose_1.norm(dim=-1) + 1e-6)

        # f_1 = f_1 * dx1_thresh
        # f_2 = f_2 * dx2_thresh

        # wh_ratio_1 = self.pdist(x1[:,[2]],x1[:,[3]])/self.pdist(x1[:,[5]],x1[:,[6]])
        # wh_ratio_2 = self.pdist(x2[:,[2]],x2[:,[3]])/self.pdist(x2[:,[5]],x2[:,[6]])

        x = torch.concat([dx1_, dx2_, cross_prod_1, cross_prod_2, adx1, adx2, dirs, d, dd, lead_1, lead_2, dist_wall_1, dist_wall_2], dim=1)
        return x

In [7]:
feng = Feature_Engineering()

In [8]:
fs = feng(x.unsqueeze(0).repeat(64,1,1,1), c.unsqueeze(0).repeat(64,1))
fs.shape

torch.Size([64, 175, 1024])

In [9]:
class ARModel(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        self.config = {
        }
        if config: self.config.update(config)

        self.training = True

        self.feature_eng = Feature_Engineering()
        n_channels = 7
        n_features = n_channels*(n_channels*2+11)

        self.context_encoder = nn.Sequential(
            nn.Linear(71, n_features),
            nn.ReLU(),
        )

        base_h_dim = 128

        # self.unet = UNet(n_features, base_h_dim)

        self.encoder = nn.Sequential(
            SEResConv(n_features, base_h_dim, 5, dilation=1, dropout=0.3),
            nn.AvgPool1d(2),
            SEResConv(base_h_dim, base_h_dim*2, 5, dilation=2, dropout=0.3),
            nn.AvgPool1d(2),
            SEResConv(base_h_dim*2, base_h_dim*4, 5, dilation=4, dropout=0.3),
            # nn.AvgPool1d(2),
            # Conv(base_h_dim*4, base_h_dim*8, 9),
        )
        
        self.decoder = nn.Sequential(
            # Deconv(base_h_dim*8, base_h_dim*4, 9, padding=4),
            # Upsample(1,2),
            SEResDeconv(base_h_dim*4, base_h_dim*2, 5, padding=8, dilation=4, dropout=0.3),
            Upsample(1,2),
            SEResDeconv(base_h_dim*2, base_h_dim, 5, padding=4, dilation=2, dropout=0.3),
            Upsample(1,2),
            nn.ConvTranspose1d(base_h_dim, 37, 5, padding=2, dilation=1, bias=True),
        )
        #self.ca = CrossAttention(n_features, 7)


    def forward(self, x, context):
        x = self.feature_eng(x, context)
        #c = self.context_encoder(context)
        
        #x = self.ca(x.transpose(-2,-1), c.unsqueeze(-2)).transpose(-2,-1)
        
        x = self.encoder(x)
        x = self.decoder(x)

        # x = self.unet(x)
        
        return x

In [10]:
config = {
    'model':ARModel,
    'num_workers':24,
    'seed':2,
    'batch_size':64,
    'losses':[
        BCE(),        
    ],
    'metrics': [
        F1_score
    ],
    'verbose':2,
    'train_only':False,
    'scheduler':True,
    "lr":1e-4,    
}

In [None]:
trainer = Trainer(config=config)
model = trainer.train(epochs=2048*2)

de48ba5a94ad7fdbd0fbd49e92a5398575a05a7cd3c08b841148a73595cc8274 

[1m Epoch 2261/4096
[1m Training 	|	 loss=0.619079[0m
[1m Validation 	|	 loss=0.189345 F1_score=0.447[0m

[1m Best : F1_score=0.464 at epoch 2189


Training:   0%|                                                          | 0/17 [00:00<?, ?it/s]

In [None]:
# for i in range(6):
#     if i!=2:
#         trainer = Trainer(config=config|{"seed":i})
#         model = trainer.train(epochs=2048*2)
#         torch.save(model.state_dict(), f"models/{trainer.exp_id}.pth")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = trainer.model.eval()

In [None]:
torch.save(model.state_dict(), f"models/{trainer.exp_id}.pth")

In [None]:
model.load_state_dict(torch.load(f"models/{trainer.exp_id}.pth", map_location=torch.device(device)))
model.eval()
None

In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt

pd.options.mode.chained_assignment = None

PATH = '../../Datasets/MABe-mouse-behavior-detection/'
df = pd.read_csv(PATH+'train.csv')

for m in range(4):
    df[f'mouse{m+1}_condition'] = df[f'mouse{m+1}_condition'].apply(ds.process_condition)
    df[f'mouse{m+1}_age'] = df[f'mouse{m+1}_age'].apply(ds.process_age)

In [None]:
base_bp = np.array(['ear_left','ear_right', 'lateral_left','lateral_right','neck','nose','tail_base']).astype('object')

In [None]:
data = pd.DataFrame(columns=['lab_id', 'video_id', 'mouse_1', 'mouse_2', 'chunk', 'ok'])

In [None]:
from preprocessing import preprocess, base_bp

In [None]:
def load_track(track, fps):
    pdist = torch.nn.PairwiseDistance(p=2)

    x = torch.tensor(np.stack([np.stack([np.array(_x) for _x in x]) for x in track.values]))
    if x.size(1)==7: x = torch.cat([x,x], dim=1)
    

    resize_x = torchvision.transforms.Resize((int(len(x)*(30./fps)), 2))
    
    x_ = []
    for i in range(14):
        x_.append(resize_x(x[:,i].unsqueeze(0))[0].unsqueeze(1) / ppc)
    x = torch.cat(x_, dim=1)
    
    return x.float()


def get_behaviors_labeled(x):
    if isinstance(x, str):
        B = {}
        X = eval(x)
        for x in X:
            m1,m2,b = x.split(',')
            if (m1,m2) in B: B[(m1,m2)].append(b.replace('"', '').replace("'", ''))
            else: B[(m1,m2)] = [b.replace('"', '').replace("'", '')]

        for p in B:
            vec = torch.zeros((len(LABELS)))
            for x in B[p]: vec[LABELS.index(x)] = 1
            vec[-1]=1
            B[p] = vec
        return B
    return {}
    

In [None]:
def mobile_average(x, kernel=[0.05,0.1,0.2,0.3,0.2,0.1,0.05]):
    kernel = torch.tensor(kernel)
    for _ in range(len(x.shape)-1): kernel = kernel.unsqueeze(-1)
    k_size = len(kernel)//2
    N = x.size(0)
    out = x.clone()
    for i in range(k_size, N-k_size):
        out[i] = torch.multiply(x[i-k_size:i+k_size+1], kernel).sum(dim=0)
    return out

In [None]:
import warnings
warnings.filterwarnings('ignore')

Pred = pd.DataFrame(columns=['row_id','video_id','agent_id','target_id','action','start_frame','stop_frame'])
model.eval()


for idx in tqdm(df.index[4:5]):
    filepath = f'{df.loc[idx, 'lab_id']}/{df.loc[idx, 'video_id']}'
    track = pd.read_parquet(PATH+f'train_tracking/{filepath}.parquet')
    annot = pd.read_parquet(PATH+f'train_annotation/{filepath}.parquet')
    annot['video_id'] = df.loc[idx, 'video_id']
    annot['lab_id'] = df.loc[idx, 'lab_id']
    annot['behaviors_labeled'] = df.loc[idx, 'behaviors_labeled']

    annot['agent_id'] = 'mouse'+annot['agent_id'].astype('str')
    annot['target_id'] = 'mouse'+annot['target_id'].astype('str')
    annot.loc[annot['agent_id']==annot['target_id'], 'target_id'] = 'self'


    fps = df.loc[idx, 'frames_per_second']
    ppc = df.loc[idx, 'pix_per_cm_approx']
    min_coord = np.array([track.x.min(), track.y.min()])
    arena_w,arena_h = np.array([track.x.max(), track.y.max()])-min_coord
    

    Behaviors = get_behaviors_labeled(df.loc[idx, 'behaviors_labeled'])
    track, _, Mice = preprocess(track, None, False)


    for c in track.columns:
        track[c] = track[c].apply(lambda x: x if (isinstance(x, np.ndarray) and pd.isna([x]).sum()==0) else np.array([0., 0.]))

    context = torch.tensor([
        float(df.loc[idx,'arena_shape']=='circular'),
        arena_w/ppc,
        arena_h/ppc,
    ]).float().unsqueeze(0).to(device)


    n_feat = 14
    seq_len = 1024
    min_size = 8

    for m1 in Mice:
        for m2 in Mice:
            cols = list(base_bp + ' - ' + str(m1))
            if m1!=m2: cols += list(base_bp + ' - ' + str(m2))
            out = track[cols]

            M1 = f'mouse{m1}'
            M2 = f'mouse{m2}' if m1!=m2 else 'self'

            if not ((M1,M2) in Behaviors): continue

            with torch.no_grad():
                X = load_track(out, fps).to(device)
                length = len(X)
                X = torch.cat([X, torch.zeros((seq_len-(X.size(0)%seq_len),n_feat,2)).to(device)], dim=0)
                X = X.reshape((X.size(0)//seq_len, seq_len, n_feat, 2)).transpose(1,2)

                yp = torch.zeros((X.size(0), 37, seq_len)).to(device)
                yp = model(X,context.repeat(X.size(0),1)).sigmoid()
                yp = yp.transpose(1,2).flatten(0,1).cpu()
                yp = yp * Behaviors[(M1,M2)][:-1]

                yp = resize_back(yp[:length].unsqueeze(0))[0]
                
            # yp = mobile_average(yp)
            
            last = (-1, 37)
            P = []
            for i,(m,p) in enumerate(zip(*yp.max(dim=1))):
                p = int(p) if m>.2 else 37
                if last[1]!=p:  
                    
                    p_len = len(Pred)
                    # Add action
                    if last[1]!=37 and (i-last[0]>min_size or len(P)==0):
                        Pred.loc[p_len] = [p_len, df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i-1]
                        P.append(last)
                        last = (i,int(p))
                    elif i-last[0]<=min_size and P:
                        last = P[-1]
                    else:
                        P.append(last)
                        last = (i,int(p))
                
            if last[1] and last[1]!=37 and i-last[0]>min_size:
                Pred.loc[len(Pred)] = [len(Pred), df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i]

In [None]:
import json

from collections import defaultdict

import pandas as pd
import polars as pl


class HostVisibleError(Exception):
    pass


def single_lab_f1(lab_solution: pl.DataFrame, lab_submission: pl.DataFrame, beta: float = 1) -> float:
    label_frames: defaultdict[str, set[int]] = defaultdict(set)
    prediction_frames: defaultdict[str, set[int]] = defaultdict(set)

    for row in lab_solution.to_dicts():
        label_frames[row['label_key']].update(range(row['start_frame'], row['stop_frame']))

    for video in lab_solution['video_id'].unique():
        
        active_labels: str = lab_solution.filter(pl.col('video_id') == video)['behaviors_labeled'].first()  # ty: ignore
        active_labels: set[str] = set(json.loads(active_labels))
        predicted_mouse_pairs: defaultdict[str, set[int]] = defaultdict(set)

        for row in lab_submission.filter(pl.col('video_id') == video).to_dicts():
            # Since the labels are sparse, we can't evaluate prediction keys not in the active labels.
            if ','.join([str(row['agent_id']), str(row['target_id']), row['action']]) not in active_labels:
                continue

            new_frames = set(range(row['start_frame'], row['stop_frame']))
            # Ignore truly redundant predictions.
            new_frames = new_frames.difference(prediction_frames[row['prediction_key']])
            prediction_pair = ','.join([str(row['agent_id']), str(row['target_id'])])
            if predicted_mouse_pairs[prediction_pair].intersection(new_frames):
                # A single agent can have multiple targets per frame (ex: evading all other mice) but only one action per target per frame.
                raise HostVisibleError('Multiple predictions for the same frame from one agent/target pair')
            prediction_frames[row['prediction_key']].update(new_frames)
            predicted_mouse_pairs[prediction_pair].update(new_frames)
            
    tps = defaultdict(int)
    fns = defaultdict(int)
    fps = defaultdict(int)

    #print(list(prediction_frames))
    
    for key, pred_frames in prediction_frames.items():
        action = key.split('_')[-1]
        matched_label_frames = label_frames[key]
        
        tps[action] += len(pred_frames.intersection(matched_label_frames))
        fns[action] += len(matched_label_frames.difference(pred_frames))
        fps[action] += len(pred_frames.difference(matched_label_frames))

        # print(action, len(pred_frames), len(matched_label_frames))
        # print(len(pred_frames.intersection(matched_label_frames)), len(matched_label_frames.difference(pred_frames)), len(pred_frames.difference(matched_label_frames)))
        # print()

    distinct_actions = set()
    for key, frames in label_frames.items():
        action = key.split('_')[-1]
        distinct_actions.add(action)
        if key not in prediction_frames:
            fns[action] += len(frames)

    action_f1s = []
    for action in distinct_actions:
        if tps[action] + fns[action] + fps[action] == 0:
            action_f1s.append(0)
        else:
            action_f1s.append((1 + beta**2) * tps[action] / ((1 + beta**2) * tps[action] + beta**2 * fns[action] + fps[action]))
    
    print(action_f1s)
    return sum(action_f1s) / len(action_f1s)


def mouse_fbeta(solution: pd.DataFrame, submission: pd.DataFrame, beta: float = 1) -> float:
    
    if len(solution) == 0 or len(submission) == 0:
        raise ValueError('Missing solution or submission data')

    expected_cols = ['video_id', 'agent_id', 'target_id', 'action', 'start_frame', 'stop_frame']

    for col in expected_cols:
        if col not in solution.columns:
            raise ValueError(f'Solution is missing column {col}')
        if col not in submission.columns:
            raise ValueError(f'Submission is missing column {col}')

    solution: pl.DataFrame = pl.DataFrame(solution)
    submission: pl.DataFrame = pl.DataFrame(submission)
    assert (solution['start_frame'] <= solution['stop_frame']).all()
    assert (submission['start_frame'] <= submission['stop_frame']).all()
    solution_videos = set(solution['video_id'].unique())
    # Need to align based on video IDs as we can't rely on the row IDs for handling public/private splits.
    submission = submission.filter(pl.col('video_id').is_in(solution_videos))

    solution = solution.with_columns(
        pl.concat_str(
            [
                pl.col('video_id').cast(pl.Utf8),
                pl.col('agent_id').cast(pl.Utf8),
                pl.col('target_id').cast(pl.Utf8),
                pl.col('action'),
            ],
            separator='_',
        ).alias('label_key'),
    )
    submission = submission.with_columns(
        pl.concat_str(
            [
                pl.col('video_id').cast(pl.Utf8),
                pl.col('agent_id').cast(pl.Utf8),
                pl.col('target_id').cast(pl.Utf8),
                pl.col('action'),
            ],
            separator='_',
        ).alias('prediction_key'),
    )

    lab_scores = []
    for lab in solution['lab_id'].unique():
        lab_solution = solution.filter(pl.col('lab_id') == lab).clone()
        lab_videos = set(lab_solution['video_id'].unique())
        lab_submission = submission.filter(pl.col('video_id').is_in(lab_videos)).clone()
        lab_scores.append(single_lab_f1(lab_solution, lab_submission, beta=beta))

    return sum(lab_scores) / len(lab_scores)


def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str, beta: float = 1) -> float:
    """
    F1 score for the MABe Challenge
    """
    solution = solution.drop(row_id_column_name, axis='columns', errors='ignore')
    submission = submission.drop(row_id_column_name, axis='columns', errors='ignore')
    return mouse_fbeta(solution, submission, beta=beta)

In [None]:
Pred['lab_id'] = 'lab'

In [None]:
#Pred = Pred[Pred.start_frame<2048]
#annot = annot[annot.start_frame<2048]

In [None]:
score(annot, Pred, 'row_id')

In [None]:
i = idx
(df.loc[i, 'lab_id'] + ' - ' + str(df.loc[i, 'video_id'])) in ds.DF.index

In [None]:
yp.shape

In [None]:
annot

In [None]:
# import matplotlib.animation as animation
# import matplotlib.pyplot as plt

# plt.rcParams["animation.html"] = "jshtml"
# plt.rcParams['figure.dpi'] = 128  
# plt.ioff()
# fig, ax = plt.subplots()

# colors = plt.cm.tab10.colors

# def update(frame):
#     plt.cla()
    
#     loc = x.cpu().transpose(0,-1)
#     plt.scatter(loc[0,frame,:7], loc[1,frame,:7], c=colors[0])
#     plt.scatter(loc[0,frame,7:], loc[1,frame,7:], c=colors[1])

    
#     if y[frame].max()>.5:
#         plt.text(1,1, LABELS[np.argmax(y[frame])], color=colors[0])
#     plt.xlim(0, 100)
#     plt.ylim(0, 100)


# ani = animation.FuncAnimation(fig=fig, func=update, frames=x.size(1), interval=30)

In [None]:
import warnings
warnings.filterwarnings('ignore')

Pred = pd.DataFrame(columns=['row_id','video_id','agent_id','target_id','action','start_frame','stop_frame'])
df = pd.read_csv(PATH+'test.csv')

for idx in tqdm(df.index):
    filepath = f"{df.loc[idx, 'lab_id']}/{df.loc[idx, 'video_id']}"
    track = pd.read_parquet(PATH+f'test_tracking/{filepath}.parquet')

    fps = df.loc[idx, 'frames_per_second']
    ppc = df.loc[idx, 'pix_per_cm_approx']
    min_coord = np.array([track.x.min(), track.y.min()])
    arena_w,arena_h = np.array([track.x.max(), track.y.max()])-min_coord
    

    Behaviors = get_behaviors_labeled(df.loc[idx, 'behaviors_labeled'])
    track, _, Mice = preprocess(track, None)

    resize_back = torchvision.transforms.Resize((len(track), 37))


    for c in track.columns:
        track[c] = track[c].apply(lambda x: x-min_coord if (isinstance(x, np.ndarray) and pd.isna([x]).sum()==0) else np.array([0., 0.]))

    context = torch.tensor([
        float(df.loc[idx,'arena_shape']=='circular'),
        arena_w/ppc,
        arena_h/ppc,
    ]).float().unsqueeze(0).to(device)

    n_feat = 14
    seq_len = 1024
    min_size = 8

    for m1 in Mice:
        for m2 in Mice:
            cols = list(base_bp + ' - ' + str(m1))
            if m1!=m2: cols += list(base_bp + ' - ' + str(m2))
            out = track[cols]

            M1 = f'mouse{m1}'
            M2 = f'mouse{m2}' if m1!=m2 else 'self'

            if not ((M1,M2) in Behaviors): continue

            with torch.no_grad():
                X = load_track(out, fps).to(device)
                length = len(X)
                X = torch.cat([X, torch.zeros((seq_len-(X.size(0)%seq_len),n_feat,2)).to(device)], dim=0)
                X = X.reshape((X.size(0)//seq_len, seq_len, n_feat, 2)).transpose(1,2)

                yp = torch.zeros((X.size(0), 37, seq_len)).to(device)
                yp = model(X,context.repeat(X.size(0),1)).sigmoid()
                yp = yp.transpose(1,2).flatten(0,1).cpu()
                yp = yp * Behaviors[(M1,M2)][:-1]

                yp = resize_back(yp[:length].unsqueeze(0))[0]
                
            # yp = mobile_average(yp)
            
            last = (-1, 37)
            P = []
            for i,(m,p) in enumerate(zip(*yp.max(dim=1))):
                p = int(p) if m>.2 else 37
                if last[1]!=p:  
                    
                    p_len = len(Pred)
                    # Add action
                    if last[1]!=37 and (i-last[0]>min_size or len(P)==0):
                        Pred.loc[p_len] = [p_len, df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i-1]
                        P.append(last)
                        last = (i,int(p))
                    elif i-last[0]<=min_size and P:
                        last = P[-1]
                    else:
                        P.append(last)
                        last = (i,int(p))
                
            if last[1] and last[1]!=37 and i-last[0]>min_size:
                Pred.loc[len(Pred)] = [len(Pred), df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i]

In [None]:
Pred = Pred.drop_duplicates(subset=['video_id', 'agent_id', 'target_id', 'action', 'start_frame'], keep='last')
Pred.row_id = np.arange(len(Pred))

In [None]:
Pred

In [None]:
model.state_dict()