# Extract Constraints from Dataset Annotation

This notebook shows how constraints can be extracted from dataset annotation
and compiled into formulas.

For both 50Salads and Breakfast dataset, we will:

1. Collect statistics.
2. Compiled constraint strings.
3. Verify constraint strings against annotations.

In [None]:
import sys, os
import re
from glob import glob
import json
import os.path as osp
import pickle as pk
from pathlib import Path

from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import torch

sys.path.append('..')
from lib.tl import parser

# Raise the recursion limit to avoid problems when parsing formulas
sys.setrecursionlimit(10000)
# Setting TL_RECORD_TRACE asks DTL to record the evaluation trace.
# Using this we can find the conflicting part between logits and formula.
os.environ['TL_RECORD_TRACE'] = '1'
data_root = '../dataset'

## General Tools

The four function below collects backward-dependency (BD), 
forward-cancellation (FC), implication (Ip) and exclusivity (Ex).

For all the functions:

- The first return value is a list of stringified logic clauses (constraints).
- The second return value is a list of tuples of ordered operands. 
  - For example, a tuple `(x,y)` in the list returned by `bd_clauses` means
        `BD(x, y)`, or action x is back dependent on action y.

In [None]:
def backward_dependency_clauses(SUBCLASSES, occur_joint, occur_before, compact=False):
    clauses = []
    ordered_operands = []
    for i in range(0, len(SUBCLASSES)):
        for j in range(0, len(SUBCLASSES)):
            if i == j: continue
            if occur_joint[i,j] > 0 and occur_before[i,j] == 0:
                act_i, act_j = SUBCLASSES[i].lower(), SUBCLASSES[j].lower()
                clauses.append(f'((F {act_i} & F {act_j}) -> (~{act_i} W {act_j}))')
                ordered_operands.append((act_i, act_j))
    return clauses, ordered_operands


def mutual_exclusivity_clauses(SUBCLASSES, occur_joint, compact=False):
    clauses = []
    ordered_operands = []
    for i in range(0, len(SUBCLASSES)):
        act_i = SUBCLASSES[i].lower()
        if compact: sub_clauses_i = []
        for j in range(0, len(SUBCLASSES)):
            act_j = SUBCLASSES[j].lower()
            if i == j: continue
            if occur_joint[i,j] == 0:
                if not compact:
                    clauses.append(f'(F {act_i} -> ~F {act_j})')
                    ordered_operands.append((act_i, act_j))
                else:
                    sub_clauses_i.append(f'~F {act_j}')
        if compact and len(sub_clauses_i) > 0:
            clauses.append(f'(F {act_i} -> ({" & ".join(sub_clauses_i)}))')
    return clauses, ordered_operands


def forward_cancellation_clauses(SUBCLASSES, occur_joint, occur_after, compact=False):
    clauses = []
    ordered_operands = []
    for i in range(0, len(SUBCLASSES)):
        for j in range(0, len(SUBCLASSES)):
            if occur_joint[i,j] > 0 and occur_after[i,j] == 0:
                act_i, act_j = SUBCLASSES[i].lower(), SUBCLASSES[j].lower()
                if act_i != act_j:
                    clauses.append(f'((F {act_i} & F {act_j}) -> (~{act_i} S {act_j}))')
                    ordered_operands.append((act_j, act_i))
                else:
                    clauses.append(f'(F {act_i} -> (~{act_i} S {act_j}))')
    return clauses, ordered_operands


def implication_clauses(SUBCLASSES, occur_joint_vid, occur_count_vid, compact=False):
    clauses = []
    ordered_operands = []
    for i in range(0, len(SUBCLASSES)):
        act_i = SUBCLASSES[i].lower()
        occur_joint_ij = occur_joint_vid[i] / occur_count_vid
        if compact: sub_clauses_i = []
        for j in range(0, len(SUBCLASSES)):
            act_j = SUBCLASSES[j].lower()
            if i == j: continue
            if occur_joint_ij[j] >= 1:
                if not compact:
                    clauses.append(f'(F {act_j} -> F {act_i})')
                    ordered_operands.append((act_j, act_i))
                else:
                    sub_clauses_i.append(f'F {act_j}')
        if compact and len(sub_clauses_i) > 0:
            clauses.append(f'(({" | ".join(sub_clauses_i)}) -> F {act_i})')
    return clauses, ordered_operands


def read_file(path, gt_mode):
    with open(path, 'r') as f:
        gt = f.readlines()
    if gt_mode:
        gt = np.array([g[:-1] for g in gt])
    else:
        gt = np.array(gt[1].split())
    return gt

# Rule from Annotation (Frequentist)


## 50 Salad

