In [8]:
from typing import List
import json
import numpy as np
import random
import warnings
from utils.synonyms import SimilarCategories
from utils.visual_genome import filter_relationships
from utils.visual_genome import get_vg_obj_name
from utils.visual_genome import count_relationships
from utils.visual_genome import sample_relationships
from utils.visualization import view_n_image_rels
from utils.visual_genome import extract_obj_categories
from utils.primitives import get_primitive_features
from utils.primitives import BBoxPrim
from utils.visual_genome import get_labels
from sklearn.tree import DecisionTreeClassifier
from snorkel.labeling import labeling_function
from snorkel.labeling import LFApplier
from snorkel.labeling.model.label_model import LabelModel
from utils.visual_genome import get_vg_obj_name
import pickle

%load_ext autoreload
%autoreload 2
warnings.filterwarnings(action="ignore")
SEED = 123
random.seed(SEED)
np.random.seed(SEED)
THRESH = 0.3 # Threshold for rounding the heuristic outputs
MIN_SAMPLES_SPLITS = [2, 4, 8, 16, 32, 64, 128] # Number of samples to look at in our decision trees
PREDICATE_FILE_PATH = "data/VisualGenome/orig_pred_list.txt"

def load_dataset(dataset_path):
    return np.load(dataset_path)

def load_annotations(ann_path):
    return json.load(ann_path)

def filter_corrupted_images(list1, list2):
    return [list1[i] for i in list2]

def split_data(annotations, splits):
    train = [annotations[i] for i in list(np.where(splits == 0)[0])]
    val   = [annotations[i] for i in list(np.where(splits == 1)[0])]
    test  = [annotations[i] for i in list(np.where(splits == 2)[0])]
    return train, val, test

def get_predicates(pred_path):
    with open(pred_path, "r") as f:
        predicates = sorted([x.strip() for x in f.readlines()])
    return predicates

def get_object_list(object_list_path):
    return [x.strip() for x in open(object_list_path, "r")]

def get_object_synonyms(object_list, synonyms):
    object_synonyms = {}
    for o in object_list:
        object_synonyms[o] = set(synonyms.get_similar_objects([o]) + [o])
    return object_synonyms

########

def fit_heuristic(X_train, Y_train, min_samples_split):
    dt = DecisionTreeClassifier(min_samples_split=min_samples_split)
    dt.fit(X_train, Y_train)
    return dt

def prob_to_label(prob_labels, thresh):
    # We default all labels to abstains = -1, by Snorkel convention
    rounded_labels = np.ones(prob_labels.shape[0]) * -1
    for i in range(rounded_labels.shape[0]):
        if np.max(prob_labels[i, :]) >= thresh:
            # Other classes are labeled as the argmax of estimated probabilities
            rounded_labels[i] = np.argmax(prob_labels[i, :])
    return rounded_labels

def examples_to_feat_matrix(examples: List[any], feature: str):
    feature_matrix = []
    for x in examples:
        feature_matrix.append(getattr(x, feature))
    feature_matrix = np.array(feature_matrix)
    return feature_matrix

