In [323]:
import os
import nltk
import json
import random
import jsonlines
import itertools
import collections
import numpy as np
import pandas as pd

from glob import glob
from tqdm import tqdm
from nltk.corpus import wordnet as wn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from pgmpy.models import BayesianModel
from pgmpy.factors.discrete import TabularCPD
from pgmpy.inference import BeliefPropagation

lemmatizer = nltk.stem.WordNetLemmatizer()

In [295]:
synsets = {
    "protests": {
        "police": wn.synset('law_enforcement_agency.n.01').name(),
        "bush": wn.synset('politician.n.02').name(),
        "putin": wn.synset('politician.n.02').name(),
        "hussein": wn.synset('politician.n.02').name(),
        "hizballah": wn.synset('politician.n.02').name(),
        "hezbollah": wn.synset('politician.n.02').name(),
        "gore": wn.synset('politician.n.02').name(),
        "politician": wn.synset('politician.n.02').name(),
        "gore": wn.synset('politician.n.02').name(),
        
        "city": wn.synset('urban_area.n.01').name(),
        "broadway": wn.synset('urban_area.n.01').name(),
        "orleans": wn.synset('urban_area.n.01').name(),
        "tokyo": wn.synset('urban_area.n.01').name(),
        "selma": wn.synset('urban_area.n.01').name(),
        "guadalajara": wn.synset('urban_area.n.01').name(),
        "paris": wn.synset('urban_area.n.01').name(),
        "zaragoza": wn.synset('urban_area.n.01').name(),
        "beijing": wn.synset('urban_area.n.01').name(),
        "washington": wn.synset('urban_area.n.01').name(),
        "zaragoza": wn.synset('urban_area.n.01').name(),
        
        "demonstrator": wn.synset('reformer.n.01').name(),
        "protester": wn.synset('reformer.n.01').name(),
        "demonstration": wn.synset('protest.n.02').name(),
        "protest": wn.synset('protest.n.02').name(),
        "sit-in": wn.synset('protest.n.02').name(),
        "march": wn.synset('protest.n.02').name(),
        "rally": wn.synset('protest.n.02').name(),
        
        "woman": wn.synset('person.n.01').name(),
        "man": wn.synset('person.n.01').name(),
        "daughter": wn.synset('person.n.01').name(),
        "student": wn.synset('person.n.01').name(),
        "group": wn.synset('group.n.01').name(),

        "japan": wn.synset('political_unit.n.01').name(),
        "china": wn.synset('political_unit.n.01').name(),
        "israel": wn.synset('political_unit.n.01').name(),
        "isreal": wn.synset('political_unit.n.01').name(),
        "egypt": wn.synset('political_unit.n.01').name(),
        "lithuania": wn.synset('political_unit.n.01').name(),
        "lebanon": wn.synset('political_unit.n.01').name(),
        "haiti": wn.synset('political_unit.n.01').name(),
        "fiji": wn.synset('political_unit.n.01').name(),
        "nigeria": wn.synset('political_unit.n.01').name(),
        "cuba": wn.synset('political_unit.n.01').name(),
        "iraq": wn.synset('political_unit.n.01').name(),
        "uganda": wn.synset('political_unit.n.01').name(),
        "cuba": wn.synset('political_unit.n.01').name(),
        "cuba": wn.synset('political_unit.n.01').name(),
        "cuba": wn.synset('political_unit.n.01').name(),
        "cuba": wn.synset('political_unit.n.01').name(),
    },
    "ce_005_disease_outbreak": {
        "organism": wn.synset('organism.n.01').name(),
        "fly": wn.synset('organism.n.01').name(),
        "cat": wn.synset('organism.n.01').name(),
        "rat": wn.synset('organism.n.01').name(),
        "mosquito": wn.synset('organism.n.01').name(),
        "cattle": wn.synset('organism.n.01').name(),
        "bird": wn.synset('organism.n.01').name(),
        "mouse": wn.synset('organism.n.01').name(),
        "gorilla": wn.synset('organism.n.01').name(),
        "deer": wn.synset('organism.n.01').name(),
        "horse": wn.synset('organism.n.01').name(),
        "worm": wn.synset('organism.n.01').name(),
        "bat": wn.synset('organism.n.01').name(),
        "sheep": wn.synset('organism.n.01').name(),
        "cow": wn.synset('organism.n.01').name(),
        "dog": wn.synset('organism.n.01').name(),
        "livestock": wn.synset('organism.n.01').name(),
        "bison": wn.synset('organism.n.01').name(),
        "duck": wn.synset('organism.n.01').name(),
        "frog": wn.synset('organism.n.01').name(),
        "chimpanzee": wn.synset('organism.n.01').name(),
        "pigeon": wn.synset('organism.n.01').name(),
        "swan": wn.synset('organism.n.01').name(),
        "shellfish": wn.synset('organism.n.01').name(),
        "pig": wn.synset('organism.n.01').name(),
        "salmon": wn.synset('organism.n.01').name(),
        "trout": wn.synset('organism.n.01').name(),
        "tick": wn.synset('organism.n.01').name(),
        "fish": wn.synset('organism.n.01').name(),
        "ferret": wn.synset('organism.n.01').name(),
        "rodent": wn.synset('organism.n.01').name(),
        "gerbil": wn.synset('organism.n.01').name(),
        "hamster": wn.synset('organism.n.01').name(),
        "raccoon": wn.synset('organism.n.01').name(),
        "lion": wn.synset('organism.n.01').name(),
        "crab": wn.synset('organism.n.01').name(),
        "elk": wn.synset('organism.n.01').name(),
        "mammal": wn.synset('organism.n.01').name(),
        "toad": wn.synset('organism.n.01').name(),
        "salmon": wn.synset('organism.n.01').name(),
        
        "poultry": wn.synset('food.n.02').name(),
        "beef": wn.synset('food.n.02').name(),
                
        "mumps": wn.synset('illness.n.01').name(),
        "meningitis": wn.synset('illness.n.01').name(),
        "polio": wn.synset('illness.n.01').name(),
        "tuberculosis": wn.synset('illness.n.01').name(),
        "dysentery": wn.synset('illness.n.01').name(),
        "leprosy": wn.synset('illness.n.01').name(),
        "cholera": wn.synset('illness.n.01').name(),
        "herpes": wn.synset('illness.n.01').name(),
        "fibrosis": wn.synset('illness.n.01').name(),
        "osteoporosis": wn.synset('illness.n.01').name(),
        "sclerosis": wn.synset('illness.n.01').name(),
        "influenza": wn.synset('illness.n.01').name(),
        "flu": wn.synset('illness.n.01').name(),
        "sickness": wn.synset('illness.n.01').name(),
        "infection": wn.synset('illness.n.01').name(),
        "illness": wn.synset('illness.n.01').name(),
        "injury": wn.synset('illness.n.01').name(),
        "leukemia": wn.synset('illness.n.01').name(),
        "carcinoma": wn.synset('illness.n.01').name(),
        "hiv": wn.synset('illness.n.01').name(),
        "asbestosis": wn.synset('illness.n.01').name(),
        "smallpox": wn.synset('illness.n.01').name(),
        "rabies": wn.synset('illness.n.01').name(),
        "anemia": wn.synset('illness.n.01').name(),
        "encephalitis": wn.synset('illness.n.01').name(),
        "zoonosis": wn.synset('illness.n.01').name(),
        "rinderpest": wn.synset('illness.n.01').name(),
        "bse": wn.synset('illness.n.01').name(),
        "osteoarthritis": wn.synset('illness.n.01').name(),
        "sepsis": wn.synset('illness.n.01').name(),
        "cancer": wn.synset('illness.n.01').name(),
        "trachoma": wn.synset('illness.n.01').name(),
        "malaria": wn.synset('illness.n.01').name(),
        "melanoma": wn.synset('illness.n.01').name(),
        "angioma": wn.synset('illness.n.01').name(),
        "plague": wn.synset('illness.n.01').name(),
        "sars": wn.synset('illness.n.01').name(),
        "pneumonia": wn.synset('illness.n.01').name(),
        "parkinson": wn.synset('illness.n.01').name(),
        "psoriasis": wn.synset('illness.n.01').name(),
        "arthritis": wn.synset('illness.n.01').name(),
        "mucopolysaccharidosis": wn.synset('illness.n.01').name(),
        "gonorrhea": wn.synset('illness.n.01').name(),
        "plague": wn.synset('illness.n.01').name(),
        "plague": wn.synset('illness.n.01').name(),
        "plague": wn.synset('illness.n.01').name(),
        "plague": wn.synset('illness.n.01').name(),
   
        "intestine": wn.synset('body_part.n.01').name(),
        "liver": wn.synset('body_part.n.01').name(),
        "lung": wn.synset('body_part.n.01').name(),
        "heart": wn.synset('body_part.n.01').name(),
        "lung": wn.synset('body_part.n.01').name(),
        "lung": wn.synset('body_part.n.01').name(),
        "lung": wn.synset('body_part.n.01').name(),
        "lung": wn.synset('body_part.n.01').name(),

        "bacillus": wn.synset('infectious_agent.n.01').name(),
        "spirochete": wn.synset('infectious_agent.n.01').name(),
        "microbe": wn.synset('infectious_agent.n.01').name(),
        "bacteria": wn.synset('infectious_agent.n.01').name(),
        "bacterium": wn.synset('infectious_agent.n.01').name(),
        "bacillus": wn.synset('infectious_agent.n.01').name(),
                
    },
    "election": {
        "candidate": wn.synset('campaigner.n.01').name(),
        "nominee": wn.synset('campaigner.n.01').name(),
        "candidate": wn.synset('campaigner.n.01').name(),
        "pac": wn.synset('campaigner.n.01').name(),
        
        "amendment": wn.synset('legislation.n.01').name(),
        "bill": wn.synset('legislation.n.01').name(),
        "treaty": wn.synset('legislation.n.01').name(),
        "contract": wn.synset('legislation.n.01').name(),
        "document.n.01": wn.synset('legislation.n.01').name(),
        
        "kennedy": wn.synset('civil_authority.n.01').name(),
        "administrative_unit.n.01": wn.synset('civil_authority.n.01').name(),
        "presiding_officer.n.01": wn.synset('civil_authority.n.01').name(),
        "clergyman.n.01": wn.synset('civil_authority.n.01').name(),
        "chief": wn.synset('civil_authority.n.01').name(),
        "chieftain": wn.synset('civil_authority.n.01').name(),
        "lawmaker": wn.synset('civil_authority.n.01').name(),
        "politician": wn.synset('civil_authority.n.01').name(),
        "bush": wn.synset('civil_authority.n.01').name(),
        "gore": wn.synset('civil_authority.n.01').name(),
        "clinton": wn.synset('civil_authority.n.01').name(),
        "nixon": wn.synset('civil_authority.n.01').name(),
        "president": wn.synset('civil_authority.n.01').name(),
        "governor": wn.synset('civil_authority.n.01').name(),
        "whip": wn.synset('civil_authority.n.01').name(),
        "senator": wn.synset('civil_authority.n.01').name(),
        "congresswoman": wn.synset('civil_authority.n.01').name(),
        "congressman": wn.synset('civil_authority.n.01').name(),
        "sheriff": wn.synset('civil_authority.n.01').name(),
        "deputy": wn.synset('civil_authority.n.01').name(),
        "representative": wn.synset('civil_authority.n.01').name(),
        "king": wn.synset('civil_authority.n.01').name(),
        "emperor": wn.synset('civil_authority.n.01').name(),
        "czar": wn.synset('civil_authority.n.01').name(),
        "judge": wn.synset('civil_authority.n.01').name(),
        "congress": wn.synset('civil_authority.n.01').name(),
        "parliament": wn.synset('civil_authority.n.01').name(),
        "senate": wn.synset('civil_authority.n.01').name(),
        "duma": wn.synset('civil_authority.n.01').name(),
        "eisenhower": wn.synset('civil_authority.n.01').name(),
        "house": wn.synset('civil_authority.n.01').name(),
        "jefferson": wn.synset('civil_authority.n.01').name(),
        "cabinet": wn.synset('civil_authority.n.01').name(),
        "official": wn.synset('civil_authority.n.01').name(),
        "representative.n.01": wn.synset('civil_authority.n.01').name(),
        "speaker": wn.synset('civil_authority.n.01').name(),
        "general": wn.synset('civil_authority.n.01').name(),
        "polity.n.02": wn.synset('civil_authority.n.01').name(),
        "legislator": wn.synset('civil_authority.n.01').name(),
        "hussein": wn.synset('civil_authority.n.01').name(),
        "mandela": wn.synset('civil_authority.n.01').name(),
        "reagan": wn.synset('civil_authority.n.01').name(),
        "buchanan": wn.synset('civil_authority.n.01').name(),
        "hussein": wn.synset('civil_authority.n.01').name(),
        "lawman.n.01": wn.synset('civil_authority.n.01').name(),
        
        "conservative": wn.synset('political_orientation.n.01').name(),
        "liberal": wn.synset('political_orientation.n.01').name(),
        "moderate": wn.synset('political_orientation.n.01').name(),
        "nationalist": wn.synset('political_orientation.n.01').name(),
        "center": wn.synset('political_orientation.n.01').name(),
        "middle": wn.synset('political_orientation.n.01').name(),
        "left": wn.synset('political_orientation.n.01').name(),
        "right": wn.synset('political_orientation.n.01').name(),
        "separatist": wn.synset('political_orientation.n.01').name(),
        "libertarian": wn.synset('political_orientation.n.01').name(),
        "democrat": wn.synset('political_orientation.n.01').name(),  # or party.n.01?
        "republican": wn.synset('political_orientation.n.01').name(),
        "communist": wn.synset('political_orientation.n.01').name(),
        "socialist": wn.synset('political_orientation.n.01').name(),
        "independent": wn.synset('political_orientation.n.01').name(),

        "people": wn.synset('people.n.01').name(),
        "adult.n.01": wn.synset('people.n.01').name(),
        "person.n.01": wn.synset('people.n.01').name(),
        
        "decision": wn.synset('vote.n.01').name(),
        "vote": wn.synset('vote.n.01').name(),
        "voting": wn.synset('vote.n.01').name(),
        
        "voter": wn.synset('voter.n.01').name(),
        "elector": wn.synset('voter.n.01').name(),
        
        "negative.n.01": wn.synset('result.n.03').name(),
        "affirmative.n.01": wn.synset('result.n.03').name(),
        "ending.n.04": wn.synset('result.n.03').name(),
        
        "colombia": wn.synset('political_unit.n.01').name(),
        "mexico": wn.synset('political_unit.n.01').name(),
        "georgia": wn.synset('political_unit.n.01').name(),
        "florida": wn.synset('political_unit.n.01').name(),
        "carolina": wn.synset('political_unit.n.01').name(),
        "colorado": wn.synset('political_unit.n.01').name(),
        "washington": wn.synset('political_unit.n.01').name(),
        "texas": wn.synset('political_unit.n.01').name(),
        "new_jersey.n.01": wn.synset('political_unit.n.01').name(),
        "missouri": wn.synset('political_unit.n.01').name(),
        "mississippi": wn.synset('political_unit.n.01').name(),
        "united_kingdom.n.01": wn.synset('political_unit.n.01').name(),
        "israel.n.01": wn.synset('political_unit.n.01').name(),
        "california": wn.synset('political_unit.n.01').name(),
        "peru": wn.synset('political_unit.n.01').name(),
        "province": wn.synset('political_unit.n.01').name(),
        "state": wn.synset('political_unit.n.01').name(),
        "country": wn.synset('political_unit.n.01').name(),
        "district": wn.synset('political_unit.n.01').name(),
        "municipality.n.01": wn.synset('political_unit.n.01').name(),
        "america": wn.synset('political_unit.n.01').name(),
        "ireland.n.01": wn.synset('political_unit.n.01').name(),
        
        "white": wn.synset('demographic.n.01').name(),
        "black": wn.synset('demographic.n.01').name(),
        "women": wn.synset('demographic.n.01').name(),
        "men": wn.synset('demographic.n.01').name(),
        "worker": wn.synset('demographic.n.01').name(),
        "gay": wn.synset('demographic.n.01').name(),
        "homosexual": wn.synset('demographic.n.01').name(),
        "jew": wn.synset('demographic.n.01').name(),
        "age": wn.synset('demographic.n.01').name(),
        "old": wn.synset('demographic.n.01').name(),
        "young": wn.synset('demographic.n.01').name(),
        "muslim.n.01": wn.synset('demographic.n.01').name(),
        "social_group.n.01": wn.synset('demographic.n.01').name(),
        "arab.n.01": wn.synset('demographic.n.01').name(),
        "south_american.n.01": wn.synset('demographic.n.01').name(),
        "american.n.01": wn.synset('demographic.n.01').name(),
        "asian.n.01": wn.synset('demographic.n.01').name(),
        "latino": wn.synset('demographic.n.01').name(),
        "inhabitant.n.01": wn.synset('demographic.n.01').name(),
        "absentee": wn.synset('demographic.n.01').name(),
        "central_american.n.01": wn.synset('demographic.n.01').name(),
        "catholic.n.01": wn.synset('demographic.n.01').name(),
        "christian.n.01": wn.synset('demographic.n.01').name(),
        "religious_person.n.01": wn.synset('demographic.n.01').name(),
        "juvenile.n.01": wn.synset('demographic.n.01').name(),
        
        "large_integer.n.01": wn.synset("number.n.01").name(),
        "proportion.n.01": wn.synset("number.n.01").name(),
        "common_fraction.n.01": wn.synset("number.n.01").name(),
        "digit.n.01": wn.synset("number.n.01").name(),
        "large_indefinite_quantity.n.01": wn.synset("number.n.01").name(),
    },
    "arrest": {
        "murder": wn.synset('crime.n.01').name(),
        "smuggling": wn.synset('crime.n.01').name(),
        "intimidation": wn.synset('crime.n.01').name(),
        "assault": wn.synset('crime.n.01').name(),
        "scam": wn.synset('crime.n.01').name(),
        "kidnapping": wn.synset('crime.n.01').name(),
        "killing": wn.synset('crime.n.01').name(),
        
        "trafficker": wn.synset('criminal.n.01').name(),
        "dealer": wn.synset('criminal.n.01').name(),
        "shoplifter": wn.synset('criminal.n.01').name(),
        "rioter": wn.synset('criminal.n.01').name(),
        "troublemaker": wn.synset('criminal.n.01').name(),
        "plotter": wn.synset('criminal.n.01').name(),
        "abductor": wn.synset('criminal.n.01').name(),
        "offender": wn.synset('criminal.n.01').name(),
        "hijacker": wn.synset('criminal.n.01').name(),
        "rapist": wn.synset('criminal.n.01').name(),
        "stoner": wn.synset('criminal.n.01').name(),
        "predator": wn.synset('criminal.n.01').name(),
        "felon": wn.synset('criminal.n.01').name(),
        "criminal": wn.synset('criminal.n.01').name(),
        "pimp": wn.synset('criminal.n.01').name(),
        "abuser": wn.synset('criminal.n.01').name(),
        "attacker": wn.synset('criminal.n.01').name(),
        "perpetrator": wn.synset('criminal.n.01').name(),
        "pedophile": wn.synset('criminal.n.01').name(),
        "flasher": wn.synset('criminal.n.01').name(),
        
        # Group special witnesses together (doctor, psychologist, etc)

        "police": wn.synset('law_enforcement_agency.n.01').name(),
        "bush": wn.synset('politician.n.02').name(),

        "syria": wn.synset('political_unit.n.01').name(),
        "iraq": wn.synset('political_unit.n.01').name(),
        "yugoslavia": wn.synset('political_unit.n.01').name(),
        "zagreb": wn.synset('political_unit.n.01').name(),
        "nato": wn.synset('political_unit.n.01').name(),
        "europe": wn.synset('political_unit.n.01').name(),
        "eritrea": wn.synset('political_unit.n.01').name(),
        "lebanon": wn.synset('political_unit.n.01').name(),
        "haiti": wn.synset('political_unit.n.01').name(),
        "fiji": wn.synset('political_unit.n.01').name(),
        "nigeria": wn.synset('political_unit.n.01').name(),
        "cuba": wn.synset('political_unit.n.01').name(),
        "uganda": wn.synset('political_unit.n.01').name(),
        "pakistan": wn.synset('political_unit.n.01').name(),
        "iran": wn.synset('political_unit.n.01').name(),
        "belgium": wn.synset('political_unit.n.01').name(),
        "spain": wn.synset('political_unit.n.01').name(),
        "bengal": wn.synset('political_unit.n.01').name(),
        
        "diplomat": wn.synset('politician.n.02').name(),
        "vargas": wn.synset('politician.n.02').name(),
        "putin": wn.synset('politician.n.02').name(),
        "chiluba": wn.synset('politician.n.02').name(),
        "president": wn.synset('politician.n.02').name(),
    },
    "plane_crash": {
        "plane": wn.synset('airplane.n.01').name(),
        "helicopter": wn.synset('airplane.n.01').name(),
        "sailplane": wn.synset('airplane.n.01').name(),
        "airplane": wn.synset('airplane.n.01').name(),
        "glider": wn.synset('airplane.n.01').name(),
        "warplane": wn.synset('airplane.n.01').name(),
        "jetliner": wn.synset('airplane.n.01').name(),
        "aircraft": wn.synset('airplane.n.01').name(),
        "turboprop": wn.synset('airplane.n.01').name(),
        "warplane": wn.synset('airplane.n.01').name(),
        "warplane": wn.synset('airplane.n.01').name(),
        "warplane": wn.synset('airplane.n.01').name(),
    }
}

