In [1]:
from src.datasets import TextConcatPosts, TextConcatFactCheck
from src.models import EmbeddingModel

tasks_path = "data/splits/tasks_no_gs_overlap.json" # This is the file that will be replaced in the final version with new split
posts_path = "data/complete_data/posts.csv"
fact_checks_path = "data/complete_data/fact_checks.csv"
gs_path = "data/complete_data/pairs.csv"
langs = ['fra', 'spa', 'eng', 'por', 'tha', 'deu', 'msa', 'ara']

def succ_at_k(df, k, group=True):
    if group:
        return df.apply(lambda x: len(list((set(x["preds"][:k]) & set(x["gs"])))) > 0, axis=1).mean()
    else:
        return df.explode("gs").apply(lambda x: x["gs"] in x["preds"][:k], axis=1).mean()
    
def print_succ_at_k(df, k, group=True):
    print(f"S@{k} (group)", succ_at_k(df, k, group=True))
    print(f"S@{k} (explode)", succ_at_k(df, k, group=False))


lang = "fra"
    
print("\n\nProcessing", lang)
posts = TextConcatPosts(posts_path, tasks_path, task_name="monolingual", gs_path=gs_path, lang=lang)
fact_checks = TextConcatFactCheck(fact_checks_path, tasks_path, task_name="monolingual", lang=lang)

df_fc = fact_checks.df
df_posts_train = posts.df_train
df_posts_dev = posts.df_dev

model_name = '/home/bsc/bsc830651/.cache/huggingface/hub/models--intfloat--multilingual-e5-large/snapshots/ab10c1a7f42e74530fe7ae5be82e6d4f11a719eb'

model = EmbeddingModel(model_name, df_fc, batch_size=512)

df_posts_dev["preds"] = model.predict(df_posts_dev["full_text"].values).tolist()

print_succ_at_k(df_posts_dev, 1)
print_succ_at_k(df_posts_dev, 5)
print_succ_at_k(df_posts_dev, 10)


  from tqdm.autonotebook import tqdm, trange




Processing fra


Batches: 100%|██████████| 9/9 [00:06<00:00,  1.50it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.23s/it]

S@1 (group) 0.7320261437908496
S@1 (explode) 0.7225806451612903
S@5 (group) 0.8627450980392157
S@5 (explode) 0.864516129032258
S@10 (group) 0.8758169934640523
S@10 (explode) 0.8774193548387097





In [2]:
lang = "por"
    
print("\n\nProcessing", lang)
posts = TextConcatPosts(posts_path, tasks_path, task_name="monolingual", gs_path=gs_path, lang=lang)
fact_checks = TextConcatFactCheck(fact_checks_path, tasks_path, task_name="monolingual", lang=lang)

df_fc = fact_checks.df
df_posts_train = posts.df_train
df_posts_dev = posts.df_dev

model_name = '/home/bsc/bsc830651/.cache/huggingface/hub/models--intfloat--multilingual-e5-large/snapshots/ab10c1a7f42e74530fe7ae5be82e6d4f11a719eb'

model = EmbeddingModel(model_name, df_fc, batch_size=512)

df_posts_dev["preds"] = model.predict(df_posts_dev["full_text"].values).tolist()

print_succ_at_k(df_posts_dev, 1)
print_succ_at_k(df_posts_dev, 5)
print_succ_at_k(df_posts_dev, 10)



Processing por


Batches: 100%|██████████| 43/43 [00:23<00:00,  1.81it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.33s/it]

S@1 (group) 0.39759036144578314
S@1 (explode) 0.38372093023255816
S@5 (group) 0.7951807228915663
S@5 (explode) 0.7906976744186046
S@10 (group) 0.8493975903614458
S@10 (explode) 0.8430232558139535





In [3]:
lang = "deu"
    
print("\n\nProcessing", lang)
posts = TextConcatPosts(posts_path, tasks_path, task_name="monolingual", gs_path=gs_path, lang=lang)
fact_checks = TextConcatFactCheck(fact_checks_path, tasks_path, task_name="monolingual", lang=lang)

df_fc = fact_checks.df
df_posts_train = posts.df_train
df_posts_dev = posts.df_dev

model_name = '/home/bsc/bsc830651/.cache/huggingface/hub/models--intfloat--multilingual-e5-large/snapshots/ab10c1a7f42e74530fe7ae5be82e6d4f11a719eb'

model = EmbeddingModel(model_name, df_fc, batch_size=512)

df_posts_dev["preds"] = model.predict(df_posts_dev["full_text"].values).tolist()

print_succ_at_k(df_posts_dev, 1)
print_succ_at_k(df_posts_dev, 5)
print_succ_at_k(df_posts_dev, 10)



Processing deu


Batches: 100%|██████████| 10/10 [00:08<00:00,  1.14it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]

S@1 (group) 0.32786885245901637
S@1 (explode) 0.31746031746031744
S@5 (group) 0.5901639344262295
S@5 (explode) 0.6031746031746031
S@10 (group) 0.6721311475409836
S@10 (explode) 0.6825396825396826





In [4]:
len(df_posts_train), len(df_posts_dev)

(606, 61)