In [1]:
import pandas as pd
import egoviz.models.processing as pr
import egoviz.models.evaluation as ev

from collections import Counter
from sklearn.ensemble import RandomForestClassifier


SEED = 42

In [2]:
data = pr.load_pickle(r"C:\Users\adesh\Documents\GitHub\EgoVizML\data\home_date_all_preds.pkl")

In [3]:
# create data df
df = pd.DataFrame(columns=['video', 'frame', 'classes', 'active', 'adl'])

for id, dets in data.items():
    adl = id.split('_', 1)[0]
    video = id.split('_')[1]
    frame = id.split('_')[2]
    classes = dets['remapped_metadata']
    active = dets['active_objects']

    row = {'video': video, 'frame': frame, 'classes': classes, 'adl': adl, 'active': active}

    df.loc[len(df)] = row

df.head()

Unnamed: 0,video,frame,classes,active,adl
0,SCI06-10--12,frame0,"[clothing_accessory, phone_tablet, other, offi...","[False, False, False, True, True, False, False...",communication-management
1,SCI06-10--12,frame147,"[phone_tablet, office_stationary, other, offic...","[False, True, False, True, False, False, False...",communication-management
2,SCI06-10--12,frame196,"[phone_tablet, phone_tablet, other, office_sta...","[False, False, False, True, False, False, Fals...",communication-management
3,SCI06-10--12,frame245,"[phone_tablet, clothing_accessory, other, offi...","[False, False, False, False, True, False, Fals...",communication-management
4,SCI06-10--12,frame294,"[phone_tablet, clothing_accessory, other, offi...","[False, False, False, True, False, False, Fals...",communication-management


In [4]:
def count_occurrences(classes, active):
    class_counts = Counter(classes)
    active_counts = Counter({cls: sum([act and (cls == c) for act, c in zip(active, classes)]) for cls in set(classes)})
    return class_counts, active_counts

# Apply the function to create new columns
df['class_counts'], df['active_counts'] = zip(*df.apply(lambda row: count_occurrences(row['classes'], row['active']), axis=1))

df.head()


Unnamed: 0,video,frame,classes,active,adl,class_counts,active_counts
0,SCI06-10--12,frame0,"[clothing_accessory, phone_tablet, other, offi...","[False, False, False, True, True, False, False...",communication-management,"{'clothing_accessory': 1, 'phone_tablet': 3, '...","{'furniture': 0, 'other': 0, 'footwear': 0, 'c..."
1,SCI06-10--12,frame147,"[phone_tablet, office_stationary, other, offic...","[False, True, False, True, False, False, False...",communication-management,"{'phone_tablet': 2, 'office_stationary': 3, 'o...","{'furnishing': 0, 'furniture': 0, 'other': 0, ..."
2,SCI06-10--12,frame196,"[phone_tablet, phone_tablet, other, office_sta...","[False, False, False, True, False, False, Fals...",communication-management,"{'phone_tablet': 2, 'other': 1, 'office_statio...","{'furnishing': 0, 'furniture': 0, 'other': 0, ..."
3,SCI06-10--12,frame245,"[phone_tablet, clothing_accessory, other, offi...","[False, False, False, False, True, False, Fals...",communication-management,"{'phone_tablet': 2, 'clothing_accessory': 1, '...","{'other': 0, 'footwear': 0, 'clothing_accessor..."
4,SCI06-10--12,frame294,"[phone_tablet, clothing_accessory, other, offi...","[False, False, False, True, False, False, Fals...",communication-management,"{'phone_tablet': 2, 'clothing_accessory': 1, '...","{'other': 0, 'footwear': 0, 'clothing_accessor..."


In [5]:
# Create a new DataFrame from class_counts and active_counts
counts_df = pd.DataFrame(df.apply(lambda row: {'adl': row['adl'], 'video': row['video'], **{f'count_{key}': value for key, value in row['class_counts'].items()}, **{f'active_{key}': value for key, value in row['active_counts'].items()}}, axis=1).tolist())

# Group by video and sum the values for each video
grouped_counts_df = counts_df.groupby('video').agg({**{'adl': 'first'}, **{col: 'sum' for col in counts_df.columns if col not in ['adl', 'video']}})

In [6]:
grouped_counts_df = grouped_counts_df.reset_index(); grouped_counts_df.head()

Unnamed: 0,video,adl,count_clothing_accessory,count_phone_tablet,count_other,count_office_stationary,count_footwear,count_furniture,active_furniture,active_other,...,count_house_fixtures,active_house_fixtures,count_tableware,active_tableware,count_bathroom_fixture,active_bathroom_fixture,count_plant,active_plant,count_hat,active_hat
0,SCI02-1--1,functional-mobility,0.0,6.0,7.0,13.0,3.0,10.0,0.0,0.0,...,20.0,0.0,6.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0
1,SCI02-1--10,meal-preparation-cleanup,2.0,0.0,6.0,2.0,1.0,2.0,0.0,0.0,...,17.0,0.0,18.0,1.0,6.0,0.0,0.0,0.0,0.0,0.0
2,SCI02-1--11,meal-preparation-cleanup,1.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,...,17.0,0.0,25.0,0.0,13.0,0.0,0.0,0.0,0.0,0.0
3,SCI02-1--12,meal-preparation-cleanup,0.0,0.0,8.0,3.0,0.0,4.0,0.0,0.0,...,15.0,0.0,21.0,0.0,7.0,0.0,1.0,0.0,0.0,0.0
4,SCI02-1--2,meal-preparation-cleanup,10.0,4.0,6.0,2.0,1.0,2.0,0.0,0.0,...,35.0,1.0,8.0,0.0,11.0,0.0,3.0,0.0,0.0,0.0