stopwords = {"we", "they", "who", "which", "that", "i", "it", "he", "she", "them", "him", "her", "u", "me", "a"}

In [324]:
schema_name = 'ce_040_arrest'

In [325]:
schema_descr_path = '../data/schemas/descrs'
cgw_schema_rel_docs_path = '../data/cgw/schema_related/pos'
cgw_preds_args_path = '../data/cgw/preds_args'

In [326]:
data_split = {y: 'train' for y in range(1994, 2007)}
data_split.update({2007: 'dev', 2008: 'dev', 2009: 'test', 2010: 'test'})

In [None]:
# Read splits (train, dev, and test)
split_docs_ids = {}

# Collect schema-related docs' ids
with jsonlines.open(f'{cgw_schema_rel_docs_path}/{schema_name}.jsonl') as reader:
    for doc in reader:
        assert doc['id'].startswith('NYT_ENG_')
        year = int(doc['id'][8:12])
        split_docs_ids[doc['id']] = data_split[year]

In [328]:
# Load schema events
with open(f'{schema_descr_path}/{schema_name}.json') as fin:
    schema_descr = json.load(fin)
    filtered_events = schema_descr['predpatt']

# Load all docs
split_raw_docs_ids = {'train': [], 'dev': [], 'test': []}
split_raw_docs = {'train': [], 'dev': [], 'test': []}

