In [None]:
import pandas as pd
import seaborn as sns
import os, json
from collections import Counter
from itertools import chain, combinations, cycle

import matplotlib.pyplot  as plt
import seaborn as sns
from skimage import io


import nltk

In [None]:
pd.set_option("display.max_rows", None, "display.max_columns", None)

In [None]:
black_list = ['i', 'background', '/', 'image', 'icon', 'illustration', 'view', 'garden']

In [None]:
cc = pd.read_csv('../data/gcc_train.tsv', sep='\t', names=['Caption', 'URL'])
# print(cc.shape)


In [None]:
def show_img(url):
    
    image=io.imread (url)

    plt.imshow(image)
    plt.show()

In [None]:
def show_sent(row_id):
    
    r = cc.loc[row_id]
    caption, url = r[0], r[1]
    
    print(caption)
    print(url)
    
    return caption, url

In [None]:
def find_all_occ( label_list ):
    
    dcc = cc.copy()
    
    for label in label_list:
        dcc = dcc[
            dcc["Caption"].apply( lambda x : f' {label} ' in x )
        ]
    
    return dcc


# find_all_occ(['stop sign', 'icon']).head()

## Building Context

In [None]:
def extract_nouns(row, noun=1, v=False):
    
    nouns = []
    
    if v:
        print(row)
        
    for t in row:
        if t[1].startswith('NN'):
            nouns.append(t[0])
    
    return nouns


def build_ond_for_label(label_list, nouns_needed=1, v=False):
    
    subset = find_all_occ( label_list )

    batch = subset['Caption'].apply(nltk.word_tokenize).apply(nltk.pos_tag)

    ond = pd.DataFrame(
        {
            'Tagged': batch,
            'Nouns': batch.apply(extract_nouns)
            
        })
    
    
    
    ond['GoodNounsNumber'] = ond['Nouns'].apply(lambda x: len(x)==nouns_needed)
    
#     return ond

    ond = ond[ond['GoodNounsNumber'] == True]
    
    if v:
        print('\nCaptions with ', nouns_needed, ' nouns that include the word "', label, '" found: ', ond.shape[0], sep='')

    return ond[['Tagged', 'Nouns']]

In [None]:
def capt_hash(noun_list):
    return sum([ hash(noun) for noun in noun_list ])
    
capt_hash(['chair', 'porch']) == capt_hash(['porch', 'chair'])

In [None]:
def remove_tail( df, v=False, rt=2 ):
    
    for i in range(rt):
    
        mm = min(df['Freq'])
        past_len_df = len(df)

        df = df[ df['Freq'] > mm ]

        if v:
            print('Filtered', past_len_df-len(df), 'pairs')
            print('Now minimum freq is', df['Freq'].min())
            print('Remaining samples:', len(df), '\n')
    
    freq_dict = {}
    tag_dict = {}

    for i, row in df.iterrows():
        
        for noun in row['ContextNouns']:
            if not (noun in freq_dict.keys()):
                freq_dict[noun] = row['Freq']*100
            if not (noun in tag_dict.keys()):
                tag_dict[noun] = []
            tag_dict[noun].append( row['Tagged'] )
        
    
    return tag_dict, freq_dict

In [None]:

# scope == how many nouns in a sentence

def build_context( label, scope=2, rt=2, v=False):
    
    
    # FIXED NOUN NUMBER DATASET 
    b = build_ond_for_label( [label], nouns_needed=scope)
    
    # HASHING
    b['ContextNouns'] = b['Nouns'].apply(
        lambda x: [n for n in x if n != label ]
    )
    b['Hash'] = b['ContextNouns'].apply(capt_hash)
    
    
    # FREQ DISTRIBUTION
    dist = b['Hash'].value_counts()
    sum_dist = sum(dist)

    

    b['Freq'] = b['Hash'].apply(lambda h: dist[h] / sum_dist)

    
    b = b.sort_values(by=['Freq'], ascending=False) 
    
    if v:
        print(f'\n========================\n{label} orig samples:', len(b), '\n========================\n')
    
    
    tag_dict, freq_dict = remove_tail( b, v, rt )

    for persona_non_grata in black_list:
        tag_dict.pop(persona_non_grata, None)
        freq_dict.pop(persona_non_grata, None)
    
    print(f'Built context for "{label}": {len(b)} sentences, {len(freq_dict)} unique nouns.')
    
    return {

        'freq_dict': freq_dict,
        'tag_dict': tag_dict,
        'noun_set': set( freq_dict.keys() )
    }



    

In [None]:
# build_context('dining table')['tag_dict']['background']

## Umbrella Test

