In [1]:
%load_ext autoreload

In [2]:
import sys

In [3]:
%autoreload

import json
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import tqdm

In [5]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')



In [6]:
def author_pair_similarity(documents_ids, documents):
    #print(documents)
    document_embeddings = model.encode(documents)
    #print(document_embeddings)
    pairwise_sim = cosine_similarity(document_embeddings, dense_output=True)
    #print(pairwise_sim)
    docs_sims = {}
    for i, di in enumerate(documents_ids):
        for j, dj in enumerate(documents_ids):
            if j <= i:
                continue
            docs_sims[(di, dj)] = pairwise_sim[i,j]
            
    return docs_sims

In [58]:
def extract_split_similarity_info(path, split, author_clm="authorIDs", max_sim=0.2):
    df_paths = Path(ds_path).glob("{}*.jsonl".format(split))
    dfs = {p: pd.read_json(p, lines=True) for p in df_paths}
    for p, df in dfs.items():
        df['authorID'] = df[author_clm].apply(lambda x:x[0])    
    
    df = pd.concat(list(dfs.values()))
    
    gdf = df.groupby('authorID').agg({'documentID': lambda x: list(x), 'fullText': lambda x: list(x)}).reset_index()
    gdf = gdf[gdf.documentID.str.len() > 1]

    gdf_sample = gdf.sample(10)
    pairwise_sims = []
    for idx, row in tqdm.tqdm(gdf_sample.iterrows()):
        pairwise_sims.append(author_pair_similarity(row['documentID'], row['fullText']))

    gdf_sample['pairwise_sims'] = pairwise_sims
    gdf_sample['max_sim'] = gdf_sample.pairwise_sims.apply(lambda sims: sorted(sims.items(), key=lambda x: x[1])[-1][1])
    gdf_sample = gdf_sample[['authorID', 'pairwise_sims', 'max_sim']]
    gdf_sample = gdf_sample[gdf_sample.max_sim < 0.2]
    filtered_authors = gdf_sample.authorID.tolist()

    for p, df in dfs.items():
        df = df[df.authorID.isin(filtered_authors)]
        with open(str(p).replace('.jsonl','_filtered.jsonl'), "w") as f:
            f.write(df.to_json(orient='records', lines=True))

    gdf_sample.to_json(path + '/' + split + '_info.json')
    return gdf_sample

In [59]:
ds_path = '/mnt/swordfish-pool2/milad/hiatus-data/performers-data/tmp-data/'
df_info = extract_split_similarity_info(ds_path, 'dev')

10it [00:00, 127.89it/s]
