In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

from preprocessing import preprocess, base_bp
from models import *



def process_labels(X):
    if isinstance(X, str): X = eval(X)
    if isinstance(X, list):
        y = []
        for x in X:
            y.append(x.split(',')[-1].replace("'", ""))
        return y
    return []

PATH = '../../Datasets/MABe-mouse-behavior-detection/'
LABELS = np.unique(pd.read_csv(PATH+'train.csv').behaviors_labeled.apply(process_labels).explode().dropna()).tolist()+['none']
df = pd.read_csv(PATH+'test.csv')

In [2]:
import torch
import torchvision
import torch.nn as nn

data = pd.DataFrame(columns=['lab_id', 'video_id', 'mouse_1', 'mouse_2', 'chunk', 'ok'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, activation=nn.GELU, padding='same', return_skip=True, bias=True):
        super().__init__()
        self.pool = nn.AvgPool1d(2)
        self.conv = SEResConv(in_channels, out_channels, kernel_size, dilation=dilation, activation=activation, padding=padding, dropout=0.3)
        self.return_skip = return_skip
        
    def forward(self, x):
        high = self.conv(x)
        low = self.pool(high)        

        if self.return_skip: 
            return low, high
        return low


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, activation=nn.GELU, padding='same', bias=True):
        super().__init__()
        if padding=='same': padding = kernel_size//2*dilation
        self.conv = SEResDeconv(in_channels, out_channels, kernel_size, dilation=dilation, activation=activation, padding=padding, dropout=0.3)
        self.upsample = Upsample(1,2)
        self.attention = CrossAttention(in_channels, 8)
        
    def forward(self, x, skip):
        x = self.upsample(x)
        #x = self.conv(x)
        x = self.attention(x.transpose(-1,-2),skip.transpose(-1,-2)).transpose(-1,-2)
        x = self.conv(x)

        return x



class UNet(nn.Module):
    def __init__(self, in_channels, hidden_dim, padding='same', activation=nn.GELU, n_dim=1):
        super().__init__()

        self.enc_1 = EncoderBlock(in_channels, hidden_dim, 5, 1)
        self.enc_2 = EncoderBlock(hidden_dim, hidden_dim*2, 5, 2)
        self.enc_3 = EncoderBlock(hidden_dim*2, hidden_dim*4, 5, 4)

        self.bottleneck = SEResConv(hidden_dim*4, hidden_dim*4, 3, dropout=0.3)

        self.dec_1 = DecoderBlock(hidden_dim*4, hidden_dim*2, 5, 4, padding=8)
        self.dec_2 = DecoderBlock(hidden_dim*2, hidden_dim, 5, 2, padding=4)
        self.dec_3 = DecoderBlock(hidden_dim, hidden_dim, 5, 1, padding=2)
        self.head =  nn.Sequential(
            nn.ConvTranspose1d(hidden_dim, 37, 5, padding=2, dilation=1, bias=True)
        ) 
        
    def forward(self, x, h=None):
        
        x,skip1=self.enc_1(x)
        x,skip2=self.enc_2(x)
        x,skip3=self.enc_3(x)

        x = self.bottleneck(x)

        x=self.dec_1(x,skip3)
        x=self.dec_2(x,skip2)
        x=self.dec_3(x,skip1)

        x=self.head(x)
            
        return x

