In [1]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, modified_precision
from nltk.translate.chrf_score import sentence_chrf, corpus_chrf
from nltk.metrics import scores
import scipy.io.wavfile
from IPython.display import Audio
from IPython.display import display
from nltk.stem import *
# from nltk.stem.snowball import SnowballStemmer
from stemming.porter2 import stem
import stemming
from nltk.metrics.scores import recall

from basics import *

from nltk.corpus import stopwords

%matplotlib inline

In [2]:
smooth_fun = nltk.translate.bleu_score.SmoothingFunction()

In [3]:
from nmt_run import *

In [4]:
cfg_path = "interspeech_bpe/sp_20hrs_best_bn-nobias_batch-32_buck-n25-w80"
dec_key = "bpe_w"

In [5]:
%%capture
last_epoch, model, optimizer, m_cfg, t_cfg = check_model(cfg_path)

In [6]:
%%capture
map_dict, vocab_dict, bucket_dict = get_data_dicts(m_cfg)

In [7]:
random.seed("meh")
# random.seed("haha")

In [8]:
def clean_out_str(out_str):
    out_str = out_str.replace("`", "")
    out_str = out_str.replace('"', '')
    out_str = out_str.replace('¿', '')
    out_str = out_str.replace("''", "")

    # for BPE
    out_str = out_str.replace("@@ ", "")
    out_str = out_str.replace("@@", "")

    out_str = out_str.strip()
    return out_str


# In[46]:


def get_out_str(word_list, dec_key="bpe_w"):
    h = [w.decode() for w in word_list]
    out_str = ""
    if dec_key == "en_w":
        for w in h:
            out_str += "{0:s}".format(w) if (w.startswith("'") or w=="n't") else " {0:s}".format(w)
    elif "bpe_w" in dec_key:
        out_str = " ".join(h)

    elif dec_key == "en_c":
        out_str = "".join(h)

    else:
        out_str = "".join(h)

    out_str = clean_out_str(out_str)
    return out_str

In [9]:
def get_text(set_key, ref_num=0):
    set_text = {}
    for u in map_dict[set_key]:
        if "train" in set_key:
            set_text[u] = get_out_str(map_dict[set_key][u]["bpe_w"])
        else:
            # fisher_dev has 4 references, but only 1 for BPE
            # modify logic if using en_w
            set_text[u] = get_out_str(map_dict[set_key][u]["en_w"][ref_num], "en_w")
    print(len(set_text))
    return set_text

In [10]:
train_text = get_text("fisher_train")
list(train_text.values())[:10]

138819


['hello',
 'hello',
 'hello',
 'hello',
 'with whom am i speaking',
 'eh silvia yes what is your name',
 'hello silvia eh my name is nicole',
 'ah nice to meet you',
 'nice to meet you em and where are you from',
 "eh i'm in philadelphia"]

In [11]:
dev_text = []
for i in range(4):
    dev_text.append(get_text("fisher_dev", ref_num=i))
list(dev_text[0].values())[:10]

3979
3979
3979
3979


['afternoon',
 'good afternoon',
 'my name is carmen in chicago you',
 'oh my name is ricardo',
 'of',
 "ah i'm in ah pennsylvania",
 'okay good afternoon',
 'well how are you good',
 'good thank god and you',
 "very good thank god as well it's very cold here is it very cold in chicago"]

In [12]:
dev2_text = []
for i in range(4):
    dev2_text.append(get_text("fisher_dev2", ref_num=i))
list(dev2_text[0].values())[:10]

3961
3961
3961
3961


['hi good af good evening',
 "good evening it's norma here from atlanta",
 "oh well would you look at that i'm from",
 "oh yeah we're we're near",
 "yes we're pretty close is this the first call you took",
 "yes because yesterday what happened is that since i was out they called me right the the computer called me but i wasn't here so couldn't",
 "i couldn't answer",
 "oh and i am on my cellphone they called me three times but i was at the store so i couldn't you know take ten minutes to talk",
 'sure',
 'but where are you from']

In [13]:
en_stop_words = set(nltk.corpus.stopwords.words("english"))

es_stop_words = set(nltk.corpus.stopwords.words("spanish"))

es_en_stop_words = en_stop_words | es_stop_words
len(en_stop_words), len(es_stop_words), len(es_en_stop_words)

(127, 313, 435)

