In [44]:
import pandas as pd
import lightgbm as lgb
import numpy as np
import warnings
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, average_precision_score, precision_score

warnings.filterwarnings('ignore')

DATA_FILE = '../metadata/relationship_final/relationship_final.csv'

In [57]:
dataset = pd.read_csv(DATA_FILE)

In [46]:
dataset.head()

Unnamed: 0,ImageID,LabelName1,LabelName2,XMin1,XMax1,YMin1,YMax1,XMin2,XMax2,YMin2,...,Area2,Area3,DistanceTopLeft,DistanceTopRight,DistanceBottomLeft,DistanceBottomRight,DistanceCenter,IoU,RelationshipFrequency,RelationshipLabel
0,9553b9608577b74b,Man,Sunglasses,0.023404,0.985106,0.038344,0.981595,0.238298,0.759574,0.349693,...,0.093542,0.907126,0.667443,0.505549,0.378309,0.384451,0.483938,0.470679,0.787829,wears
1,b2b742920d39272f,Man,Sunglasses,0.099278,1.0,0.0,1.0,0.133574,0.916968,0.238267,...,0.275744,0.900722,0.762505,0.418075,0.240723,0.25232,0.418406,0.634265,0.787829,wears
2,434a26d95fcb7c74,Man,Sunglasses,0.227891,0.858844,0.040307,0.996161,0.282313,0.460884,0.193858,...,0.016109,0.603099,0.804147,0.815749,0.16291,0.426556,0.552341,0.402799,0.787829,wears
3,393bfd0076ce5ce9,Man,Sunglasses,0.548736,1.0,0.17148,1.0,0.676895,0.853791,0.279783,...,0.011495,0.373881,0.731531,0.671349,0.167792,0.181952,0.438156,0.472318,0.787829,wears
4,bc383971045dc428,Man,Sunglasses,0.254425,0.90708,0.023599,1.0,0.473451,0.637168,0.126844,...,0.005554,0.637253,0.900208,0.881569,0.24214,0.288984,0.578225,0.368365,0.787829,wears


In [58]:
def swap_contains(row):
    row.LabelName1, row.XMin1, row.XMax1, row.YMin1, row.YMax1, row.Area1, row.LabelName2, row.XMin2, row.XMax2, row.YMin2, row.YMax2, row.Area2 = row.LabelName2, row.XMin2, row.XMax2, row.YMin2, row.YMax2, row.Area2, row.LabelName1, row.XMin1, row.XMax1, row.YMin1, row.YMax1, row.Area1
    return row

In [59]:
dataset[dataset.RelationshipLabel == 'contain'] = dataset[dataset.RelationshipLabel == 'contain'].apply(swap_contains, axis=1)

In [60]:
columns = dataset.columns.tolist()
columns.remove('RelationshipFrequency')

In [61]:
dataset = dataset[columns]

In [62]:
swap_relations = {
    'highfive': 'interacts_with',
    'talk_on_phone': 'interacts_with',
    'cut': 'interacts_with',
    'holding_hands': 'interacts_with',
    'handshake': 'interacts_with',
    'eat': 'interacts_with',
    'read': 'interacts_with',
    'hug': 'interacts_with',
    'kiss': 'interacts_with',
    'dance': 'interacts_with',
    'hits': 'interacts_with',
    'contain': 'inside_of',
    'snowboard': 'on',
    'ski': 'on',
    'surf': 'on',
    'skateboard': 'on',
    'hang': 'on',
    'ride': 'on',
    'throw': 'hits',
    'kick': 'hits',
    'catch': 'holds',
    'drink': 'holds'
}

In [65]:
dataset = dataset[dataset.RelationshipLabel != 'none']
dataset['RelationshipLabel'] = dataset.RelationshipLabel.apply(lambda x: swap_relations.get(x, x))

In [67]:
targets_string = list(set(dataset.RelationshipLabel.unique()))
targets = {target: idx for idx, target in enumerate(targets_string)}

In [12]:
desc = pd.read_csv('../metadata/class-descriptions-boxable.csv')
desc_labels = dict(zip(desc.LabelName, desc.LabelID))
label_values = {idx: value for idx, value in enumerate(desc_labels.keys())}
label_names = {value: key for key, value in label_values.items()}

In [69]:
dataset['LabelName1'] = dataset.LabelName1.map(label_names)
dataset['LabelName2'] = dataset.LabelName2.map(label_names)
dataset['RelationshipLabel'] = dataset.RelationshipLabel.map(targets)

In [15]:
def get_test_train_split(df):
    X = df.iloc[:, 1:-1].values
    y = df.iloc[:, -1].values
    
    return train_test_split(X, y, test_size=0.25)

In [16]:
def get_dataset(X_train, y_train, features, cat):
    d_train = lgb.Dataset(X_train, label=y_train, feature_name=features, 
                      categorical_feature=cat, free_raw_data=False)
    d_train.construct();
    return d_train

In [17]:
def print_classification_report(model, X_test, y_test):
    y_pred = model.predict(X_test)
    predictions, _ = get_predictions(y_pred)
    print(classification_report(y_test, predictions, target_names=targets_string))
    return y_pred

In [18]:
def get_predictions(y_pred):
    predictions = []
    scores = []
    for idx, prediction in enumerate(y_pred):
        index = np.where(prediction == np.max(prediction))[0][0]
        target_class = targets_string[index]
        predictions.append(index)
        scores.append(prediction[index])
    return np.array(predictions), np.array(scores)

In [19]:
def train_model(df, features, cat, params={}, n=100):
    X_train, X_test, y_train, y_test = get_test_train_split(df)
    d_train = get_dataset(X_train, y_train, features, cat)
    
    model = lgb.train(params, d_train, n)
    y_pred = print_classification_report(model, X_test, y_test)
    
    return model, X_test, y_test, y_pred

In [20]:
params = {}
params['learning_rate'] = 0.003
params['boosting_type'] = 'gbdt'
params['objective'] = 'multiclass'
params['metric'] = 'multi_logloss'
params['sub_feature'] = 0.5
params['num_leaves'] = 10
params['min_data'] = 50
params['max_depth'] = 10
params['num_classes'] = len(targets)

In [196]:
over_sampled = list(dataset.RelationshipLabel.value_counts().index)[:2]

In [198]:
dfs = [dataset[~dataset.RelationshipLabel.isin(over_sampled)]]
for s in over_sampled:
    dfs.append(dataset[dataset.RelationshipLabel == s].sample(60000))
df_undersampled = pd.concat(dfs, axis=0)

In [70]:
model, X_test, y_test, y_pred = train_model(df_undersampled, columns[1:-1], columns[1:3], params, n=500)

                precision    recall  f1-score   support

            on       0.88      0.98      0.93     13877
         plays       0.87      0.29      0.44      2793
     inside_of       0.98      0.87      0.92      6624
interacts_with       0.87      0.77      0.82      3088
            at       0.99      1.00      1.00     29573
         holds       0.77      0.83      0.80      9006
         wears       0.98      1.00      0.99     27018

      accuracy                           0.94     91979
     macro avg       0.91      0.82      0.84     91979
  weighted avg       0.94      0.94      0.94     91979



In [102]:
model.save_model('model.txt');