In [None]:
import pygrove
from tqdm.auto import tqdm
import py_stringmatching as sm
from sentence_transformers import SentenceTransformer
from IPython.utils import io
import requests
tqdm.pandas()
from functools import cache
from nltk.tokenize import WordPunctTokenizer

author = '' #anonymized

In [None]:
sources= ['medwiki','wikem','wikidoc']

In [None]:
def parse(source):
    df=pd.read_csv(f's3://{author}/{source}_sections.csv')
    df['diff_diag_text']=df.diff_diag_text.map(eval)
    df=df.fillna('')
    df['title']=df['title'].map(lambda x: (x.split('(')[0] if not x.startswith('(') else x))
    return df 

# Distractor retrieval

In [None]:
with io.capture_output() as captured:
    model = SentenceTransformer(f"{author}/distractors_prediction")

In [None]:
def get_distractors(df,key='title',n_distractors=50):

    answers=list(set(df[key]))
    Hd = model.encode([{'DOC':x} for x in answers],normalize_embeddings=True)
    Hq = model.encode([{'QRY':x} for x in answers],normalize_embeddings=True)
    distractors=defaultdict(list)
    chunks= np.array_split(range(len(answers)),1000)
    for chunk in tqdm(chunks):
        scores=Hq[chunk]@Hd.T
        for i,a in enumerate(chunk):
            answer=answers[a]
            distractors[f'AVG_{answer}'] = np.argsort(scores[i])[::-1][:n_distractors].mean()
            for j in np.argsort(scores[i])[::-1][:n_distractors]:
                if answer not in answers[j] and answers[j] not in answer:
                    distractors[answer]+= [answers[j]]
    return dict(distractors)

In [None]:
def interlace(l1,l2):
    l = [a for a in fc.flatten(itertools.zip_longest(l1,l2)) if a]
    return list(dict.fromkeys(l))

def make_qa(df):
    key='title'
    distractors=get_distractors(df,key)
    df['distractors']=df[key].progress_map(lambda x: distractors[x][:10])
    df['options']=df.apply(lambda x: [x[key]] + interlace([a for a in x.diff_diag_text if a], x.distractors)
                ,axis=1)
    df['options_no_diffdiag']=df.apply(lambda x: [x[key]] + x.distractors,axis=1)
    df.options=df.options.map(lambda x:[x[0]]+[a for a in x if a!=x[0]])
    df['label']=0
    return df

# Masking

In [None]:
def randargmax(b,**kw):
    b=np.array(b)
    return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw)

from nltk.corpus import stopwords
stopwords=stopwords.words('english')


oc = sm.similarity_measure.overlap_coefficient.OverlapCoefficient()
sim = lambda x,y: oc.get_sim_score(*[sm.AlphabeticTokenizer().tokenize(a.lower()) for a in (x,y)])
def diagnose(df, key='text',n=1000):
    return df.sample(n).apply(lambda x: randargmax([sim(x[key], o) for o in x.options][:6]),axis=1).value_counts(normalize=True)

In [None]:
def fuzzy_mask(s,to_mask,tol=0):
    s=' '+s
    import regex
    to_mask=regex.escape(to_mask)
    return regex.sub(f'[\s(]({to_mask}){{e<={tol}}}\)?','<unk> ', s ,flags=regex.IGNORECASE)
    m = regex.findall(f'[\s(]({to_mask}){{e<={tol}}}\)?', s ,flags=regex.IGNORECASE)
    for x in m:
        if type(x)==tuple:
            for e in x:
                s=s.replace(e,' <unk>')
        else:
            s=s.replace(x,' <unk>')
    return s.strip()

def apply_mask(x,n_options=8, naive=False):
    
    s=x.text
    tok = sm.AlphabeticTokenizer().tokenize
    tok= WordPunctTokenizer().tokenize
    suspects= tok(x.options[0].lower())
    suspects=[x.options[0]]+ suspects
            
    for word in suspects:
        if word in stopwords or word in string.punctuation:
            if naive:
                s=fuzzy_mask(s, word)
            continue
        p=" ".join(x.options[1:n_options]).lower().count(word)
        if random.random()>1/(1+p):
            continue
        else:
            s=fuzzy_mask(s, word)
    return s.strip()


def apply_mask_token(x, naive=False):
    s=x.text
    tok = lambda s: tz(s, add_special_tokens=False).input_ids
    suspects= tok(x.options[0].lower())
    l=tok(s)
    unk_id=tz.convert_tokens_to_ids('##unk')
    for word in suspects:
        if naive:
            l=[(w if w not in suspects else unk_id) for w in l]
        continue
        p=tok(" ".join(x.options[1:6]).lower()).count(word)
        if random.random()>1/(1+p):
            continue
        else:
            l=[(w if w not in suspects else unk_id) for w in l]
    return tz.decode(l)


# Build and export 

In [None]:
for source in ['wikem','medwiki','wikidoc']:
    df=parse(source)
    df=make_qa(df)
    df['unmasked']=df['text']
    df['text']=df.progress_apply(apply_mask,axis=1)
    df['text_naive']=df.progress_apply(lambda x: apply_mask(x,naive=True),axis=1)
    print(source, len(df))
    for key in ['text','text_naive']:
        print(key,diagnose(df, key=key))
    df[['text','unmasked','text_naive','options','options_no_diffdiag','label']].to_csv(f's3://{author}/mc_{source}.csv', index=False)