for pred_arg_path in tqdm(glob(f'{cgw_preds_args_path}/*.jsonl')):
    with jsonlines.open(pred_arg_path) as reader:
        for doc in reader:
            assert doc['filename'].endswith('.comm')
            doc_id = doc['filename'][:-5]
            #if doc_id in split_docs_ids:
            # doc_split = split_docs_ids[doc_id]
            doc_split = data_split[int(doc_id[8:12])]
            if doc_split != 'train':
                continue
            # Experiment with filtering by protagonist
            if True:
                doc_repr = [(a0 if ix == 0 else a1) for a0, a1, ix in doc['preds_args']]
                doc_repr = [e for e in doc_repr if e in filtered_events]
            else:
                doc_repr = []
                for pred_args in doc['preds_args']:
                    pred_idx = pred_args[2]
                    arg_idx = 1 - pred_idx
                    pred, arg = pred_args[pred_idx], pred_args[arg_idx].lower()
                    if pred not in filtered_events:
                        continue
                    arg = lemmatizer.lemmatize(arg, pos=wn.NOUN)
                    if arg.lower() in stopwords:
                        continue
                    # Check if word is in the custom synset list
                    # Use WordNet lookup if not
                    if arg in synsets[schema_name]:
                        syn_root = synsets[schema_name][arg]
                    else:
                        if not wn.synsets(arg, pos=wn.NOUN):
                            continue
                        syn = wn.synsets(arg, pos=wn.NOUN)[0]
                        if syn.hypernyms():
                            syn_root = syn.hypernyms()[0].name()
                        else:
                            syn_root = syn.name()
                    # TODO: remove later
                    if syn_root == 'illness.n.01':
                        # doc_repr.append(f'{pred}_{syn_root}_{pred_idx}')
                        doc_repr.append(f'{pred}')
            if len(doc_repr) > 5:
                split_raw_docs_ids[doc_split].append(doc_id)
                split_raw_docs[doc_split].append(doc_repr)