In [None]:
SUBCLASSES = [
    'cut_tomato', 'place_tomato_into_bowl', 'cut_cheese', 
    'place_cheese_into_bowl', 'cut_lettuce', 'place_lettuce_into_bowl', 
    'add_salt', 'add_vinegar', 'add_oil', 'add_pepper', 'mix_dressing', 
    'peel_cucumber', 'cut_cucumber', 'place_cucumber_into_bowl', 
    'add_dressing', 'mix_ingredients', 'serve_salad_onto_plate', 
    'action_start', 'action_end']

Collect statistics.

In [None]:
# split can be set to 0, which means all samples will be used; 
# or 1~5, which means using the training samples of the corresponding split.
split = 1
with open(f'/home/ziwei/projects/compositional_action/dataset/50salads/splits/train.split{split}.bundle', 'r') as f:
    fns = f.read().splitlines()

annotation_files = ['../dataset/50salads/groundTruth/' + i for i in fns]

occur_before = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_after = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_joint = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_joint_vid = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_count = np.zeros((len(SUBCLASSES,)))
occur_count_vid = np.zeros((len(SUBCLASSES,)))

anno_ = []
for anno in tqdm(annotation_files):
    anno_items = open(anno).read().splitlines()
    actions = [anno_items[0]]
    for i in range(1, len(anno_items)):
        if anno_items[i] != anno_items[i-1]: actions.append(anno_items[i])
    
    occur_before_record = []
    occur_after_record = []
    occur_joint_record = []
    occur_record = []

    for i, act in enumerate(actions):
        an = act
        act_i = SUBCLASSES.index(act)

        occur_count[act_i] += 1
        if act not in occur_record:
            occur_count_vid[act_i] += 1
            occur_record.append(act)

        for j in range(0, i):
            act_j = SUBCLASSES.index(actions[j])
            if (act_j, act_i) not in occur_joint_record:
                occur_joint_vid[act_j, act_i] += 1
                occur_joint_record.append((act_j, act_i))
            occur_before[act_j, act_i] += 1
            occur_joint[act_j, act_i] += 1

        for j in range(i+1, len(actions)):
            act_j = SUBCLASSES.index(actions[j])
            if (act_j, act_i) not in occur_joint_record:
                occur_joint_vid[act_j, act_i] += 1
                occur_joint_record.append((act_j, act_i))
            occur_after[act_j, act_i] += 1
            occur_joint[act_j, act_i] += 1


Generate rule strings.

In [None]:
# BD: ~a_1 W a_2
bd_str = backward_dependency_clauses(SUBCLASSES, occur_joint, occur_before)[0]
# Ex: F a_1 -> ~F a_2 
ex_str = mutual_exclusivity_clauses(SUBCLASSES, occur_joint)[0]
# FC: ~a_1 S a_2
fc_str = forward_cancellation_clauses(SUBCLASSES, occur_joint, occur_after)[0]
# Ip: F a_1 -> F a_2
ip_str = implication_clauses(SUBCLASSES, occur_joint_vid, occur_count_vid)[0]
# The combined rule string
rule_str = '(' + ' & '.join([c for k in [bd_str, ex_str, fc_str, ip_str] for c in k if len(k) > 0]) + ')'

Verify constraints.
Note that if you set `split` to non-zero values, the constraints will be wrong for some samples.
This is because the constraints are formed from a partial observation of the whole dataset.

In [None]:
evaluator = parser.parse(rule_str)

ap_map = lambda x: SUBCLASSES.index(x)
annotation_files = glob(osp.join(data_root, '50salads', 'groundTruth', '*.txt'))
for fn in tqdm(annotation_files):
    gt_label = read_file(fn, True)

    gt_array = -1 * torch.ones((len(SUBCLASSES), len(gt_label)))
    for t, l in enumerate(gt_label):
        gt_array[SUBCLASSES.index(l), t] = 1
    tl_value = evaluator(gt_array, ap_map=lambda x: SUBCLASSES.index(x), rho=1000)

    if tl_value < 0.0: 
        print('conflict found: ', fn)        
        for i, children in enumerate(evaluator.children):
            if children.value.min() < 0.0:
                print('clause', i, children)

## Breakfast