def get_lfs(img_agn_features, labeled_mask, multiclass_limited_labels):
    lfs = []
    for _, ms in enumerate(MIN_SAMPLES_SPLITS):
        for feat_key, feat in img_agn_features.items():
            heuristic = fit_heuristic(feat[labeled_mask], multiclass_limited_labels, ms)

            @labeling_function(
                name=f"heuristic_ms:{ms}_feat:{feat_key}", 
                resources=dict(feat_key=feat_key, heuristic=heuristic, thresh=THRESH)
            )
            def lf(x, feat_key, heuristic, thresh):
                feat = getattr(x, feat_key)
                feat = np.expand_dims(feat, axis=0)
                probs = heuristic.predict_proba(feat)
                return prob_to_label(probs, thresh=thresh).squeeze()
            lfs.append(lf)
    return lfs

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
splits = load_dataset("data/VisualGenome/split.npy")
valid = load_dataset("data/VisualGenome/valid.npy")
annotations = load_annotations(open("data/VisualGenome/relationships.json"))
valid_annotations = filter_corrupted_images(annotations, list(valid))
train, val, test = split_data(valid_annotations, splits)
predicates = get_predicates(PREDICATE_FILE_PATH)
object_list = get_object_list("data/VisualGenome/object_list.txt")
synonyms = SimilarCategories()
all_objects = set(object_list + synonyms.get_similar_objects(object_list))
object_filter = lambda r: (get_vg_obj_name(r["subject"]) in all_objects and get_vg_obj_name(r["object"]) in all_objects)
filter_relationships(annotations, object_filter, inplace=True)
predicate_filter = lambda r: r["predicate"].lower() in predicates
filtered_train = filter_relationships(train, predicate_filter)
filtered_val = filter_relationships(val, predicate_filter)
filtered_test = filter_relationships(test, predicate_filter)
train_counts = count_relationships(filtered_train)
cardinality = len(predicates)
LIMITED_LABEL_TRAIN = sample_relationships(filtered_train, train_counts, n_per_pred=10)
# view_n_image_rels(LIMITED_LABEL_TRAIN, n=3)
object_synonyms = get_object_synonyms(object_list, synonyms)
obj_categories = extract_obj_categories(annotations, predicates, object_synonyms)
limited_labels = get_labels(LIMITED_LABEL_TRAIN, predicates)
EVAL_VALID = filtered_val # Used to validate performance of our training labels, hyperparameters, etc.
# ORACLE_TRAIN = filtered_train # Includes ALL labels
EVAL_TEST = filtered_test # Used to evaluate the downstream scene graph model performance

100%|██████████| 108077/108077 [00:02<00:00, 45783.42it/s]
100%|██████████| 75651/75651 [00:00<00:00, 111315.47it/s]
100%|██████████| 10807/10807 [00:00<00:00, 104909.29it/s]
100%|██████████| 21615/21615 [00:00<00:00, 107223.18it/s]
100%|██████████| 75651/75651 [00:00<00:00, 331783.66it/s]


In [10]:
def get_index_of_label(relation):
    for pred in predicates:
        if pred == relation:
            return predicates.index(relation)

index_of_wear = get_index_of_label("wearing")

In [11]:
def get_primitives_labels():
    train_examples = get_primitive_features(LIMITED_LABEL_TRAIN, obj_categories, object_synonyms)
    limited_labels = get_labels(LIMITED_LABEL_TRAIN, predicates)
    np.unique(limited_labels)
    unlabeled_mask = np.sum(limited_labels, axis=1) == len(predicates) * -1
    labeled_mask = np.logical_not(unlabeled_mask)
    assert np.sum(labeled_mask) + np.sum(unlabeled_mask) == len(limited_labels)
    multiclass_limited_labels = np.where(limited_labels[labeled_mask, :] == 1.0)[1]
    img_agn_features = {
        "spatial": examples_to_feat_matrix(train_examples, feature="spatial"),
        "categorical": examples_to_feat_matrix(train_examples, feature="categorical")}
    lfs = get_lfs(img_agn_features, labeled_mask, multiclass_limited_labels)
    applier = LFApplier(lfs)
    L_train = applier.apply(train_examples)
    return L_train, applier, lfs

L_train_paper, applier_paper, lfs_paper = get_primitives_labels()
with open('./lf1234678/L_train_paper.pickle', 'wb') as f:
    pickle.dump(L_train_paper, f)

100%|██████████| 75651/75651 [00:01<00:00, 42965.20it/s]
100%|██████████| 75651/75651 [00:00<00:00, 334629.05it/s]
70978it [01:03, 1118.38it/s]


In [12]:
my_lfs = []
@labeling_function(name="my_lf1")
def my_lf1(x):
    if x[0].area > x[1].area:
        return index_of_wear
    else:
        return -1