In [14]:
def search_text(set_text, query_list):
    query_set = set([stem(q) for q in query_list])
    query_results = {q:[] for q in query_set}
    
    for u, v in tqdm(set_text.items(), ncols=120):
        words_in_text = set([stem(w) for w in v.strip().split()])
        common_words = query_set & words_in_text
        for q in common_words:
            query_results[q].append(u)
            # end if query check
        # end for all query terms
    # end for all utterances
    return query_results
# end function

In [15]:
def get_common_ref_queries(dev_query_results):
    dev_common_results = {}
    for q in dev_query_results[0]:
        common_utts = set(dev_query_results[0][q])
        for i in range(1,4):
            common_utts &= set(dev_query_results[i][q])
        dev_common_results[q] = common_utts
    return dev_common_results

In [16]:
def get_query_terms_thresh(train_results, dev_common, dev2_common, min_c, max_c):
    min_freq_dev_queries = [(q, c, dev2_common[q], train_results[q]) for q,c in dev_common.items() 
                            if len(c) >= min_c and len(c) <= max_c]
    print("{0:>5s} -- {1:10s} | {2:10s} | {3:10s} | {4:10s}".format("#", "query", 
                                                                    "dev count", "dev2_count", "train count"))
    for i, (q, v, d2, t) in enumerate(min_freq_dev_queries, start=1):
        print("{0:>5d} -- {1:10s} | {2:10d} | {3:10d} | {4:10d}".format(i, q, len(v), len(d2), len(t)))
    return {q: {"train": t, "dev": d, "dev2": d2} for q, d, d2, t in min_freq_dev_queries}

In [17]:
# def search_text(set_text, query_list):
#     query_results = {q:[] for q in query_list}
#     for u, v in tqdm(set_text.items(), ncols=80):
#         for q in query_list:
#             words_in_text = set([stem(w) for w in v.strip().split()])
#             if stem(q) in words_in_text:
#                 query_results[q].append(u)
#             # end if query check
#         # end for all query terms
#     # end for all utterances
# # end function

### Topics

In [18]:
topics_fname = "../criseslex/fsp06_topics_in_english.txt"

In [19]:
topics = [ "peace", "Music", "Marriage", "Religion", "Cell phones", 
           "Dating", "Telemarketing and SPAM", "Politics", "Travel", 
           "Technical devices", "Healthcare", "Advertisements", "Power", 
           "Occupations", "Movies", "Welfare", "Breaking up", "Location", 
           "Justice", "Memories", "Crime", "Violence against women", "Equality", 
            "Housing", "Immigration",     
            # new topics
           "Interracial", "Christians", "muslims", "jews", "e-mail", 
           "phone", "democracy", "Democratic", "Republican", "technology", 
           "leadership", "community", "jury", "police", "inequality", 
           "renting", "Violence", "immigrants", "immigrant", "skilled", 
           "Telemarketing", "SPAM", "skill", "job", "health", "mobile", 
            "ads", "physical", "emotional", "bubble", "rent", "economy", 
            "abuse", "women", "city", "country", "suburban", "dollar", 
            "united states", "laws", "phone", "race", "biracial", "interracial", 
            "marriage", "lyrics", "sexuality", "medicine", "television", "european",
            "home", "protect", "spouse", "language", "cellphone", "money",
            "doctor", "insurance", "cigarettes", "alcohol", "income", "salary",
            "class", "censor", "rating", "programs", "government",
            "relationship", "legal", "event", "life", "safe", "victim", "cops",
            "wage", "illegal"
            ]
topics = list(set(t.lower() for t in topics))
topics_stem = [stem(t) for t in topics]
len(topics), len(topics_stem)

(98, 98)

In [20]:
topics_stem[:5]

['doctor', 'travel', 'christian', 'sexual', 'cellphon']

In [21]:
topics_train_query_results = search_text(train_text, topics)

100%|█████████████████████████████████████████████████████████████████████████| 138819/138819 [00:25<00:00, 5545.94it/s]


In [22]:
topics_dev_query_results = []
for i in range(4):
    topics_dev_query_results.append(search_text(dev_text[i], topics))    

100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5825.34it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5863.21it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5349.02it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5645.29it/s]


In [23]:
topics_dev2_query_results = []
for i in range(4):
    topics_dev2_query_results.append(search_text(dev2_text[i], topics))    

100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5727.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5807.30it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5879.98it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5916.87it/s]


In [24]:
# all_dev_counts = [(q, [len(topics_dev_query_results[i][q]) for i in range(4)]) for q in topics_train_query_results]