100%|██████████| 197/197 [01:33<00:00,  2.10it/s]


In [329]:
len(split_raw_docs)

3

In [330]:
len(split_raw_docs['train']) + len(split_raw_docs['dev']) + len(split_raw_docs['test'])

10582

In [331]:
# For debug purposes only
all_event_cntr = collections.Counter()
for doc in split_raw_docs['train']:
    for e in set(doc):
        all_event_cntr[e] += 1

In [332]:
# Read json description of a graph's edges

# graph_path = f'../data/cgw/causal/with-args/{schema_name}_out_edges.json'
graph_path = f'../data/cgw/causal/{schema_name}_out_edges.json'

with open(graph_path) as fin:
    edges = json.load(fin)

vertices = set()
parents = collections.defaultdict(list)
for u, v in edges:
    parents[v].append(u)
    vertices.add(u)
    vertices.add(v)

In [333]:
G = BayesianModel(edges)

# Compute marginal probabilities of events on train split

marginal_vertices = vertices - parents.keys()

event_cntr = collections.Counter()
for doc in split_raw_docs['train']:
    for e in set(doc) & marginal_vertices:
        event_cntr[e] += 1
        
# event_prob = {}
# for e, cnt in event_cntr.items():
#     event_prob[e] = cnt / len(split_raw_docs['train'])
#     G.add_cpds(TabularCPD(e, 2, [[1 - event_prob[e]], [event_prob[e]]]))

