In [None]:
import sys;sys.path.append('..')
from ppanlp import *
ppa = PPA()

In [None]:
from string import punctuation


class NERModel:
    def __init__(self, corpus=None):
        self.corpus = corpus if corpus is not None else PPA()
        self.ent2pages=defaultdict(set)
        self.ent2count=Counter()
        self._annodf=None
        
    @staticmethod
    def clean_ent(ent):
        o=ent.strip(punctuation).title()
        if o.endswith("'S"): o = o[:-2]
        return o
    
    @staticmethod
    def iter_by_page(iterr):
        last_pageid=None
        last_l=[]
        for res in iterr:
            pageid=res[0]
            if last_l and pageid!=last_pageid:
                yield (pageid,last_l)
                last_l=[]
            last_l.append(res[1:])
            last_pageid = pageid
        if last_l: yield last_l
        
    def iter_ents(self,ent_types:set=None,lim=None,by_page=False,ents=None):
        def iterr():
            with self.corpus.ents_db(flag='r') as db:
                total=len(db)
                iterr=tqdm(db.items(),desc='Iterating over saved ents',position=0,total=total)
                for page_id,page_ents in iterr:
                    for ent,ent_type in page_ents:
                        ent = self.clean_ent(ent)
                        if (not ent_types or ent_type in ent_types) and (not ents or ent in ents):
                            yield page_id,ent,ent_type
        oiterr = (self.iter_by_page(iterr()) if by_page else iterr())
        yield from iterlim(oiterr,lim)
    
                            
    def iter_persons(self, **kwargs):
        kwargs['ent_types']={'PERSON'}
        yield from self.iter_ents(**kwargs)
    
    def count_ents(self, **kwargs):
        self.ent2count=Counter()
        for res in self.iter_ents(**kwargs):
            page_id,ent = res[:2]
            self.ent2pages[ent].add(page_id)
            self.ent2count[ent]+=1
        self.ent2count_s = pd.Series(self.ent2count).sort_values(ascending=False)
        return self.ent2count_s
    
    def count_persons(self, **kwargs):
        kwargs['ent_types']={'PERSON'}
        return self.count_ents(**kwargs)
    
    def prep_anno_df(self, min_count=100):
        s = ner.ent2count_s
        s = s[s>=min_count]
        df = pd.DataFrame({'count':s}).rename_axis('name')
        df['is_valid'] = ''
        return df
    
    @cached_property
    def path_to_anno(self): return os.path.join(self.corpus.path_data, 'data.ner.to_anno.csv')
    @cached_property
    def path_anno(self): return os.path.join(self.corpus.path_data, 'data.ner.anno.csv')
    
    def load_anno_df(self, fn=None, force=False):
        if force or self._annodf is None:
            fn=fn if fn else self.path_anno
            self._annodf = pd.read_csv(fn).set_index('name').fillna('')
        return self._annodf
    
    @cached_property
    def anno_df(self): return self.load_anno_df()

    @cached_property
    def anno_ents(self): 
        df=self.anno_df
        df=df[df.is_valid.str.startswith('y')]
        return set(df.index)

    def iter_ents_anno(self, **kwargs):
        kwargs['ents']=self.anno_ents
        yield from self.iter_ents(**kwargs)

    def iter_persons_anno(self, **kwargs):
        kwargs['ent_types']={'PERSON'}
        yield from self.iter_ents_anno(**kwargs)

    def link_persons(self, min_page_count=2,**kwargs):
        import networkx as nx
        last_pageid=None
        last_ents = []
        G = nx.Graph()
        l = [x for x in self.iter_persons_anno(by_page=True) if len(x)==2]
        for pageid,page_ents in l:
            for a1,b1 in page_ents:
                for a2,b2 in page_ents:
                    if a1<a2:
                        if G.has_edge(a1,a2): 
                            G.edges[a1,a2]['weight']+=1
                        else: 
                            G.add_edge(a1,a2,weight=0)
        
        bad=[(a,b) for a,b,d in G.edges(data=True) if d['weight']<min_page_count]
        G.remove_edges_from(bad)
        return G
    