In [25]:
topics_dev_common = get_common_ref_queries(topics_dev_query_results)

In [26]:
topics_dev2_common = get_common_ref_queries(topics_dev2_query_results)

In [27]:
topic_query_terms = get_query_terms_thresh(topics_train_query_results, 
                                           topics_dev_common, topics_dev2_common, 2, 30)

    # -- query      | dev count  | dev2_count | train count
    1 -- doctor     |          2 |          1 |        350
    2 -- travel     |         19 |          0 |        403
    3 -- juri       |          2 |          0 |        208
    4 -- job        |          3 |          6 |        634
    5 -- interraci  |          2 |          0 |         65
    6 -- spam       |          9 |         15 |         86
    7 -- rent       |          3 |          3 |        330
    8 -- medicin    |          3 |          0 |        143
    9 -- alcohol    |          3 |          0 |         47
   10 -- christian  |          6 |          3 |        248
   11 -- hous       |         19 |         10 |       1380
   12 -- money      |         30 |         38 |       1219
   13 -- protect    |          2 |          1 |        111
   14 -- justic     |          2 |          0 |         85
   15 -- home       |          7 |         12 |        555
   16 -- govern     |          8 |          9 |        

In [28]:
print(sum([1 if len(d["dev2"]) > 0 else 0 for d in topic_query_terms.values()]))
print(np.sum([len(d["dev2"]) for d in topic_query_terms.values()]))

25
254


### Read predictions

In [29]:
dev_ids = []
with open("./fisher/fisher_dev/fisher_dev_eval.ids", "r", encoding="utf-8") as in_f:
    for line in in_f:
        dev_ids.append(line.strip())
     # end for
# end with

dev2_ids = []
with open("fisher/fisher_dev2/fisher_dev2_eval.ids", "r", encoding="utf-8") as in_f:
    for u in in_f:
        dev2_ids.append(u.strip())

In [30]:
google_s2t_hyps_path = os.path.join("../chainer2/speech2text/both_fbank_out/", "google_s2t_hyps.dict")
google_s2t_hyps = pickle.load(open(google_s2t_hyps_path, "rb"))
google_hyp_dev = google_s2t_hyps['fisher_dev_r0']

In [31]:
google_s2t_hyps_dev2 = pickle.load(open("./google/google_s2t_dev2_hyps.dict", "rb"))
google_hyp_dev2 = google_s2t_hyps_dev2['fisher_dev_r0']

In [32]:
model_pred_files = {
    "2.5h": "experiments/nmt_asr/sp-2.5hrs_swbd1-train_nodev_ep25_baseline",
    "2.5h+asr": "experiments/nmt_asr/sp-2.5hrs_swbd1-train_nodev_ep25",
    "5h": "experiments/nmt_asr/sp-5hrs_swbd1-train100k_baseline",
    "5h+asr": "experiments/nmt_asr/sp-5hrs_swbd1-train_nodev_ep25",
    "10h": "interspeech_bpe/sp_20hrs_best_bn-nobias_batch-32_buck-n25-w80/",
    "10h+asr": "experiments/nmt_asr/sp-10hrs_swbd1-train_nodev_ep25_enc-attn-dec",
    "20h": "interspeech_bpe/sp_20hrs_best_bn-nobias_batch-32_buck-n25-w80/",
    "20h+asr": "experiments/nmt_asr/sp-20hrs_swbd1-train-nodev_ep25_enc-attn-dec/",
    "50h": "interspeech_bpe/sp_50hrs_best_bn-nobias_bucks-n25-w80_x0.2/",
    "50h+asr": "experiments/nmt_asr/sp-50hrs_swbd1-train-nodev_ep25/",
#     "160h": "interspeech_bpe/sp_160hrs_cnn-512-9-mfcc-13_drpt-0.3_l2e-4_rnn-3"
}

dev_beam_fname = "fisher_dev_beam_len-norm_min-0_max-300_N-5_K-5_W-0.6.en"
dev2_beam_fname = "fisher_dev2_beam_len-norm_min-0_max-300_N-5_K-5_W-0.6.en"

In [33]:
pred_dev_text = {}
pred_dev2_text = {}
for k, v in model_pred_files.items():
    pred_dev_text[k] = {}
    with open(os.path.join(v,dev_beam_fname), "r", encoding="utf-8") as in_f:
        for i, line in enumerate(in_f):
            utt = dev_ids[i]
            pred_dev_text[k][utt] = line.strip()
    
    pred_dev2_text[k] = {}
    with open(os.path.join(v,dev2_beam_fname), "r", encoding="utf-8") as in_f:
        for i, line in enumerate(in_f):
            utt = dev2_ids[i]
            pred_dev2_text[k][utt] = line.strip()