# Compute conditional probabilities of events using LogisticRegression and causal graph

# Build features
vectorizer = TfidfVectorizer(tokenizer=lambda x: x, lowercase=False)
train_tfidf = vectorizer.fit_transform(split_raw_docs['train'])
train_tfidf = train_tfidf.toarray()

# Train classifier for each node with at least one causal parent
# for u, vs in parents.items():
#     X_train = (train_tfidf[:, [vectorizer.vocabulary_[ix] for ix in vs]] > 0).astype(int)
#     y_train = (train_tfidf[:, vectorizer.vocabulary_[u]] > 0).astype(int)
#     clf = LogisticRegression(C=1.0, random_state=0)
#     clf.fit(X_train, y_train)
#     print(all_event_cntr[u], [all_event_cntr[v] for v in vs], clf.coef_, clf.intercept_)
# #     print(u, vs)
# #     print('  ', y_train.mean())
# #     print('  ', clf.predict(X_train).mean())
#     # Create a binary prediction mask
#     X_cpd = []
#     pad_length = len(vs)
#     pad_fmt_mask = '{:0%db}' % pad_length
#     for x in range(2**pad_length):
#         x_str = pad_fmt_mask.format(x)
#         x_instance = [int(d) for d in x_str]
#         X_cpd.append(x_instance)
#     X_cpd = np.array(X_cpd, dtype=np.float32)
#     u_cpd = clf.predict_proba(X_cpd).T
#     G.add_cpds(TabularCPD(u, 2, u_cpd, evidence=vs, evidence_card=[2] * len(vs)))