my_lfs.append(my_lf1)
@labeling_function(name="my_lf2")
def my_lf2(x):
    clothes = ["shirt", "t-shirt", "pullover", "sweatshirt", "hoodie",
               "dress", "skirt", "jean", "short", "pyjama", "pants",
               "suit",  "blouse", "jacket", "glove", "swimsuit", "bikini",
               "shoe", "sandal", "boot", "slipper", "sock", "hat", "cap",
               "glasses", "scarf", "sunglasses", "watch", "belt"]
    if x[3] in clothes:
        return index_of_wear
    else:
        return -1
my_lfs.append(my_lf2)
@labeling_function(name="my_lf3")
def my_lf3(x):
    clothes = ["hat", "cap"]
    middle_point_of_subject = x[0].x0 + int(x[0].height/2)
    if x[3] in clothes and x[1].x0 < middle_point_of_subject:
        return index_of_wear
    else:
        return -1
my_lfs.append(my_lf3)
@labeling_function(name="my_lf4")
def my_lf4(x):
    clothes = ["skirt", "jean", "short", "pants"]
    middle_point_of_subject = x[0].x0 + int(x[0].height/2)
    if x[3] in clothes and x[1].x0 > middle_point_of_subject:
        return index_of_wear
    else:
        return -1
my_lfs.append(my_lf4)
import torch
from transformers import BertTokenizer, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

@labeling_function(name="my_lf6")
def my_lf6(x):
    text = f"[SEP] {x[2]} [MASK] {x[3]}"
    tokenized_text = tokenizer.tokenize(text)
    masked_index = tokenized_text.index("[MASK]")
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]
    probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
    _, top_fifty_indices = torch.topk(probs, 50, sorted=True)
    for pred_idx in top_fifty_indices:
        predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
        if predicted_token in predicates:
            return predicates.index(predicted_token)
    return -1
my_lfs.append(my_lf6)

@labeling_function(name="my_lf7")
def my_lf7(x):
    text = f"[SEP] {x[2]} and {x[3]} are [MASK]"
    tokenized_text = tokenizer.tokenize(text)
    masked_index = tokenized_text.index("[MASK]")
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]
    probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
    _, top_fifty_indices = torch.topk(probs, 50, sorted=True)
    for pred_idx in top_fifty_indices:
        predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
        if predicted_token in predicates:
            return predicates.index(predicted_token)
    return -1
my_lfs.append(my_lf7)

@labeling_function(name="my_lf8")
def my_lf8(x):
    text = f"[SEP] {x[2]} is [MASK] {x[3]}"
    tokenized_text = tokenizer.tokenize(text)
    masked_index = tokenized_text.index("[MASK]")
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]
    probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
    _, top_fifty_indices = torch.topk(probs, 50, sorted=True)
    for pred_idx in top_fifty_indices:
        predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
        if predicted_token in predicates:
            return predicates.index(predicted_token)
    return -1
my_lfs.append(my_lf8)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
def get_my_train_examples(train_examples):
    examples = []
    for a in train_examples:
        for r in a["relationships"]:
            sub_bbox = BBoxPrim.from_vg_obj(r["subject"])
            obj_bbox = BBoxPrim.from_vg_obj(r["object"])
            sub_name = get_vg_obj_name(r["subject"])
            obj_name = get_vg_obj_name(r["object"])
            image_id = a["image_id"]
            x = (sub_bbox, obj_bbox, sub_name, obj_name, image_id)
            examples.append(x)
    return examples

train_examples = get_my_train_examples(LIMITED_LABEL_TRAIN)
with open('./lf1234678/train_examples.pickle', 'wb') as f:
    pickle.dump(train_examples, f)
my_applier = LFApplier(my_lfs)
my_L_train = my_applier.apply(train_examples)
with open('./lf1234678/my_L_train.pickle', 'wb') as f:
    pickle.dump(my_L_train, f)