#         end for line
#     end with open
# end for

In [34]:
pred_dev_text["Weiss"] = {}
for utt in dev_ids:
    pred_dev_text["Weiss"][utt] = " ".join(google_hyp_dev[utt])

pred_dev2_text["Weiss"] = {}
for utt in dev2_ids:
    pred_dev2_text["Weiss"][utt] = " ".join(google_hyp_dev2[utt])

### Evaluation

In [35]:
topics_pred_dev_results = {}
topics_pred_dev2_results = {}
for k in pred_dev_text:
    topics_pred_dev_results[k] = search_text(pred_dev_text[k], topics)
    if k in pred_dev2_text:
        topics_pred_dev2_results[k] = search_text(pred_dev2_text[k], topics)

100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 6525.01it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 6980.10it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 6588.12it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 6485.48it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 7724.98it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 7421.03it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 6142.30it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 6129.89it/s]
100%|███████████████████████████

In [36]:
def compute_prec_recall(query_docs, pred_docs, set_key):
    query_matches = {}
    aggregate_results = {"t": 0, "tp": 0, "tc": 0}
    for i, q in enumerate(query_docs):
        pred_set = set(pred_docs[q])
        query_matches[q] = {}
        query_matches[q]["t"] = len(query_docs[q][set_key])
        query_matches[q]["tp"] = len(pred_set)
        query_matches[q]["tc"] = len(pred_set & query_docs[q][set_key])
        query_matches[q]["prec"] = query_matches[q]["tc"] / query_matches[q]["tp"] if query_matches[q]["tp"] > 0 else 0
        query_matches[q]["rec"] = query_matches[q]["tc"] / query_matches[q]["t"] if query_matches[q]["t"] > 0 else 0
    
    for q in query_matches:
        for k in aggregate_results:
            aggregate_results[k] += query_matches[q][k]
            
    aggregate_results["prec"] = aggregate_results["tc"] / aggregate_results["tp"] if aggregate_results["tp"] > 0 else 0
    aggregate_results["rec"] = aggregate_results["tc"] / aggregate_results["t"] if aggregate_results["t"] > 0 else 0
    
    return query_matches, aggregate_results
        

In [76]:
def evaluate_preds(query_terms, eval_matches, set_key):
    eval_results = {}
    print("{0:10s} | {1:10s}% | {2:10s}%".format("model", "precision", "recall"))
    for k in eval_matches:
        eval_results[k] = {}
        eval_results[k]["query"], eval_results[k]["aggr"] = compute_prec_recall(query_terms, eval_matches[k], set_key)
        print("{0:10s} | {1:10.1f} | {2:10.1f}".format(k, eval_results[k]["aggr"]["prec"]*100, eval_results[k]["aggr"]["rec"]*100))
    return eval_results
    

In [77]:
topics_dev_eval = evaluate_preds(topic_query_terms, topics_pred_dev_results, "dev")

model      | precision % | recall    %
2.5h       |        0.0 |        0.0
2.5h+asr   |        7.1 |        9.2
5h         |        0.0 |        0.0
5h+asr     |       16.9 |       19.5
10h        |       20.8 |       24.5
10h+asr    |       24.9 |       34.5
20h        |       20.8 |       24.5
20h+asr    |       33.9 |       50.6
50h        |       44.3 |       59.0
50h+asr    |       44.7 |       66.3
Weiss      |       60.2 |       91.6


In [78]:
topics_dev2_eval = evaluate_preds(topic_query_terms, topics_pred_dev2_results, "dev2")

model      | precision % | recall    %
2.5h       |        1.0 |        0.4
2.5h+asr   |        6.1 |        7.9
5h         |        0.0 |        0.0
5h+asr     |       16.1 |       18.9
10h        |       24.7 |       31.5
10h+asr    |       29.8 |       40.9
20h        |       24.7 |       31.5
20h+asr    |       36.0 |       51.6
50h        |       38.5 |       55.5
50h+asr    |       44.0 |       64.6
Weiss      |       56.2 |       87.8


In [37]:
topics_pred_dev_results["20h+asr"].keys()

