In [1]:
import cv2 as cv
import numpy as np
import pandas as pd 
import sklearn.metrics.pairwise as pairwise
import vptree
from sklearn import preprocessing 
import pickle

In [2]:
import import_ipynb
import ActionClassificationCosineDatasetGen as Dataset

importing Jupyter notebook from ActionClassificationCosineDatasetGen.ipynb


In [3]:
print('Loading Poses Data....')
DATA_DF = Dataset.read_dataset(norm=True)
ARGS = {'TREE':None,
       'X': Dataset.read_dataset(norm=True).loc[:,DATA_DF.columns != 'label'],
       'Y': Dataset.read_dataset(norm=True).loc[:,DATA_DF.columns == 'label'],
       'TREEPATH':'/Users/sandeep/Desktop/dataandmodles/models/vptrees',
       'cosineDistanceMatching':True,
       'weightedDistanceMatching': False}


Loading Poses Data....


In [4]:
def cosineDistanceMatching(poseVector1,poseVector2):
    '''
    Returns the cosine similarity as a distance function between the two L2 normalized vectors.
    The distance is inversely porportional to the similarity between the two vectors
    '''
    poseVector1 , poseVector2 = np.array(poseVector1).reshape(1,-1) , np.array(poseVector2).reshape(1,-1)
    cosineDistance = pairwise.cosine_distances(poseVector1 , poseVector2) 
    distance = np.sqrt(cosineDistance * 2)
    return distance

    

In [5]:
def getPoseData():
    '''
    Helper function of buildVPTree()
    '''
    return ARGS['X'].values.tolist()

In [6]:
def buildVPTree():
    '''
    This fuction will build the vptree for the given dataset
    '''
    poseData = getPoseData()
    ARGS['TREE'] = vptree.VPTree(poseData,cosineDistanceMatching)


In [7]:
def l2Normalize(poses):
    '''
    Helper function of findMatch()
    PARAM poses: a list containing pose [....]
    '''
    return preprocessing.normalize([poses], axis=1)[0]

In [8]:
def findMatch(pose):
    '''
    Helper fuction for getPred()
    Will get the closest vectors to pose using vantage point tree
    
    Returns
    ---------
    similarPoses : a list of tuples. 1 tuple contains distacne and vectors. (0.5 , [..])
    '''
    pose = l2Normalize(pose)
    similarPoses= ARGS['TREE'].get_nearest_neighbor(pose)
    return similarPoses

In [9]:
def getSimilarPosesClassesIndex(similarPoses):
    '''
    Helper function of getPred()
    
    Parameters
    --------------
    similarPoses : a list of tuples. 1 tuple contains distacne and vectors. [(0.5 , [..]) , ...] 
    
    Returns 
    -------------
    index : a list of index in the ARGS['X'] dataframe of the vectors found by vptrees
    '''
    index = []
    for d,p in similarPoses:
        poseIndexList = ARGS['X'][ARGS['X'] == p].dropna().index.tolist()
        index.extend(poseIndexList)
    return list(set(index))

In [10]:
def getSimilarPosesClasses(indexs):
    '''
    Helper function of getPred()
    
    Parameters
    ----------
    index : a list of index in the ARGS['X'] dataframe of the vectors found by vptrees
    
    Returns
    --------
    list : A 2D list with classes of each pose. [[similarVectorClass1],  [....]]
    '''
    return ARGS['Y'].iloc[indexs].values

In [23]:
def getPred(poses):
    '''
    This function finds the classes of pose in poses using VPTREES
    
    Parameters
    --------- 
    poses : a 2D list that contains poses for all the humans detected in the frame. 
    Example [[0_x, 0_y.....] , [0_x, 0_y.....] , ....... ]
    The poses must contain coordinates must be relative to image of size (244 x 244)
    Returns
    ---------
    poesesClasses : a 2D list that contains poses class for each human 
    Example [['walk'] , ['shoot'] ,......]  
    '''
    similarPoses = []
    for pose in poses:
       similarPose = findMatch(pose)
       similarPoses.append(similarPose)
    classIndexs = getSimilarPosesClassesIndex(similarPoses)
    posesClasses = getSimilarPosesClasses(classIndexs)
    return posesClasses
    