L_train_main = np.concatenate((L_train_paper, my_L_train), axis=1)
with open('./lf1234678/L_train_main.pickle', 'wb') as f:
    pickle.dump(L_train_main, f)

70978it [1:56:11, 10.18it/s]


In [14]:
val_counts = count_relationships(filtered_val)
print(sorted(val_counts.items(), key=lambda x:x[1]))
class_balance = np.array([val_counts[k] / val_counts["_TOTAL"] for k in sorted(predicates)])
N_EPOCHS = 60
valid_labels = get_labels(EVAL_VALID, predicates)
multiclass_labels_valid = np.where(valid_labels == 1.0)[1]

[('says', 6), ('playing', 7), ('flying in', 11), ('walking in', 16), ('lying on', 48), ('covered in', 51), ('painted on', 59), ('covering', 62), ('hanging from', 64), ('using', 77), ('mounted on', 84), ('parked on', 98), ('eating', 107), ('watching', 114), ('growing on', 122), ('walking on', 204), ('carrying', 205), ('standing on', 341), ('riding', 457), ('sitting on', 565), ('wearing', 6412), ('_TOTAL', 9110)]


100%|██████████| 10807/10807 [00:00<00:00, 675073.99it/s]


In [15]:
label_model_paper = LabelModel(cardinality=cardinality, verbose=True)
label_model_paper.fit(L_train_paper, class_balance=class_balance, seed=SEED, lr=0.01, l2=0.01, log_freq=10, n_epochs=N_EPOCHS)
with open('./lf1234678/label_model_paper.pickle', 'wb') as f:
    pickle.dump(label_model_paper, f)

weak_labels = label_model_paper.predict_proba(L_train_paper)
valid_examples = get_primitive_features(EVAL_VALID, obj_categories, object_synonyms)
L_valid_paper = applier_paper.apply(valid_examples)
with open('./lf1234678/L_valid_paper.pickle', 'wb') as f:
    pickle.dump(L_valid_paper, f)

valid_labels = get_labels(EVAL_VALID, predicates)
multiclass_labels_valid = np.where(valid_labels == 1.0)[1]
label_model_paper.score(L_valid_paper, multiclass_labels_valid)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/60 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=7.136]
INFO:root:[10 epochs]: TRAIN:[loss=2.472]
INFO:root:[20 epochs]: TRAIN:[loss=2.550]
 45%|████▌     | 27/60 [00:00<00:00, 267.31epoch/s]INFO:root:[30 epochs]: TRAIN:[loss=2.319]
INFO:root:[40 epochs]: TRAIN:[loss=2.180]
INFO:root:[50 epochs]: TRAIN:[loss=2.118]
100%|██████████| 60/60 [00:00<00:00, 312.50epoch/s]
INFO:root:Finished Training
100%|██████████| 10807/10807 [00:00<00:00, 40431.94it/s]
9110it [00:08, 1076.22it/s]
100%|██████████| 10807/10807 [00:00<00:00, 720576.16it/s]


{'accuracy': 0.8073545554335895}

In [16]:
label_model_my = LabelModel(cardinality=cardinality, verbose=True)
label_model_my.fit(my_L_train, class_balance=class_balance, seed=SEED, lr=0.01, l2=0.01, log_freq=10, n_epochs=N_EPOCHS)
with open('./lf1234678/label_model_my.pickle', 'wb') as f:
    pickle.dump(label_model_my, f)
    
weak_labels = label_model_my.predict_proba(my_L_train)
my_valid_examples = get_my_train_examples(EVAL_VALID)
with open('./lf1234678/my_valid_examples.pickle', 'wb') as f:
    pickle.dump(my_valid_examples, f)

my_L_valid = my_applier.apply(my_valid_examples)
with open('./lf1234678/my_L_valid.pickle', 'wb') as f:
    pickle.dump(my_L_valid, f)