In [17]:
class Feature_Engineering(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        if config: self.config.update(config)
        self.pdist = torch.nn.PairwiseDistance(p=2)

    def forward(self, x, context):
        B,C,N,_ = x.shape
        C = C//2
        x1 = x[:,:C]
        x2 = x[:,C:]

        rel_x = (x1-x2)
        rel_dist = rel_x.norm(dim=-1)

        dx1 = x1.diff(dim=2, prepend=x1[:, :, :1])
        dx2 = x2.diff(dim=2, prepend=x2[:, :, :1])

        adx1 = torch.einsum('...i,...i->...', dx1[:, :, 1:], dx1[:, :, :-1]) / (dx1[:, :, 1:].norm(dim=-1) * dx1[:, :, :-1].norm(dim=-1) + 1e-4)
        adx1 = torch.cat([torch.zeros_like(adx1[:, :, :1]), adx1], dim=2)

        adx2 = torch.einsum('...i,...i->...', dx2[:, :, 1:], dx2[:, :, :-1]) / (dx2[:, :, 1:].norm(dim=-1) * dx2[:, :, :-1].norm(dim=-1) + 1e-4)
        adx2 = torch.cat([torch.zeros_like(adx2[:, :, :1]), adx2], dim=2)

        dx1_thresh = (dx1.norm(dim=-1)>.1).float()
        dx2_thresh = (dx2.norm(dim=-1)>.1).float()

        dx1_= dx1.norm(dim=-1)
        dx2_= dx2.norm(dim=-1)

        ddx1 = dx1.diff(dim=2, prepend=dx1[:, :, :1])
        ddx2 = dx2.diff(dim=2, prepend=dx2[:, :, :1])

        cross_prod_1 = (dx1[:,:,:,[0]] * ddx1[:,:,:,[1]] - dx1[:,:,:,[1]] * ddx1[:,:,:,[0]]).squeeze(dim=-1)
        cross_prod_2 = (dx2[:,:,:,[0]] * ddx2[:,:,:,[1]] - dx2[:,:,:,[1]] * ddx2[:,:,:,[0]]).squeeze(dim=-1)
        
        dirs = torch.einsum('...i,...i->...', dx1, dx2) / (dx1.norm(dim=-1) * dx2.norm(dim=-1) + 1e-6)
        cross = (dx1[:,:,:,[0]] * dx2[:,:,:,[1]] - dx1[:,:,:,[1]] * dx2[:,:,:,[0]]).squeeze(dim=-1)
        d = torch.cat([
            self.pdist(x1.roll(i+1, dims=1), x2) #* mask_x1.roll(i+1, dims=1) * mask_x2
        for i in range(x1.size(1))], dim=1)
        
        dd = torch.cat([d,torch.zeros_like(d[:,:,:1])], dim=2).diff(dim=-1)


        lead_1 = torch.einsum('...i,...i->...', dx1, rel_x) / (dx1.norm(dim=-1) * rel_x.norm(dim=-1) + 1e-6)
        lead_2 = torch.einsum('...i,...i->...', dx2, -rel_x) / (dx2.norm(dim=-1) * rel_x.norm(dim=-1) + 1e-6)
        drel_x = rel_x.diff(dim=2, prepend=rel_x[:, :, :1])

        
        tail_to_neck_1 = (x1[:, [6]] - x1[:, [4]])
        neck_to_nose_1 = (x1[:, [4]] - x1[:, [5]])
        tail_to_nose_1 = (x1[:, [6]] - x1[:, [5]])
        head_angle_1 = torch.einsum('...i,...i->...', tail_to_neck_1, neck_to_nose_1) / (tail_to_neck_1.norm(dim=-1) * neck_to_nose_1.norm(dim=-1) + 1e-6)

        tail_to_neck_2 = (x2[:, [6]] - x2[:, [4]])
        neck_to_nose_2 = (x2[:, [4]] - x2[:, [5]])
        tail_to_nose_2 = (x2[:, [6]] - x2[:, [5]])
        head_angle_2 = torch.einsum('...i,...i->...', tail_to_neck_2, neck_to_nose_2) / (tail_to_neck_2.norm(dim=-1) * neck_to_nose_2.norm(dim=-1) + 1e-6)

        ears_span_1 = (x1[:, [0]] - x1[:, [1]]).norm(dim=-1)
        ears_span_2 = (x2[:, [0]] - x2[:, [1]]).norm(dim=-1)

        len1 = (x1[:, -1] - x1[:, -2]).norm(dim=-1)
        len2 = (x2[:, -1] - x2[:, -2]).norm(dim=-1)

        # f_1 = torch.einsum('...i,...i->...', dx1, tail_to_nose_2) / (dx1.norm(dim=-1) * tail_to_nose_2.norm(dim=-1) + 1e-6)
        # f_2 = torch.einsum('...i,...i->...', dx2, tail_to_nose_1) / (dx2.norm(dim=-1) * tail_to_nose_1.norm(dim=-1) + 1e-6)

        # f_1 = f_1 * dx1_thresh
        # f_2 = f_2 * dx2_thresh

        # wh_ratio_1 = self.pdist(x1[:,[2]],x1[:,[3]])/self.pdist(x1[:,[5]],x1[:,[6]])
        # wh_ratio_2 = self.pdist(x2[:,[2]],x2[:,[3]])/self.pdist(x2[:,[5]],x2[:,[6]])

        x = torch.concat([dx1_, dx2_, cross_prod_1, cross_prod_2, adx1, adx2, dirs, d, dd, lead_1, lead_2], dim=1)
        return x

In [18]:
class ARModel(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        self.config = {
        }
        if config: self.config.update(config)

        self.training = True

        self.feature_eng = Feature_Engineering()
        n_channels = 7
        n_features = n_channels*(n_channels*2+9)

        self.context_encoder = nn.Sequential(
            nn.Linear(71, n_features),
            nn.ReLU(),
        )

        base_h_dim = 128

        # self.unet = UNet(n_features, base_h_dim)

        self.encoder = nn.Sequential(
            SEResConv(n_features, base_h_dim, 5, dilation=1, dropout=0.3),
            nn.AvgPool1d(2),
            SEResConv(base_h_dim, base_h_dim*2, 5, dilation=2, dropout=0.3),
            nn.AvgPool1d(2),
            SEResConv(base_h_dim*2, base_h_dim*4, 5, dilation=4, dropout=0.3),
            # nn.AvgPool1d(2),
            # Conv(base_h_dim*4, base_h_dim*8, 9),
        )
        
        self.decoder = nn.Sequential(
            # Deconv(base_h_dim*8, base_h_dim*4, 9, padding=4),
            # Upsample(1,2),
            SEResDeconv(base_h_dim*4, base_h_dim*2, 5, padding=8, dilation=4, dropout=0.3),
            Upsample(1,2),
            SEResDeconv(base_h_dim*2, base_h_dim, 5, padding=4, dilation=2, dropout=0.3),
            Upsample(1,2),
            nn.ConvTranspose1d(base_h_dim, 37, 5, padding=2, dilation=1, bias=True),
        )
        #self.ca = CrossAttention(n_features, 7)


    def forward(self, x, context):
        x = self.feature_eng(x, context)
        #c = self.context_encoder(context)
        
        #x = self.ca(x.transpose(-2,-1), c.unsqueeze(-2)).transpose(-2,-1)
        
        x = self.encoder(x)
        x = self.decoder(x)

        # x = self.unet(x)
        
        return x

In [48]:
CFG = {
    'num_workers':24,
    'seed':2,
    'batch_size':64,
    'verbose':2,
    'max_grad_norm':1,
    'train_only':True,
    'train_backbone':True,
    'scheduler':True,
    "lr":1e-4,
    
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [49]:
M = [
    '3f6515fb758330f72dbad465dfeae81a4c58120c0b695af47bd0279baea249f1',
    '39aba54b6489b8502e5a8b9e6ef502e7df459bdd3854d521f7f9160a269bf0a8',
    'f8c34eb26c887a79802f436c19825af5ccacd25b80015bbc7801ea5b86e540d7',
    'b688fc4b98d9328388235928caaf29311a4075abc4743346fbd17a1281a4d289',
    'ea423bc8007e78e8e32e1e6ab63986496b3b99b7dae08f23be204073c751d49e',
    '978f98560a3d8e0107f8429ed969ce48034e526f92cfbb5f33289afd7a7e2c8d',
]
MODELS = []
ARCH = [ARModel,ARModel,ARModel,ARModel,ARModel,ARModel,ARModel,ARModel]
for m,Arch in zip(M, ARCH):
    m = f'models/{m}.pth'
    model = Arch(CFG)
    model.load_state_dict(torch.load(m, map_location=torch.device(device)))
    MODELS.append(model.eval().to(device))

In [50]:
def load_track(track, fps):
    pdist = torch.nn.PairwiseDistance(p=2)

    x = torch.tensor(np.stack([np.stack([np.array(_x) for _x in x]) for x in track.values]))
    if x.size(1)==7: x = torch.cat([x,x], dim=1)
    

    resize_x = torchvision.transforms.Resize((int(len(x)*(30./fps)), 2))
    
    x_ = []
    for i in range(14):
        x_.append(resize_x(x[:,i].unsqueeze(0))[0].unsqueeze(1) / ppc)
    x = torch.cat(x_, dim=1)
    
    return x.float()


def get_behaviors_labeled(x):
    if isinstance(x, str):
        B = {}
        X = eval(x)
        for x in X:
            m1,m2,b = x.split(',')
            if (m1,m2) in B: B[(m1,m2)].append(b.replace('"', '').replace("'", ''))
            else: B[(m1,m2)] = [b.replace('"', '').replace("'", '')]

        for p in B:
            vec = torch.zeros((len(LABELS)))
            for x in B[p]: vec[LABELS.index(x)] = 1
            vec[-1]=1
            if vec.sum()==0: vec += 1
            B[p] = vec

            
        return B
    return {}

In [63]:
import warnings
warnings.filterwarnings('ignore')

Pred = pd.DataFrame(columns=['row_id','video_id','agent_id','target_id','action','start_frame','stop_frame'])


for idx in tqdm(df.index):
    filepath = f"{df.loc[idx, 'lab_id']}/{df.loc[idx, 'video_id']}"
    track = pd.read_parquet(PATH+f'test_tracking/{filepath}.parquet')

    fps = df.loc[idx, 'frames_per_second']
    ppc = df.loc[idx, 'pix_per_cm_approx']
    min_coord = np.array([track.x.min(), track.y.min()])
    arena_w,arena_h = np.array([track.x.max(), track.y.max()])-min_coord
    

    Behaviors = get_behaviors_labeled(df.loc[idx, 'behaviors_labeled'])
    track, _, Mice = preprocess(track, None)

    resize_back = torchvision.transforms.Resize((len(track), 37))


    for c in track.columns:
        track[c] = track[c].apply(lambda x: x-min_coord if (isinstance(x, np.ndarray) and pd.isna([x]).sum()==0) else np.array([0., 0.]))

    context = torch.tensor([
        float(df.loc[idx,'arena_shape']=='circular'),
        arena_w/ppc,
        arena_h/ppc,
    ]).float().unsqueeze(0).to(device)

    n_feat = 14
    seq_len = 1024
    min_size = 8

    for m1 in Mice:
        for m2 in Mice:
            cols = list(base_bp + ' - ' + str(m1))
            if m1!=m2: cols += list(base_bp + ' - ' + str(m2))
            out = track[cols]

            M1 = f'mouse{m1}'
            M2 = f'mouse{m2}' if m1!=m2 else 'self'

            if not ((M1,M2) in Behaviors): continue

            with torch.no_grad():
                X = load_track(out, fps).to(device)
                length = len(X)

                yp = torch.zeros((X.size(0), 37))
                for offset in [0,seq_len//4,seq_len//2,3*seq_len//4]:
                    X_ = torch.cat([torch.zeros((offset,n_feat,2)).to(device), X, torch.zeros((seq_len-((X.size(0)+offset)%seq_len),n_feat,2)).to(device)], dim=0)
                    X_ = X_.reshape((X_.size(0)//seq_len, seq_len, n_feat, 2)).transpose(1,2)
    
                    yp_ = torch.zeros((X_.size(0), 37, seq_len)).to(device)
                    for model in MODELS:
                        yp_ += model(X_,context.repeat(X_.size(0),1)).sigmoid()
                    yp_ = yp_.transpose(1,2).flatten(0,1).cpu()/len(MODELS)
                    yp_ = yp_ * Behaviors[(M1,M2)][:-1]
                    yp += yp_[offset:offset+length]
                    
    
                yp = resize_back(yp[:length].unsqueeze(0)/4.)[0]
                
            # yp = mobile_average(yp)
            
            last = (-1, 37)
            P = []
            for i,(m,p) in enumerate(zip(*yp.max(dim=1))):
                p = int(p) if m>.2 else 37
                if last[1]!=p:  
                    
                    p_len = len(Pred)
                    # Add action
                    if last[1]!=37 and (i-last[0]>min_size or len(P)==0):
                        Pred.loc[p_len] = [p_len, df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i-1]
                        P.append(last)
                        last = (i,int(p))
                    elif i-last[0]<=min_size and P:
                        last = P[-1]
                    else:
                        P.append(last)
                        last = (i,int(p))
                
            if last[1] and last[1]!=37 and i-last[0]>min_size:
                Pred.loc[len(Pred)] = [len(Pred), df.loc[idx, 'video_id'], M1, M2, LABELS[last[1]], last[0], i]

100%|█████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.44s/it]


In [69]:
12/1400

0.008571428571428572

In [65]:
Pred = Pred.drop_duplicates(subset=['video_id', 'agent_id', 'target_id', 'action', 'start_frame'], keep='last')
Pred.row_id = np.arange(len(Pred))

In [66]:
Pred

Unnamed: 0,row_id,video_id,agent_id,target_id,action,start_frame,stop_frame
0,0,438887472,mouse1,self,rear,8902,8937
1,1,438887472,mouse1,self,rear,14450,14503
2,2,438887472,mouse1,mouse4,submit,1364,1373
5,3,438887472,mouse1,mouse4,submit,1522,1545
6,4,438887472,mouse2,mouse1,avoid,1130,1209
...,...,...,...,...,...,...,...
240,213,438887472,mouse4,self,rear,16280,16299
241,214,438887472,mouse4,self,rear,16351,16393
244,215,438887472,mouse4,self,rear,16595,16677
245,216,438887472,mouse4,self,rear,16687,16710


In [39]:
Pred

Unnamed: 0,row_id,video_id,agent_id,target_id,action,start_frame,stop_frame
0,0,438887472,mouse1,self,rear,8910,8930
1,1,438887472,mouse1,self,rear,14447,14503
2,2,438887472,mouse1,mouse4,submit,1365,1373
3,3,438887472,mouse1,mouse4,submit,1537,1545
4,4,438887472,mouse2,mouse1,avoid,1130,1209
...,...,...,...,...,...,...,...
249,225,438887472,mouse4,self,rear,16281,16297
250,226,438887472,mouse4,self,rear,16354,16384
251,227,438887472,mouse4,self,rear,16594,16713
252,228,438887472,mouse4,self,rear,17245,17285