dict_keys(['doctor', 'travel', 'juri', 'rate', 'job', 'interraci', 'breaking up', 'relationship', 'telemarket', 'spam', 'democraci', 'polic', 'republican', 'technical devic', 'violence against women', 'rent', 'medicin', 'alcohol', 'muslim', 'date', 'christian', 'jew', 'hous', 'memori', 'wage', 'united st', 'music', 'money', 'locat', 'protect', 'citi', 'health', 'incom', 'justic', 'home', 'govern', 'marriag', 'law', 'skill', 'women', 'countri', 'cellphon', 'welfar', 'e-mail', 'telemarketing and spam', 'dollar', 'polit', 'suburban', 'violenc', 'legal', 'race', 'crime', 'emot', 'european', 'illeg', 'event', 'movi', 'insur', 'sexual', 'lyric', 'phone', 'immigr', 'bubbl', 'televis', 'censor', 'program', 'religion', 'safe', 'peac', 'languag', 'cell phon', 'ad', 'spous', 'occup', 'abus', 'healthcar', 'economi', 'salari', 'cop', 'power', 'mobil', 'technolog', 'physic', 'democrat', 'life', 'communiti', 'equal', 'advertis', 'inequ', 'victim', 'biraci', 'class', 'cigarett', 'leadership'])

### Crisis words

In [87]:
crisis_lex_fname = "../criseslex/CrisisLexLexicon/CrisisLexRec.txt"

In [88]:
crisis = set()
with open(crises_lex_fname, "r") as in_f:
    for line in in_f:
        crisis.update(line.strip().split())
crisis = list(crisis)
crisis_stem = [stem(w) for w in crisis]

In [89]:
len(crisis_stem)

288

In [91]:
crisis_train_query_results = search_text(train_text, crisis)

100%|█████████████████████████████████████████████████████████████████████████| 138819/138819 [00:24<00:00, 5579.66it/s]


In [92]:
crisis_dev_query_results = []
for i in range(4):
    crisis_dev_query_results.append(search_text(dev_text[i], crisis))    

100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5850.64it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5861.59it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5852.02it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3979/3979 [00:00<00:00, 5764.79it/s]


In [93]:
crisis_dev2_query_results = []
for i in range(4):
    crisis_dev2_query_results.append(search_text(dev2_text[i], crisis))    

100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5814.70it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5826.08it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5887.33it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3961/3961 [00:00<00:00, 5925.61it/s]


In [94]:
crisis_dev_common = get_common_ref_queries(crisis_dev_query_results)

In [95]:
crisis_dev2_common = get_common_ref_queries(crisis_dev2_query_results)

In [102]:
crisis_query_terms = get_query_terms_thresh(crisis_train_query_results, 
                                           crisis_dev_common, crisis_dev2_common, 2, 30)

    # -- query      | dev count  | dev2_count | train count
    1 -- love       |         19 |          5 |        879
    2 -- case       |          3 |          1 |        605
    3 -- name       |         23 |         27 |       1150
    4 -- accept     |          2 |          0 |        205
    5 -- die        |          4 |          4 |        282
    6 -- second     |          4 |          8 |        267
    7 -- number     |          3 |         10 |        356
    8 -- anoth      |          9 |          8 |       1215
    9 -- wait       |         15 |          4 |        363
   10 -- free       |          5 |          5 |        263
   11 -- girl       |         12 |          9 |        748
   12 -- black      |          6 |          2 |        269
   13 -- death      |          3 |          0 |         45
   14 -- thousand   |          6 |          8 |        454
   15 -- chang      |         21 |          8 |        785
   16 -- lost       |          2 |          4 |        

In [103]:
print(sum([1 if len(d["dev2"]) > 0 else 0 for d in crisis_query_terms.values()]))
print(np.sum([len(d["dev2"]) for d in crisis_query_terms.values()]))

57
403


In [105]:
crisis_pred_dev_results = {}
crisis_pred_dev2_results = {}
for k in pred_dev_text:
    crisis_pred_dev_results[k] = search_text(pred_dev_text[k], crisis)
    if k in pred_dev2_text:
        crisis_pred_dev2_results[k] = search_text(pred_dev2_text[k], crisis)