#     def person_cooccurence_matrix(self, min_page_count=2,**kwargs):
#         persons = defaultdict(Counter)
#         numpages = 0
#         for pagedata in self.iter_persons_anno(by_page=True):
#             if len(pagedata)!=2: continue 
#             numpages+=1
#             pageid,pageents = pagedata
#             pageents = {x[0] for x in pageents}
#             for x in pageents:
#                 for y in pageents:
#                     if x!=y:
#                         persons[x][y]+=1
#         df = pd.DataFrame(persons).fillna(0) / numpages * 1000
#         return df
        
    def person_cooccurence(self, min_page_count=5,**kwargs):
        person1 = Counter()
        person2 = Counter()
        numpages = 0
        pair_pages = defaultdict(set)
        for pagedata in self.iter_persons_anno(by_page=True):
            if len(pagedata)!=2: continue 
            numpages+=1
            pageid,pageents = pagedata
            pageents = {x[0] for x in pageents}
            for x in pageents:
                person1[x]+=1
                for y in pageents:
                    if x<y:
                        person2[x,y]+=1
                        pair_pages[x,y].add(pageid)
        
        # calc probs
        person1_sum = sum(person1.values())
        person2_sum = sum(person2.values())
        person1_probs = {k:v/person1_sum for k,v in person1.items()}
        person2_probs = {k:v/person2_sum for k,v in person2.items()}

        o=[]
        for pair in tqdm(person2_probs):
            if min_page_count and person2[pair]<min_page_count: continue
            p1,p2 = pair
            p1_prob,p2_prob = person1_probs[p1],person1_probs[p2]
            prob_exp = p1_prob * p2_prob
            prob_obs = person2_probs[pair]
            od={
                'person1':p1, 'person2':p2, 
                'count1':person1[p1], 'count2':person1[p2], 'count_both':person2[pair],
                'prob1':p1_prob, 'prob2':p2_prob, 
                'prob_exp':prob_exp, 'prob_obs':prob_obs, 'obsexp':prob_obs/prob_exp,
                'pair_pages':list(pair_pages[pair])
            }
            o.append(od)
        return pd.DataFrame(o).sort_values('obsexp',ascending=False)
            
            
            
        
            
                    
        

In [None]:
ner = NERModel(ppa)

In [None]:
df = ner.person_cooccurence()
df

In [None]:
pdf = df.query('obsexp>=10 & count_both>=10')
pdf

In [None]:
# adding topic models
qd=dict(min_doc_len=25,max_per_cluster=50,frac=1)
tm = ppa.topic_model(model_type='bertopic', **qd)
tm.mdl

In [None]:
tdf = tm.mdl.get_topic_info()
tdf.columns = [x.lower() for x in tdf]
tdf['representative_docs_ids']=[[tm.doc2id[doc] for doc in docs] for docs in tdf.representative_docs]
tdf

In [None]:
import networkx as nx
G=nx.Graph()
pages_tdf = {pageid for pages in tdf.representative_docs_ids for pageid in pages}
pages_pdf = {pageid for pages in pdf.pair_pages for pageid in pages}
pages_both = pages_tdf & pages_pdf

# topic -> doc
for tname,tids in zip(tdf.name, tdf.representative_docs_ids):
    for tid in tids[:1]:
        if tid in pages_both:
            G.add_edge(tname,tid)

# doc -> person
nodes = set(G.nodes())
for pageid,person,_ in ner.iter_persons_anno():
    if pageid in pages_both:
        G.add_edge(pageid,person)

In [None]:
G.order(),G.size()

In [None]:
from pyvis.network import Network
nt = Network(notebook=True, cdn_resources='in_line')
nt.from_nx(G)
nt.show('tmp.nx.html')