In [12]:
def fit(posenetPred):
    '''
    This function take posenet results, use VPTREE that was loaded and find classes for each poses
    
    Parameter
    ---------
    posenetPred : The non-empty output of Posenet model
    Example {'detectionList': [{keypoints:[{score: , part:.. ,position:..}]} , human2 , ....]
    
    Return
    ------
    posesClasses : a 2D list that contains poses class for each human 
    Example [['walk'] , ['shoot'] ,......]  
    '''
    poses = Dataset.generate_data_posenet(posenetPred)
    posesClasses = getPred(poses)
    return posesClasses

#Saving and loading trees

In [13]:
def saveTree(fileName):
    '''
    Save the tree to the path defined at  ARGS['TREEPATH']
    Parameters
    -----------
    fileName : The file name of the plk file
    '''
    
    file_vptree = f"{ARGS['TREEPATH']}/{fileName}.pkl"
    with open(file_vptree, 'wb') as output:
        pickle.dump(ARGS['TREE'], output, pickle.HIGHEST_PROTOCOL)

In [14]:
def loadTree():
    '''
    Loads the tree based on the switch defined at ARGS. Sw
    to the path defined at  ARGS['TREEPATH']
    Parameters
    -----------
    fileName : The file name of the plk file
    '''
    if ARGS['cosineDistanceMatching']:
        fileName = 'vptreeCosineDistance'
    elif ARGS['weightedDistanceMatching']:
        fileName = 'vptreeWeightedDistance'
        
    file_vptree = f"{ARGS['TREEPATH']}/{fileName}.pkl"
    with open(file_vptree, 'rb') as input:
        ARGS['TREE'] = pickle.load(input) 

        
#initailze the VPTREE
print('Initializing VPTREE....')
loadTree()

Initializing VPTREE....


#Test Data 