In [None]:
# umb_context = build_context( 'umbrella', v=True )
# umb_context

## Intersection

In [None]:
def find_intersect( label_list, context_base ):
    
    
    sub_dict = { label:context_base[label] for label in label_list }

    set_list = [ v['noun_set'] for v in sub_dict.values() ]
    
    cum_intersection = set_list[0].intersection( *set_list[1:] )
    
    df = pd.DataFrame( cum_intersection, columns=['mutual_label'] )
    
    for label, context in sub_dict.items():
    
        fd = context['freq_dict']
        
        df[ f'freq_with_{label}' ] = df['mutual_label'].apply( lambda l: fd[l] )
    
    
    df[ 'mutual_freq' ] = df.iloc[:, 1:].prod( axis=1 )
    
    df = df.sort_values(by=['mutual_freq'], ascending=False)
    
    
    
#     df = df[ ~ df['mutual_label'].isin( black_list ) ].reset_index(drop=True)
    return df


In [None]:
color_cycle = cycle(
    ['#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
)

In [None]:
def both_context_bases( label_list, cb_dict, allcolor=False, v=True ):
    
    color = '#1f77b4' if allcolor else next( color_cycle )
    
    h = 10
    inter2 = find_intersect( label_list, cb_dict['2'] ).head(h)
    inter3 = find_intersect( label_list, cb_dict['3'] ).head(h)
    
    fig, axs = plt.subplots(2, figsize=(10,4))
    fig.suptitle(label_list)
    
    axs[0].bar( inter2['mutual_label'], inter2['mutual_freq'].values, color=color )
    axs[1].bar( inter3['mutual_label'], inter3['mutual_freq'].values, color=color )
    
    vote = set(inter2['mutual_label']).intersection( set(inter3['mutual_label']) )
    
    top_freq3 = set(inter3[ inter3['mutual_freq'] >= 1 ]['mutual_label'])
    top_freq2 = set(inter2[ inter2['mutual_freq'] >= 1 ]['mutual_label'])
    tf = top_freq2.union(top_freq3) #.difference( vote )
    
    if v:
        print(f'\n\n{", ".join(label_list)} (Context Base 2)')
        print('', 13 * '=' * inter2.shape[1] )
        print(inter2.head(h))
        print('\n')
        print(f'{", ".join(label_list)} (Context Base 3)')
        print(inter3.head(h))
        print('\nVote:', vote)
        print('Top Frequency:', tf)
    
        
    return vote, tf


In [None]:
person_names = {
    'person',
    'man',
    'woman',
    'artist',
    'people',
    'boy',
    'girl'
}

In [None]:
def build_bow( label_list, cb_dict, v=True ):
    
    bow, tf = both_context_bases( label_list, cb_dict, allcolor=True, v=v )
    
    bow_org = bow
    
    
    
    # POSSIBLE NOUN SETS
    
    pns_list = []
    
    for pair in combinations(label_list, 2):
        
        vote, tf_small = both_context_bases( pair, cb_dict, v=v )
        pns_list.append( vote )
        bow = bow.union( tf_small )
#         print(f'Adding {tf_small} to BoW')
        
    
    for pns_pair in combinations(pns_list, 2):
        
        common = pns_pair[0].intersection( pns_pair[1] )
        bow = bow.union(common)
#         print(f'Adding {common} to BoW')
        
    bow = bow.union(tf)
#     print(f'Adding {tf} to BoW')
    
    peoples = person_names.intersection( bow )
    bow = bow.difference( peoples )
    
    
    bow = bow.union( set(label_list) )
    
    
    if v:
        print(f'\nOriginal labels ({len(label_list)} total):', *label_list)
        print(f'Original vote ({len(bow_org)} total):', bow_org)
#     print('Added by initial vote:', bow_org.difference(set(label_list)))
    
    
        print('\nModified vote:', bow)
        d = bow.difference( bow_org )
        print(f'\nAdded ({len(d)} total):', d)
    
        print('\nPerson names found and extracted:', peoples )
    
        print(f'\nAdding original labels to BoW ({len(bow)} total)')
    
    
    
    
    
    return( bow, peoples )

## Export

In [None]:
def get_full_context( class_list, rt=0, v=True ):
    
    return {
        
        str(i):
        
        {
            class_name: build_context(class_name, scope=i, rt=rt, v=v)
            for class_name in set(class_list)
        }
        
        for i in (2, 3)
    }

In [None]:
print('''
Imported functions for use outside:

\tget_full_context( class_list, rt=0, v=True )

\tbuild_bow( label_list, cb_dict, v=True )

\tfind_all_occ( label_list )

''')