In [7]:
df_no_active = grouped_counts_df.drop(columns=[col for col in grouped_counts_df.columns if 'active' in col]); df_no_active.head()

Unnamed: 0,video,adl,count_clothing_accessory,count_phone_tablet,count_other,count_office_stationary,count_footwear,count_furniture,count_furnishing,count_drinkware,...,count_musical_instrument,count_sink,count_cabinetry,count_kitchen_appliance,count_tv_computer,count_house_fixtures,count_tableware,count_bathroom_fixture,count_plant,count_hat
0,SCI02-1--1,functional-mobility,0.0,6.0,7.0,13.0,3.0,10.0,8.0,19.0,...,0.0,6.0,15.0,7.0,0.0,20.0,6.0,4.0,0.0,0.0
1,SCI02-1--10,meal-preparation-cleanup,2.0,0.0,6.0,2.0,1.0,2.0,3.0,18.0,...,0.0,15.0,1.0,3.0,0.0,17.0,18.0,6.0,0.0,0.0
2,SCI02-1--11,meal-preparation-cleanup,1.0,0.0,2.0,0.0,0.0,0.0,1.0,7.0,...,0.0,18.0,0.0,0.0,0.0,17.0,25.0,13.0,0.0,0.0
3,SCI02-1--12,meal-preparation-cleanup,0.0,0.0,8.0,3.0,0.0,4.0,2.0,28.0,...,0.0,14.0,10.0,2.0,0.0,15.0,21.0,7.0,1.0,0.0
4,SCI02-1--2,meal-preparation-cleanup,10.0,4.0,6.0,2.0,1.0,2.0,1.0,47.0,...,0.0,13.0,19.0,11.0,0.0,35.0,8.0,11.0,3.0,0.0


## With Active Objects

In [8]:
X = grouped_counts_df.drop(['adl', 'video'], axis=1)
y = grouped_counts_df['adl']
groups = grouped_counts_df['video'].str[:5]

active = ev.leave_one_group_out_cv(grouped_counts_df, X, y, groups, RandomForestClassifier(n_estimators=300, random_state=SEED)); active

    

Unnamed: 0,group_left_out,accuracy,precision,recall,f1,mean_accuracy,mean_precision,mean_recall,mean_f1
0,SCI02,0.895833,0.866414,0.414035,0.434039,0.673679,0.614821,0.604869,0.481714
1,SCI03,0.765625,0.70709,0.56192,0.326276,0.673679,0.614821,0.604869,0.481714
2,SCI06,0.621359,0.316465,0.388312,0.906655,0.673679,0.614821,0.604869,0.481714
3,SCI08,0.571429,0.666667,0.666667,0.333333,0.673679,0.614821,0.604869,0.481714
4,SCI10,0.188073,0.406832,0.708256,0.142857,0.673679,0.614821,0.604869,0.481714
5,SCI11,0.752525,0.545405,0.395079,0.540745,0.673679,0.614821,0.604869,0.481714
6,SCI12,0.736364,0.506316,0.584802,0.659696,0.673679,0.614821,0.604869,0.481714
7,SCI13,0.644068,0.528935,0.654593,0.339665,0.673679,0.614821,0.604869,0.481714
8,SCI14,0.652174,0.657043,0.554915,0.430868,0.673679,0.614821,0.604869,0.481714
9,SCI15,0.905882,0.953431,0.639907,0.665693,0.673679,0.614821,0.604869,0.481714


## Without Active Objects

In [9]:
X = df_no_active.drop(['adl', 'video'], axis=1)
y = df_no_active['adl']
groups = df_no_active['video'].str[:5]

no_active = ev.leave_one_group_out_cv(grouped_counts_df, X, y, groups, RandomForestClassifier(n_estimators=300, random_state=SEED)); no_active

Unnamed: 0,group_left_out,accuracy,precision,recall,f1,mean_accuracy,mean_precision,mean_recall,mean_f1
0,SCI02,0.864583,0.856504,0.382456,0.407556,0.646083,0.590697,0.570583,0.479456
1,SCI03,0.78125,0.711966,0.567183,0.332621,0.646083,0.590697,0.570583,0.479456
2,SCI06,0.61165,0.317269,0.256061,0.766707,0.646083,0.590697,0.570583,0.479456
3,SCI08,0.571429,0.666667,0.666667,0.333333,0.646083,0.590697,0.570583,0.479456
4,SCI10,0.165138,0.401361,0.687899,0.124753,0.646083,0.590697,0.570583,0.479456
5,SCI11,0.717172,0.694306,0.358825,0.372689,0.646083,0.590697,0.570583,0.479456
6,SCI12,0.731818,0.439139,0.519836,0.745143,0.646083,0.590697,0.570583,0.479456
7,SCI13,0.661017,0.554674,0.640911,0.332276,0.646083,0.590697,0.570583,0.479456
8,SCI14,0.5,0.468031,0.453692,0.495879,0.646083,0.590697,0.570583,0.479456
9,SCI15,0.894118,0.686264,0.609604,0.888911,0.646083,0.590697,0.570583,0.479456