In [15]:
d = np.array([[0.18694303, 0.11572664, 0.18694303, 0.11572664, 0.19830082,
        0.12155902, 0.20566803, 0.13230287, 0.19983565, 0.13260983,
        0.18264549, 0.12217295, 0.19615205, 0.10774549, 0.19185451,
        0.10958729, 0.20198442, 0.14335369, 0.21242131, 0.16054385,
        0.21088647, 0.17834795, 0.19031967, 0.14396762, 0.18755697,
        0.16115779, 0.18510123, 0.17834795, 0.18663606, 0.11388483,
        0.18540819, 0.11480574, 0.1909336 , 0.11449877, 0.18356639,
        0.11572664],
       [0.22652992, 0.04472884, 0.22652992, 0.04472884, 0.23198075,
        0.04761457, 0.2332633 , 0.0522638 , 0.2332633 , 0.05899718,
        0.22572833, 0.04841616, 0.22652992, 0.05242412, 0.22604897,
        0.05579081, 0.23550775, 0.05547017, 0.23214107, 0.06236386,
        0.23230139, 0.07005915, 0.23133948, 0.05579081, 0.23021725,
        0.06252418, 0.23166011, 0.06941787, 0.22685056, 0.04408756,
        0.2263696 , 0.04408756, 0.22797279, 0.04472884, 0.22556801,
        0.04488916],
       [0.1455586 , 0.14369843, 0.1455586 , 0.14369843, 0.15532451,
        0.15811477, 0.17439129, 0.15160416, 0.17439129, 0.13300243,
        0.13346747, 0.15811477, 0.15113912, 0.15299929, 0.17392625,
        0.13672277, 0.15253425, 0.19950364, 0.15578955, 0.23345181,
        0.16183512, 0.26879511, 0.13300243, 0.1990386 , 0.13067721,
        0.23484694, 0.13067721, 0.26833007, 0.14834886, 0.14183825,
        0.14230329, 0.14137321, 0.15578955, 0.14741877, 0.13625773,
        0.14788382],
       [0.18277657, 0.1276963 , 0.18277657, 0.1276963 , 0.17603205,
        0.13376637, 0.17333424, 0.14478242, 0.17063643, 0.14950359,
        0.19042036, 0.13444082, 0.19559116, 0.14343352, 0.19356781,
        0.15040286, 0.17895467, 0.15332549, 0.1818773 , 0.16861307,
        0.18614883, 0.18210212, 0.18907146, 0.15287585, 0.18907146,
        0.16838825, 0.19424226, 0.1832262 , 0.18097803, 0.12702185,
        0.18300139, 0.12634739, 0.1780554 , 0.12724666, 0.18435029,
        0.12792112],
       [0.20281541, 0.09841091, 0.20281541, 0.09841091, 0.21267583,
        0.10131103, 0.21634933, 0.10614458, 0.21924945, 0.11001141,
        0.20474883, 0.10111769, 0.2064889 , 0.10556455, 0.20300875,
        0.10498453, 0.21480259, 0.11310488, 0.21364254, 0.12315864,
        0.21731603, 0.13340575, 0.20977571, 0.11271819, 0.21035573,
        0.12335198, 0.21074242, 0.13050563, 0.2041688 , 0.09725086,
        0.20204204, 0.0974442 , 0.20706893, 0.09725086, 0.2018487 ,
        0.09763754],
       [0.22443155, 0.05285568, 0.22443155, 0.05285568, 0.22982259,
        0.05707476, 0.23228372, 0.06293459, 0.22935381, 0.06809124,
        0.22454875, 0.05648878, 0.22478314, 0.06094225, 0.2226736 ,
        0.06176262, 0.22982259, 0.0662161 , 0.22736146, 0.07336509,
        0.2297054 , 0.08180325, 0.22665829, 0.0659817 , 0.22794745,
        0.07359948, 0.23204933, 0.08110007, 0.22536912, 0.0522697 ,
        0.22384557, 0.05238689, 0.22806464, 0.05262129, 0.22349398,
        0.05309007],
       [0.22211447, 0.05895235, 0.22211447, 0.05895235, 0.22414731,
        0.06312502, 0.22371934, 0.0684746 , 0.22136553, 0.07232629,
        0.22821299, 0.06259006, 0.23088778, 0.06676273, 0.23227867,
        0.07040044, 0.22478926, 0.07104239, 0.22468227, 0.07821083,
        0.22757104, 0.08388138, 0.22853396, 0.07147036, 0.2255382 ,
        0.07810384, 0.22778502, 0.08388138, 0.2229704 , 0.05820341,
        0.22254244, 0.05809641, 0.22511023, 0.05905934, 0.22511023,
        0.05863137],
       [0.21485147, 0.08675887, 0.21485147, 0.08675887, 0.21149442,
        0.08990611, 0.21086497, 0.09525641, 0.20960607, 0.09546623,
        0.21946742, 0.08906684, 0.22209012, 0.0937877 , 0.22471282,
        0.09546623, 0.21611036, 0.09745948, 0.21946742, 0.09315825,
        0.22009687, 0.09085028, 0.22009687, 0.09714476, 0.22083122,
        0.09357789, 0.22649625, 0.09703985, 0.21191405, 0.08675887,
        0.21537601, 0.08644415, 0.21296313, 0.08686378, 0.21653   ,
        0.08633924],
       [0.2191449 , 0.06245351, 0.2191449 , 0.06245351, 0.22806683,
        0.06914496, 0.22946088, 0.07639402, 0.23308542, 0.08224904,
        0.21496275, 0.06775091, 0.21301108, 0.07527878, 0.21496275,
        0.08085499, 0.22639397, 0.08475833, 0.22750921, 0.09702599,
        0.22695159, 0.10873602, 0.21886609, 0.0839219 , 0.22026014,
        0.09842004, 0.22137539, 0.10706316, 0.22026014, 0.06133827,
        0.21830847, 0.06133827, 0.22276944, 0.0621747 , 0.21802966,
        0.06245351],
       [0.17711887, 0.12412839, 0.17711887, 0.12412839, 0.19889578,
        0.13501684, 0.20542885, 0.15171248, 0.18510374, 0.15824555,
        0.16114914, 0.13792043, 0.15243837, 0.15461606, 0.16695631,
        0.16187503, 0.18002246, 0.16550452, 0.17131169, 0.18147425,
        0.20688065, 0.21341372, 0.16260093, 0.1705858 , 0.17857066,
        0.19018502, 0.1466312 , 0.20833244, 0.17494118, 0.11832121,
        0.17421528, 0.11977301, 0.17929656, 0.11904711, 0.1698599 ,
        0.1204989 ]])

In [16]:
d.shape

(10, 36)

#Test

In [None]:
c = getPred(d)

In [19]:
%debug

> [0;32m<ipython-input-19-fe42bc865289>[0m(1)[0;36m<module>[0;34m()[0m
[0;32m----> 1 [0;31m[0mprint[0m[0;34m([0m[0mvar_dic_list[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> up
*** Oldest frame
ipdb> up
*** Oldest frame
ipdb> exit