In [None]:
# rule file generation (SIL ignored)
SUBCLASSES = ['take_cup', 'pour_coffee', 'pour_milk', 'pour_sugar', 
    'stir_coffee', 'spoon_sugar', 'add_teabag', 'pour_water', 'stir_tea', 
    'cut_bun', 'smear_butter', 'put_toppingOnTop', 'put_bunTogether', 
    'take_plate', 'take_knife', 'take_butter', 'take_topping', 'cut_orange', 
    'squeeze_orange', 'take_glass', 'pour_juice', 'take_squeezer', 'take_bowl', 
    'pour_cereals', 'stir_cereals', 'spoon_powder', 'stir_milk', 'pour_oil', 
    'take_eggs', 'crack_egg', 'add_saltnpepper', 'fry_egg', 'put_egg2plate', 
    'butter_pan', 'cut_fruit', 'put_fruit2bowl', 'peel_fruit', 'stir_fruit', 
    'stirfry_egg', 'stir_egg', 'pour_egg2pan', 'spoon_flour', 'stir_dough', 
    'pour_dough2pan', 'fry_pancake', 'put_pancake2plate', 'pour_flour', 'SIL']

METACLASSES = [
    'coffee', 'tea', 'sandwich', 'juice', 'cereals', 
    'milk', 'friedegg', 'salat', 'scrambledegg', 'pancake'
]

Collect statistics.

In [None]:
# split can be set to 0, which means all samples will be used; 
# or 1~4, which means using the training samples of the corresponding split.
split = 0
with open(f'{data_root}/breakfast/splits/train.split{split}.bundle', 'r') as f:
    fns = f.read().splitlines()
fns = [f'{data_root}/breakfast/groundTruth/' + i for i in fns]

annotation_files = fns

occur_before = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_after = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_joint = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_joint_vid = np.zeros((len(SUBCLASSES), len(SUBCLASSES)))
occur_count = np.zeros((len(SUBCLASSES,)))
occur_count_vid = np.zeros((len(SUBCLASSES,)))

for anno in tqdm(annotation_files):
    with open(anno, 'r') as f:
        lines = f.readlines()
    lines = [l[:-1] for l in lines]

    meta_act = anno.split('/')[-1].split('_')[-1].split('.')[0]

    actions = [lines[0]]
    for l in lines[1:]:
        if l != actions[-1]:
            actions.append(l)

    occur_before_record = []
    occur_after_record = []
    occur_joint_record = []
    occur_record = []

    for i, act in enumerate(actions):
        act_i = SUBCLASSES.index(act)

        occur_count[act_i] += 1
        if act not in occur_record:
            occur_count_vid[act_i] += 1
            occur_record.append(act)

        for j in range(0, i):
            act_j = SUBCLASSES.index(actions[j])
            if (act_j, act_i) not in occur_joint_record:
                occur_joint_vid[(act_j, act_i)] += 1
                occur_joint_record.append((act_j, act_i))
            occur_before[act_j, act_i] += 1
            occur_joint[act_j, act_i] += 1

        for j in range(i+1, len(actions)):
            act_j = SUBCLASSES.index(actions[j])
            if (act_j, act_i) not in occur_joint_record:
                occur_joint_vid[(act_j, act_i)] += 1
                occur_joint_record.append((act_j, act_i))
            occur_after[act_j, act_i] += 1
            occur_joint[act_j, act_i] += 1
    

Generate rule strings.

In [None]:
# BD: ~a_1 W a_2
bd_str = backward_dependency_clauses(SUBCLASSES, occur_joint, occur_before)[0]
# Ex: F a_1 -> ~F a_2 
ex_str = mutual_exclusivity_clauses(SUBCLASSES, occur_joint)[0]
# FC: ~a_1 S a_2
fc_str = forward_cancellation_clauses(SUBCLASSES, occur_joint, occur_after)[0]
# Ip: F a_1 -> F a_2
ip_str = implication_clauses(SUBCLASSES, occur_joint_vid, occur_count_vid)[0]
# The combined rule string
rule_str = '(' + ' & '.join([c for k in [bd_str, ex_str, fc_str, ip_str] for c in k if len(k) > 0]) + ')'

Verify constraints.
This can be much slower than 50Salads given the number of constraints.

Note that if you set `split` to non-zero values, the constraints will be wrong for some samples.
This is because the constraints are formed from a partial observation of the whole dataset.

In [None]:
annotation_files = glob(osp.join(data_root, 'breakfast', 'groundTruth', '*.txt'))
SUBCLASSES_ = [i.lower() for i in SUBCLASSES]
ap_map = lambda x: SUBCLASSES_.index(x)

evaluator = parser.parse(rule_str)

for fn in tqdm(annotation_files):
    gt_label = read_file(fn, True)

    gt_array = -1 * torch.ones((len(SUBCLASSES), len(gt_label)))
    for t, l in enumerate(gt_label):
        gt_array[SUBCLASSES_.index(l.lower()), t] = 1
    tl_value = evaluator(gt_array, ap_map=lambda x: SUBCLASSES_.index(x), rho=1000)

    if tl_value < 0.0: 
        print('conflict found: ', fn)        
        for i, children in enumerate(evaluator.children):
            if children.value.min() < 0.0:
                print('clause', i, children)