G.fit(pd.DataFrame((train_tfidf > 0).astype(int), columns=header))

In [334]:
all_event_cntr

Counter({'testify': 2677,
         'face': 7053,
         'charge': 3886,
         'accuse': 4001,
         'arrest': 1611,
         'indict': 820,
         'convict': 993,
         'plead': 2352,
         'acquit': 406,
         'sentence': 605})

In [335]:
print(G.cpds[7])

+----------+--------------------+---------------------+---------------------+---------------------+
| accuse   | accuse(0)          | accuse(0)           | accuse(1)           | accuse(1)           |
+----------+--------------------+---------------------+---------------------+---------------------+
| charge   | charge(0)          | charge(1)           | charge(0)           | charge(1)           |
+----------+--------------------+---------------------+---------------------+---------------------+
| plead(0) | 0.7307979120059657 | 0.8025801407349492  | 0.8077066965955855  | 0.8117469879518072  |
+----------+--------------------+---------------------+---------------------+---------------------+
| plead(1) | 0.2692020879940343 | 0.19741985926505082 | 0.19229330340441453 | 0.18825301204819278 |
+----------+--------------------+---------------------+---------------------+---------------------+


## Run Belief propagation

In [336]:
belief_propagation = BeliefPropagation(G)

In [337]:
belief_propagation.query(variables=['convict'],
                         evidence={'accuse': 1, 'sentence': 1, 'charge': 1}).values

