In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# TO USE A DATABASE OTHER THAN SQLITE, USE THIS LINE
# Note that this is necessary for parallel execution amongst other things...
# os.environ['SNORKELDB'] = 'postgres:////babble_model_drink'

from snorkel import SnorkelSession
session = SnorkelSession()

In [2]:
import numpy as np

# anns_folder = '/dfs/scratch0/paroma/coco/annotations/'
anns_folder = os.environ['SNORKELHOME'] + '/experiments/babble/drink/data/'
train_path = anns_folder + 'drink_train_anns.npy'
val_path = anns_folder + 'drink_val_anns.npy'

train_anns = np.load(train_path).tolist()
val_anns = np.load(val_path).tolist()

In [3]:
from snorkel.models import candidate_subclass

Drinker = candidate_subclass('Drinker', ['person', 'cup'])

In [4]:
from snorkel.parser import ImageCorpusExtractor, CocoPreprocessor

corpus_extractor = ImageCorpusExtractor(candidate_class=Drinker)

coco_preprocessor = CocoPreprocessor(train_path, source=0)
%time corpus_extractor.apply(coco_preprocessor, person_id=[1], object_id=[44,46,47])

coco_preprocessor = CocoPreprocessor(val_path, source=1)
%time corpus_extractor.apply(coco_preprocessor, person_id=[1], object_id=[44,46,47], clear=False)

for split in [0, 1]:
    num_candidates = session.query(Drinker).filter(Drinker.split == split).count()
    print("Split {} candidates: {}".format(split, num_candidates))

Clearing existing...
Running UDF...
CPU times: user 3.92 s, sys: 104 ms, total: 4.02 s
Wall time: 4.03 s
Running UDF...
CPU times: user 1 s, sys: 72 ms, total: 1.08 s
Wall time: 1.09 s
Split 0 candidates: 2524
Split 1 candidates: 495


In [5]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

def assign_gold_labels(session, labels_by_candidate):
    candidate_class = Drinker
    annotator_name = 'gold'
    
    for candidate_hash, label in labels_by_candidate.items():
        set_name, image_idx, bbox1_idx, bbox2_idx = candidate_hash.split(':')
        source = {'train': 0, 'val': 1}[set_name]
        stable_id_1 = "{}:{}::bbox:{}".format(source, image_idx, bbox1_idx)
        stable_id_2 = "{}:{}::bbox:{}".format(source, image_idx, bbox2_idx)
        context_stable_ids = "~~".join([stable_id_1, stable_id_2])
        query = session.query(StableLabel).filter(StableLabel.context_stable_ids == context_stable_ids)
        query = query.filter(StableLabel.annotator_name == annotator_name)
        label = 1 if label else -1
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=label))

    session.commit()
    reload_annotator_labels(session, candidate_class, annotator_name, split=source, filter_label_split=False)

In [6]:
anns_folder = '/dfs/scratch0/paroma/coco/annotations/'
validation_labels_by_candidate = np.load(anns_folder + 'drink_val_labels_by_candidate.npy').tolist()
assign_gold_labels(session, validation_labels_by_candidate)

AnnotatorLabels created: 464


In [7]:
anns_folder = '/dfs/scratch0/paroma/coco/annotations/'
train_labels_by_candidate = np.load(anns_folder + 'drink_train_labels_by_candidate.npy').tolist()
assign_gold_labels(session, train_labels_by_candidate)

AnnotatorLabels created: 2440


In [8]:
np.where(np.array(train_labels_by_candidate.values()) == True)

(array([  66,   69,  355,  389,  413,  655,  800,  824,  882,  887,  965,
         966,  967, 1033, 1040, 1059, 1174, 1182, 1336, 1420, 1444, 1453,
        1497, 1609, 1649, 1723, 1731, 1827, 1853, 1864, 1870, 1886, 1921,
        1980, 2153, 2271, 2325, 2352, 2356, 2411]),)

In [9]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_dev

<495x1 sparse matrix of type '<type 'numpy.int64'>'
	with 464 stored elements in Compressed Sparse Row format>

In [10]:
from snorkel.annotations import load_gold_labels

L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)
L_gold_train

<2524x1 sparse matrix of type '<type 'numpy.int64'>'
	with 2440 stored elements in Compressed Sparse Row format>

In [11]:
a = np.array(L_gold_train.todense()).ravel()

In [12]:
np.where(a == 1)

(array([  24,   28,  196,  334,  438,  476,  503,  544,  631,  642,  645,
         698,  712,  783,  850, 1095, 1134, 1211, 1238, 1260, 1338, 1464,
        1504, 1608, 1634, 1782, 1829, 1838, 1841, 1863, 1884, 1920, 1969,
        2102, 2161, 2169, 2319, 2422, 2423, 2476]),)