In [None]:
import sys;sys.path.append('..')
from ppanlp import *
ppa = PPA()

In [None]:
# adding topic models
qd=dict(min_doc_len=25,max_per_cluster=50,frac=1)
tm = ppa.topic_model(model_type='bertopic', **qd)
tdf = tm.mdl.get_topic_info()
tdf.columns = [x.lower() for x in tdf]
tdf['representative_docs_ids']=[[tm.doc2id[doc] for doc in docs] for docs in tdf.representative_docs]
docinfo = tm.mdl.get_document_info(tm.docs)
docinfo['page_id']=[tm.doc2id[doc] for doc in docinfo.Document]
page2topic = dict(zip(docinfo.page_id,docinfo.Name))
# page2topic

In [None]:
from string import punctuation
from scipy.stats.contingency import odds_ratio
from scipy.stats import fisher_exact


class NERModel:
    def __init__(self, corpus=None):
        self.corpus = corpus if corpus is not None else PPA()
        self.ent2pages=defaultdict(set)
        self.ent2count=Counter()
        self._annodf=None
        
    @staticmethod
    def clean_ent(ent):
        o=ent.strip(punctuation).title()
        if o.endswith("'S"): o = o[:-2]
        return o
    
    @staticmethod
    def iter_by_page(iterr):
        last_pageid=None
        last_l=[]
        for resd in iterr:
            pageid=resd['page_id']
            if last_l and pageid!=last_pageid:
                yield last_l
                last_l=[]
            last_l.append(resd)
            last_pageid = pageid
        if last_l: yield last_l
        
    def iter_ents(self,ent_types:set=None,lim=None,by_page=False,ents=None):
        def iterr():
            with self.corpus.ents_db(flag='r') as db:
                total=len(db)
                iterr=tqdm(db.items(),desc='Iterating over saved ents',position=0,total=total)
                for page_id,page_ents in iterr:
                    for ent,ent_type in page_ents:
                        ent = self.clean_ent(ent)
                        if (not ent_types or ent_type in ent_types) and (not ents or ent in ents):
                            yield {'page_id':page_id,'ent':self.anno_ents.get(ent,ent),'ent_type':ent_type,'ent_orig':ent}
        oiterr = (self.iter_by_page(iterr()) if by_page else iterr())
        yield from iterlim(oiterr,lim)
    
                            
    def iter_persons(self, **kwargs):
        kwargs['ent_types']={'PERSON'}
        yield from self.iter_ents(**kwargs)
    
    def count_ents(self, **kwargs):
        self.ent2count=Counter()
        for resd in self.iter_ents(**kwargs):
            page_id,ent = resd['page_id'],resd['ent']
            self.ent2pages[ent].add(page_id)
            self.ent2count[ent]+=1
        self.ent2count_s = pd.Series(self.ent2count).sort_values(ascending=False)
        return self.ent2count_s
    
    def count_persons(self, **kwargs):
        kwargs['ent_types']={'PERSON'}
        return self.count_ents(**kwargs)
    
    def prep_anno_df(self, min_count=100):
        s = ner.ent2count_s
        s = s[s>=min_count]
        df = pd.DataFrame({'count':s}).rename_axis('name')
        df['is_valid'] = ''
        return df
    
    @cached_property
    def path_to_anno(self): return os.path.join(self.corpus.path_data, 'data.ner.to_anno.csv')
    @cached_property
    def path_anno(self): return os.path.join(self.corpus.path_data, 'data.ner.anno.csv')
    
    def load_anno_df(self, fn=None, force=False):
        if force or self._annodf is None:
            fn=fn if fn else self.path_anno
            self._annodf = pd.read_csv(fn).set_index('name').fillna('')
        return self._annodf
    
    @cached_property
    def anno_df(self): return self.load_anno_df()

    @cached_property
    def anno_ents(self): 
        df=self.anno_df
        df=df[df.is_valid.str.startswith('y')]
        return {
            k:(v if v else k)
            for k,v in zip(df.index, df.who)
        }

    def iter_ents_anno(self, **kwargs):
        kwargs['ents']=set(self.anno_ents.keys())
        yield from self.iter_ents(**kwargs)

    def iter_persons_anno(self, **kwargs):
        kwargs['ent_types']={'PERSON'}
        yield from self.iter_ents_anno(**kwargs)
        
    @cached_property
    def persons_anno_pagedata(self):
        return [pdata for pdata in self.iter_persons_anno(by_page=True) if pdata]
            
    def person_cooccurence(self, min_page_count=5, lim=None, funcs = [odds_ratio, fisher_exact], min_count=10, **kwargs):
        data = self.persons_anno_pagedata
        data = (
            random.sample(data,lim) 
            if lim and len(data)>lim 
            else data
        )
        toppref='TOPIC_'
        person1 = Counter()
        person2 = Counter()
        pair_pages = defaultdict(set)
        numpages=0
        allppl=set()
        for pagedata in data:
            numpages+=1
            pageid=pagedata[0]['page_id']
            pageppl = {d['ent'] for d in pagedata}
            topic=page2topic.get(pageid)
            if topic and not topic.startswith('-1'):
                pageppl.add(toppref+topic)
            for x in pageppl: 
                person1[x]+=1
                allppl.add(x)
                for y in pageppl:
                    if x<y:
                        person2[x,y]+=1
        
        def count_ind(x):
            return person1[x]
        def count_solo(x,y):
            return person1[x] - person2[x,y] - person2[y,x]
        def count_together(x,y):
            return person2[x,y]
        def count_neither(x,y):
            return numpages - count_together(x,y) - count_solo(x,y) - count_solo(y,x)

        person1sum=sum(person1.values())
        def prob_ind(x):
            return person1[x]/numpages
        def prob_obs(x,y):
            return person2[x,y]/numpages
        def prob_exp(x,y):
            return prob_ind(x) * prob_ind(y)
        def prob_obsexp(x,y):
            return prob_obs(x,y) / prob_exp(x,y)
        
        
        
        def get_contingency_table(x,y):
            tl=count_together(x,y)
            tr=count_solo(x,y)
            bl=count_solo(y,x)
            br=count_neither(x,y)
            return ((tl,tr),(bl,br))
        
        def iter_res():
            minc=min_count
            cmps = [(x,y) for x in allppl for y in allppl if x<y and count_ind(x)>=minc and count_ind(y)>=minc]
            for x,y in tqdm(cmps):
                val_d={
                    'person_x':x, 
                    'person_y':y, 
                    'num_total_x':count_ind(x), 
                    'num_total_y':count_ind(y), 
                    'num_solo_x':count_solo(x,y), 
                    'num_solo_y':count_solo(y,x),
                    'num_both_xy':count_together(x,y), 
                    'num_neither_xy':count_neither(x,y), 
                    'prob_x':prob_ind(x)*100,
                    'prob_y':prob_ind(y)*100,
                    'prob_xy_obs':prob_obs(x,y)*100,
                    'prob_xy_exp':prob_exp(x,y)*100,
                    'prob_xy_obsexp':prob_obsexp(x,y)*100,
                    'includes_topic':x.startswith(toppref) or y.startswith(toppref),
                }
                ctbl=get_contingency_table(x,y)
                for func in funcs:
                    res = func(ctbl)
                    method=func.__name__
                    stat=res.statistic if hasattr(res,'statistic') else None
                    pval=res.pvalue if hasattr(res,'pvalue') else None
                    if stat is not None: val_d[f'{method}'] = stat
                    if pval is not None: val_d[f'{method}_p'] = pval
                if val_d.get('fisher_exact_p',1)!=1: 
                    yield val_d
        o=list(iter_res())
        return pd.DataFrame() if not o else pd.DataFrame(o).query('fisher_exact_p!=1').sort_values('odds_ratio',ascending=False)

In [None]:
ner = NERModel(ppa)

In [None]:
df = ner.person_cooccurence(lim=100000, min_count=25)
df

In [254]:
df['prob_xy_obsexp_log']=df['prob_xy_obsexp'].apply(np.log10)

In [None]:
dfsig = df[df.fisher_exact_p<=.05]
# dfsig[dfsig.includes_topic]
dfsig_pos = dfsig[dfsig.odds_ratio>1]
dfsig_pos

In [256]:
import networkx as nx
G = nx.Graph()
for d in dfsig_pos.to_dict('record')[:500]:
    G.add_edge(d['person_x'], d['person_y'], weight=d['prob_xy_obsexp_log'], **d)
G.order(),G.size()

(377, 500)

In [257]:
from pyvis.network import Network
nt = Network(notebook=True, cdn_resources='in_line')
nt.from_nx(G)
nt.show('tmp.nx.html')

tmp.nx.html