valid_labels = get_labels(EVAL_VALID, predicates)
multiclass_labels_valid = np.where(valid_labels == 1.0)[1]
label_model_my.score(my_L_valid, multiclass_labels_valid)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/60 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=2.370]
INFO:root:[10 epochs]: TRAIN:[loss=0.665]
INFO:root:[20 epochs]: TRAIN:[loss=0.306]
INFO:root:[30 epochs]: TRAIN:[loss=0.118]
INFO:root:[40 epochs]: TRAIN:[loss=0.059]
INFO:root:[50 epochs]: TRAIN:[loss=0.033]
100%|██████████| 60/60 [00:00<00:00, 1132.02epoch/s]
INFO:root:Finished Training
9110it [14:28, 10.49it/s]
100%|██████████| 10807/10807 [00:00<00:00, 720072.49it/s]


{'accuracy': 0.7327113062568605}

In [17]:
label_model_main = LabelModel(cardinality=cardinality, verbose=True)
label_model_main.fit(L_train_main, class_balance=class_balance, seed=SEED, lr=0.01, l2=0.01, log_freq=10, n_epochs=N_EPOCHS)
with open('./lf1234678/label_model_main.pickle', 'wb') as f:
    pickle.dump(label_model_main, f)
weak_labels_main = label_model_main.predict_proba(L_train_main)

valid_examples_paper = get_primitive_features(EVAL_VALID, obj_categories, object_synonyms)
L_valid_paper = applier_paper.apply(valid_examples_paper)

my_valid_examples = get_my_train_examples(EVAL_VALID)
my_L_valid = my_applier.apply(my_valid_examples)

L_valid_main = np.concatenate((L_valid_paper, my_L_valid), axis=1)
with open('./lf1234678/L_valid_main.pickle', 'wb') as f:
    pickle.dump(L_valid_main, f)
label_model_main.score(L_valid_main, multiclass_labels_valid)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/60 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=16.942]
INFO:root:[10 epochs]: TRAIN:[loss=3.825]
INFO:root:[20 epochs]: TRAIN:[loss=3.217]
 47%|████▋     | 28/60 [00:00<00:00, 276.56epoch/s]INFO:root:[30 epochs]: TRAIN:[loss=2.548]
INFO:root:[40 epochs]: TRAIN:[loss=2.503]
INFO:root:[50 epochs]: TRAIN:[loss=2.266]
100%|██████████| 60/60 [00:00<00:00, 276.15epoch/s]
INFO:root:Finished Training
100%|██████████| 10807/10807 [00:00<00:00, 41724.94it/s]
9110it [00:08, 1086.08it/s]
9110it [14:34, 10.42it/s]


{'accuracy': 0.8159165751920966}

In [18]:
from snorkel.labeling import LFAnalysis
lfs = lfs_paper + my_lfs
LFAnalysis(L_valid_main, lfs).lf_summary(multiclass_labels_valid)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
heuristic_ms:2_feat:spatial,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1.0,1.0,0.989023,2690,6420,0.29528
heuristic_ms:2_feat:categorical,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1.0,1.0,0.989023,3631,5479,0.398573
heuristic_ms:4_feat:spatial,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1.0,1.0,0.989023,1876,7234,0.205928
heuristic_ms:4_feat:categorical,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1.0,1.0,0.989023,3602,5508,0.39539
heuristic_ms:8_feat:spatial,4,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0.781778,0.781778,0.771021,1540,5582,0.216231
heuristic_ms:8_feat:categorical,5,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1.0,1.0,0.989023,3267,5843,0.358617
heuristic_ms:16_feat:spatial,6,"[0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15...",0.835565,0.835565,0.825576,3227,4385,0.423936
heuristic_ms:16_feat:categorical,7,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0.722942,0.722942,0.712075,3204,3382,0.486486
heuristic_ms:32_feat:spatial,8,"[0, 3, 4, 5, 6, 8, 12, 14, 15, 16]",0.87124,0.87124,0.862569,452,7485,0.056948
heuristic_ms:32_feat:categorical,9,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0.37629,0.37629,0.365642,2690,738,0.784714