100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 6253.43it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 7020.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 6609.46it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 6518.22it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 7698.82it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 7467.92it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3977/3977 [00:00<00:00, 6175.67it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 3959/3959 [00:00<00:00, 6137.65it/s]
100%|███████████████████████████

In [106]:
crisis_dev_eval = evaluate_preds(crisis_query_terms, crisis_pred_dev_results, "dev")

model      | precision % | recall    %
2.5h       |        1.0 |        0.2
2.5h+asr   |        5.4 |        4.9
5h         |        3.8 |        0.4
5h+asr     |       10.3 |       10.7
10h        |       16.5 |       12.2
10h+asr    |       18.5 |       21.7
20h        |       16.5 |       12.2
20h+asr    |       25.6 |       32.2
50h        |       30.7 |       40.1
50h+asr    |       35.4 |       48.8
Weiss      |       42.8 |       78.0


In [107]:
crisis_dev2_eval = evaluate_preds(crisis_query_terms, crisis_pred_dev2_results, "dev2")

model      | precision % | recall    %
2.5h       |        0.9 |        0.2
2.5h+asr   |        4.2 |        4.7
5h         |        2.0 |        0.2
5h+asr     |        9.2 |       12.7
10h        |       14.1 |       12.7
10h+asr    |       16.9 |       21.1
20h        |       14.1 |       12.7
20h+asr    |       23.9 |       32.3
50h        |       27.4 |       38.2
50h+asr    |       31.4 |       47.1
Weiss      |       36.3 |       75.2


In [108]:
crisis_pred_dev_results["20h+asr"].keys()

dict_keys(['love', 'bomb', 'case', 'arrest', 'explos', 'report', 'amaz', 'flash', 'name', 'say', 'militari', 'nurs', 'respond', 'dead', 'effect', 'confirm', 'accept', 'effort', 'shot', 'stream', 'aftermath', 'launch', 'unfold', 'recov', 'donat', 'neglig', 'imag', 'buri', 'huge', 'blast', 'die', 'path', 'cleanup', 'hotlin', 'flood', 'cost', 'affect', 'tragic', 'second', 'terrifi', 'suppli', 'alert', 'twister', 'polic', 'number', 'anoth', 'wait', 'surg', 'injuri', 'free', 'lend', 'onlin', 'girl', 'black', 'death', 'homeown', 'txting', 'video', 'inund', 'thousand', 'chang', 'lost', 'devast', 'unaccount', 'cross', 'regist', 'hous', 'bushfir', 'tonight', 'memori', 'resid', 'thought', 'massiv', 'offic', 'prepar', 'run', 'want', 'rememb', 'dozen', 'evacue', 'recoveri', 'leg', 'send', 'deepen', 'citi', 'relief', 'surviv', 'join', 'reconnect', 'retweet', 'risk', 'terribl', 'heart', 'financi', 'magnitud', 'break', 'home', 'govern', 'million', 'time', 'reced', 'floodwat', 'rain', 'major', 'leav',

In [111]:
crisis_dev2_eval["50h+asr"]

{'aggr': {'prec': 0.3140495867768595,
  'rec': 0.47146401985111663,
  't': 403,
  'tc': 190,
  'tp': 605},
 'query': {'accept': {'prec': 0.0, 'rec': 0, 't': 0, 'tc': 0, 'tp': 5},
  'anoth': {'prec': 0.14814814814814814,
   'rec': 0.5,
   't': 8,
   'tc': 4,
   'tp': 27},
  'assist': {'prec': 0, 'rec': 0, 't': 0, 'tc': 0, 'tp': 0},
  'black': {'prec': 0.5, 'rec': 1.0, 't': 2, 'tc': 2, 'tp': 4},
  'bodi': {'prec': 0, 'rec': 0.0, 't': 3, 'tc': 0, 'tp': 0},
  'break': {'prec': 0.0, 'rec': 0.0, 't': 1, 'tc': 0, 'tp': 1},
  'brought': {'prec': 0.0, 'rec': 0.0, 't': 1, 'tc': 0, 'tp': 1},
  'case': {'prec': 0.08333333333333333, 'rec': 1.0, 't': 1, 'tc': 1, 'tp': 12},
  'chang': {'prec': 0.35294117647058826,
   'rec': 0.75,
   't': 8,
   'tc': 6,
   'tp': 17},
  'come': {'prec': 0.1111111111111111, 'rec': 0.25, 't': 8, 'tc': 2, 'tp': 18},
  'communiti': {'prec': 1.0, 'rec': 0.75, 't': 4, 'tc': 3, 'tp': 3},
  'death': {'prec': 0, 'rec': 0, 't': 0, 'tc': 0, 'tp': 0},
  'die': {'prec': 0.5, 'rec':