# Analyze based on semantic categories

1.) change tfidf so we compare equivalent categories only - done
2.) update ranking accordingly

In [1]:
import os
from collections import Counter, defaultdict
import csv

In [3]:
f_original = os.listdir('../contexts/giga_full/vocab')
print(len(f_original))
f_update = os.listdir('../contexts/giga_full_updated/vocab')
print(len(f_update))

1446
1636


In [2]:
f_original = os.listdir('../contexts/wiki/vocab')
print(len(f_original))
f_update = os.listdir('../contexts/wiki_updated/vocab')
print(len(f_update))

1669
1874


In [11]:

def get_categories(prop, model_name):
    analysis_type = 'tfidf-raw-10000/each_target_vs_corpus_per_category'
    path_dir = f'../results/{model_name}/{analysis_type}'
    path_dir = f'{path_dir}/{prop}'
    categories = set()
    for d in os.listdir(path_dir):
        categories.add(d)
    return categories

def get_context_cnts(prop, cat, label, model_name):
    
    analysis_type = 'tfidf-raw-10000/each_target_vs_corpus_per_category'
    path_dir = f'../results/{model_name}/{analysis_type}'
    path_dir = f'{path_dir}/{prop}'
    path_label = f'{path_dir}/{cat}/{label}'
    
    context_cnt = Counter()
    for f in os.listdir(path_label):
        full_path = f'{path_label}/{f}'
        if full_path.endswith('.csv'):
            with open(full_path) as infile:
                data = list(csv.DictReader(infile))
            for d in data:
                context = d['']
                diff = float(d['diff'])
                if diff > 0:
                    context_cnt[context] += 1
    return context_cnt
    
def get_n_concepts_total(prop, cat, label, model_name):
    
    analysis_type = 'tfidf-raw-10000/each_target_vs_corpus_per_category'
    path_dir = f'../results/{model_name}/{analysis_type}'
    path_dir = f'{path_dir}/{prop}'
    path_pos = f'{path_dir}/{cat}/{label}'
    path_neg = f'{path_dir}/{cat}/{label}'
    
    files_pos = [f for f in os.listdir(path_pos) if f.endswith('.csv')]
    files_neg = [f for f in os.listdir(path_neg) if f.endswith('.csv')]
    
    return len(files_pos), len(files_neg)


def get_f1_distinctiveness(n_pos, n_neg, total_pos, total_neg):
    
    tp = n_pos
    tn = total_neg - n_neg
    fp = n_neg
    fn = total_pos - n_pos
    
    if tp+fp != 0:
        p = tp/(tp+fp)
    else:
        p = 0
    if tp+fn != 0:
        r = tp/(tp+fn)
    else:
        r = 0
    
    if p+r != 0:
        f1 = 2 * ((p*r)/(p+r))
    else:
        f1=0
    
    return p, r, f1

    
def aggregate_contexts(prop, cutoff, model_name):
    
    annotation_name = 'annotation-tfidf-top20-raw-10000-categories'
    aggregation_name = 'aggregated-tfidf-top20-raw-10000-categories'

    path_dir_agg = f'../analysis/{model_name}/{aggregation_name}/{prop}'
    
    os.makedirs(path_dir_agg, exist_ok = True)
    
    context_cnts_all = Counter()
    context_cat_dict = defaultdict(set)

    cats = get_categories(prop, model_name)

    for cat in cats:
        context_cnts_pos = get_context_cnts(prop, cat, 'pos', model_name)
        context_cnts_neg = get_context_cnts(prop, cat, 'neg', model_name)
        total_pos, total_neg = get_n_concepts_total(prop, cat, label, model_name)
        
        context_f1_dict = Counter()
        context_score_dict = defaultdict(dict)
        
        # get distinctiveness
        for c, cnt_pos in context_cnts_pos.most_common():
            cnt_neg = context_cnts_neg[c]
            p, r, f1 = get_f1_distinctiveness(cnt_pos, cnt_neg, total_pos, total_neg)
            context_f1_dict[context] = f1
            context_score_dict[context] = {'p': p,'r':r, 'f1': f1}
        
        table = []
        for c, f1 in context_f1_dict.most_common():
            scores = context_score_dict[c]
            d = dict()
            d['context'] = c
            d.update(scores)
            d['n_pos'] = context_cnts_pos[c]
            d['total_pos'] = total_pos
            d['n_neg'] = context_cnts_neg[c]
            d['total_neg'] = total_neg
            table.append(d)
        
        # collect and write to file
        f = f'{path_dir_agg}/{cat}.csv'
        
        header = table[0].keys()
        with open(f, 'w') as outfile:
            writer = csv.DictWriter(fieldnames = header)
            writer.writeheader()
            for d in table:
                writer.writerow(d)
        

                
def prepare_annotation(prop, model_name):
    
    annotation_name = 'annotation-tfidf-top20-raw-10000-categories'
    os.makedirs(path_dir_agg, exist_ok = True)
    
    # get categories
    cats = get_categories(prop, model_name)
    
    # collect all contexts and categories 
    context_cats_dict = defaultdict(set)
    
    # load top 20 per category
    for 
    
    
    
    
    
                
# def prepare_annotation(prop, label, cutoff, model_name)

#     annotation_name = 'annotation-tfidf-top20-raw-10000-categories'
#     aggregation_name = 'aggregated-tfidf-top20-raw-10000-categories'
#     path_dir = f'../analysis/{model_name}/{annotation_name}/{prop}-{label}'

    
#     os.makedirs(path_dir_agg, exist_ok = True)
    
#     context_cnts_all = Counter()
#     context_cat_dict = defaultdict(set)

#     cats = get_categories(prop, model_name)

#     for cat in cats:
#         context_cnts = get_context_cnts(prop, cat, label, model_name)
#         # collect and write to file
#         f = f'{path_dir_agg}/{cat}.csv'
#         with open(f, 'w') as outfile:
#             outfile.write('context')
#             for c, cnt in context_cnts.most_common(cutoff):
#                 context_cat_dict[c].add(cat)

#     os.makedirs(path_dir, exist_ok=True)
#     with open(f'{path_dir}/annotation.csv', 'w') as outfile:
#         outfile.write('context,evidence_type,categories\n')
#         for c, cats in context_cat_dict.items():
#             line = f'{c}, ,{" ".join(cats)}\n'
#             outfile.write(line)
        

In [14]:
model_name = 'giga_full_updated'
prop = 'lay_eggs'
label = 'pos'
cutoff = 20
aggregate_contexts(prop, cutoff, model_name)
        
# for c, cnt in context_cnts_all.most_common():
#     print(c, cnt, context_cat_dict[c])

In [52]:
## test if word not in pairs

path = '/Users/piasommerauer/Data/dsm/corpus_exploration/giga_full/counts/pairs'


with open(path) as infile:
    for line in infile:
        w, c = line.strip().split(' ')
        if w in ['hose', 'hoursefly', 'bra', 'dress']:
            print(w)
            break

hose