Eliminating: acquit: 100%|██████████| 5/5 [00:00<00:00, 740.47it/s]


array([0.62698413, 0.37301587])

In [2]:
from pgmpy.factors.discrete import TabularCPD
from pgmpy.models import BayesianModel
from pgmpy.inference import BeliefPropagation
bayesian_model = BayesianModel([('A', 'J'), ('R', 'J'), ('J', 'Q'),
                                ('J', 'L'), ('G', 'L')])
cpd_a = TabularCPD('A', 2, [[0.2], [0.8]])
cpd_r = TabularCPD('R', 2, [[0.4], [0.6]])
cpd_j = TabularCPD('J', 2,
                   [[0.9, 0.6, 0.7, 0.1],
                    [0.1, 0.4, 0.3, 0.9]],
                   ['R', 'A'], [2, 2])
cpd_q = TabularCPD('Q', 2,
                   [[0.9, 0.2],
                    [0.1, 0.8]],
                   ['J'], [2])
cpd_l = TabularCPD('L', 2,
                   [[0.9, 0.45, 0.8, 0.1],
                    [0.1, 0.55, 0.2, 0.9]],
                   ['G', 'J'], [2, 2])
cpd_g = TabularCPD('G', 2, [[0.6], [0.4]])
bayesian_model.add_cpds(cpd_a, cpd_r, cpd_j, cpd_q, cpd_l, cpd_g)
belief_propagation = BeliefPropagation(bayesian_model)
belief_propagation.map_query(variables=['J', 'Q'],
                             evidence={'A': 0, 'R': 0, 'G': 0, 'L': 1})

0it [00:00, ?it/s]


{'J': 0, 'Q': 0}

---