In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

from preprocessing import preprocess, base_bp
from models import *



def process_labels(X):
    if isinstance(X, str): X = eval(X)
    if isinstance(X, list):
        y = []
        for x in X:
            y.append(x.split(',')[-1].replace("'", ""))
        return y
    return []

PATH = '../../Datasets/MABe-mouse-behavior-detection/'
LABELS = np.unique(pd.read_csv(PATH+'train.csv').behaviors_labeled.apply(process_labels).explode().dropna()).tolist()+['none']
df = pd.read_csv(PATH+'test.csv')

In [2]:
import torch
import torchvision
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
from models import ARModel

In [4]:
CFG = {
    'num_workers':24,
    'seed':2,
    'batch_size':64,
    'verbose':2,
    'max_grad_norm':1,
    'train_only':True,
    'train_backbone':True,
    'scheduler':True,
    "lr":1e-4,
    
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
M = [
    '3f6515fb758330f72dbad465dfeae81a4c58120c0b695af47bd0279baea249f1',
    '39aba54b6489b8502e5a8b9e6ef502e7df459bdd3854d521f7f9160a269bf0a8',
    'f8c34eb26c887a79802f436c19825af5ccacd25b80015bbc7801ea5b86e540d7',
    'b688fc4b98d9328388235928caaf29311a4075abc4743346fbd17a1281a4d289',
    'ea423bc8007e78e8e32e1e6ab63986496b3b99b7dae08f23be204073c751d49e',
    '978f98560a3d8e0107f8429ed969ce48034e526f92cfbb5f33289afd7a7e2c8d',
]
MODELS = []
ARCH = [ARModel,ARModel,ARModel,ARModel,ARModel,ARModel,ARModel,ARModel]
for m,Arch in zip(M, ARCH):
    m = f'models/{m}.pth'
    model = Arch(CFG)
    model.load_state_dict(torch.load(m, map_location=torch.device(device)))
    MODELS.append(model.eval().to(device))

In [6]:
def load_track(track, fps):
    pdist = torch.nn.PairwiseDistance(p=2)

    x = torch.tensor(np.stack([np.stack([np.array(_x) for _x in x]) for x in track.values]))
    if x.size(1)==7: x = torch.cat([x,x], dim=1)
    

    resize_x = torchvision.transforms.Resize((int(len(x)*(30./fps)), 2))
    
    x_ = []
    for i in range(14):
        x_.append(resize_x(x[:,i].unsqueeze(0))[0].unsqueeze(1) / ppc)
    x = torch.cat(x_, dim=1)
    
    return x.float()


def get_behaviors_labeled(x):
    if isinstance(x, str):
        B = {}
        X = eval(x)
        for x in X:
            m1,m2,b = x.split(',')
            if (m1,m2) in B: B[(m1,m2)].append(b.replace('"', '').replace("'", ''))
            else: B[(m1,m2)] = [b.replace('"', '').replace("'", '')]

        for p in B:
            vec = torch.zeros((len(LABELS)))
            for x in B[p]: vec[LABELS.index(x)] = 1
            vec[-1]=1
            if vec.sum()==0: vec += 1
            B[p] = vec

            
        return B
    return {}

In [7]:
import warnings
warnings.filterwarnings('ignore')

Pred = pd.DataFrame(columns=['row_id','video_id','agent_id','target_id','action','start_frame','stop_frame'])


for idx in tqdm(df.index):
    filepath = f"{df.loc[idx, 'lab_id']}/{df.loc[idx, 'video_id']}"
    track = pd.read_parquet(PATH+f'test_tracking/{filepath}.parquet')

    fps = df.loc[idx, 'frames_per_second']
    ppc = df.loc[idx, 'pix_per_cm_approx']
    min_coord = np.array([track.x.min(), track.y.min()])
    arena_w,arena_h = np.array([track.x.max(), track.y.max()])-min_coord
    

    Behaviors = get_behaviors_labeled(df.loc[idx, 'behaviors_labeled'])
    track, _, Mice = preprocess(track, None)

    resize_back = torchvision.transforms.Resize((len(track), 37))


    for c in track.columns:
        track[c] = track[c].apply(lambda x: x-min_coord if (isinstance(x, np.ndarray) and pd.isna([x]).sum()==0) else np.array([0., 0.]))

    context = torch.tensor([
        float(df.loc[idx,'arena_shape']=='circular'),
        arena_w/ppc,
        arena_h/ppc,
    ]).float().unsqueeze(0).to(device)

    n_feat = 14
    seq_len = 1024
    min_size = 8

    for m1 in Mice:
        for m2 in Mice:
            cols = list(base_bp + ' - ' + str(m1))
            if m1!=m2: cols += list(base_bp + ' - ' + str(m2))
            out = track[cols]

            M1 = f'mouse{m1}'
            M2 = f'mouse{m2}' if m1!=m2 else 'self'

            if not ((M1,M2) in Behaviors): continue

            with torch.no_grad():
                X = load_track(out, fps).to(device)
                length = len(X)

                yp = torch.zeros((X.size(0), 37))
                for offset in [0,seq_len//4,seq_len//2,3*seq_len//4]:
                    X_ = torch.cat([torch.zeros((offset,n_feat,2)).to(device), X, torch.zeros((seq_len-((X.size(0)+offset)%seq_len),n_feat,2)).to(device)], dim=0)
                    X_ = X_.reshape((X_.size(0)//seq_len, seq_len, n_feat, 2)).transpose(1,2)
    
                    yp_ = torch.zeros((X_.size(0), 37, seq_len)).to(device)
                    for model in MODELS:
                        yp_ += model(X_,context.repeat(X_.size(0),1)).sigmoid()
                    yp_ = yp_.transpose(1,2).flatten(0,1).cpu()/len(MODELS)
                    yp_ = yp_ * Behaviors[(M1,M2)][:-1]
                    yp += yp_[offset:offset+length]
                    
    
                yp = resize_back(yp[:length].unsqueeze(0)/4.)[0]
                
            # yp = mobile_average(yp)
            
            last = (-1, 37)
            P = []
            for i,(m,p) in enumerate(zip(*yp.max(dim=1))):
                p = int(p) if m>.2 else 37
                if last[1]!=p:  
                    
                    p_len = len(Pred)
                    # Add action
                    if last[1]!=37 and (i-last[0]>min_size or len(P)==0):
                        Pred.loc[p_len] = [p_len, df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i-1]
                        P.append(last)
                        last = (i,int(p))
                    elif i-last[0]<=min_size and P:
                        last = P[-1]
                    else:
                        P.append(last)
                        last = (i,int(p))
                
            if last[1] and last[1]!=37 and i-last[0]>min_size:
                Pred.loc[len(Pred)] = [len(Pred), df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i]

100%|█████████████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.50s/it]


In [9]:
Pred = Pred.drop_duplicates(subset=['video_id', 'agent_id', 'target_id', 'action', 'start_frame'], keep='last')
Pred.row_id = np.arange(len(Pred))

In [10]:
Pred

Unnamed: 0,row_id,video_id,agent_id,target_id,action,start_frame,stop_frame
0,0,438887472,mouse1,self,rear,8902,8937
1,1,438887472,mouse1,self,rear,14450,14503
2,2,438887472,mouse1,mouse4,submit,1364,1373
5,3,438887472,mouse1,mouse4,submit,1522,1545
6,4,438887472,mouse2,mouse1,avoid,1130,1209
...,...,...,...,...,...,...,...
240,213,438887472,mouse4,self,rear,16280,16299
241,214,438887472,mouse4,self,rear,16351,16393
244,215,438887472,mouse4,self,rear,16595,16677
245,216,438887472,mouse4,self,rear,16687,16710
