In [2]:
from src.datasets import TextConcatFactCheck, TextConcatPosts
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

tasks_path = "data/complete_data/tasks.json"
posts_path = "data/complete_data/posts.csv"
fact_checks_path = "data/complete_data/fact_checks.csv"
gs_path = "data/complete_data/pairs.csv"
langs = ['fra', 'spa', 'eng', 'por', 'tha', 'deu', 'msa', 'ara']

class EmbeddingModel:
    def __init__(self, model_name, df_fc, device="cuda", show_progress_bar=True, batch_size=128, normalize_embeddings=True, k=10):
        self.model = SentenceTransformer(model_name, device=device)
        self.device = device
        self.show_progress_bar = show_progress_bar
        self.batch_size = batch_size
        self.normalize_embeddings = normalize_embeddings
        self.emb_fc = self.encode(df_fc["full_text"].values)
        self.pos_to_idx = {pos: idx for pos, idx in enumerate(df_fc.index)}
        self.k = k

    def encode(self, texts):
        return torch.tensor(self.model.encode(texts, device="cuda", show_progress_bar=self.show_progress_bar, 
                                              batch_size=self.batch_size, normalize_embeddings=self.normalize_embeddings))
    
    def similarity(self, emb1, emb2):
        return torch.mm(emb1, emb2.T).cpu().numpy()
    
    def predict(self, texts):
        arr1 = self.encode(texts)
        sim = self.similarity(arr1, self.emb_fc)
        idx_sim = np.argsort(sim, axis=1)[:, ::-1][:, :self.k]
        # Apply the function element-wise to the array
        vectorized_map = np.vectorize(lambda x: self.pos_to_idx.get(x, None))
        return vectorized_map(idx_sim)

model_name = '/home/bsc/bsc830651/.cache/huggingface/hub/models--intfloat--multilingual-e5-large/snapshots/ab10c1a7f42e74530fe7ae5be82e6d4f11a719eb'

d_out = {}
for lang in tqdm(langs, desc="Languages"):

    posts = TextConcatPosts(posts_path, tasks_path, task_name="monolingual", gs_path=gs_path, lang=lang)
    fact_checks = TextConcatFactCheck(fact_checks_path, tasks_path, task_name="monolingual", lang=lang)

    df_fc = fact_checks.df
    df_posts_train = posts.df_train
    df_posts_dev = posts.df_dev

    model = EmbeddingModel(model_name, df_fc)

    df_posts_dev["preds"] = model.predict(df_posts_dev["full_text"].values).tolist()
    d_out.update(df_posts_dev["preds"].to_dict())



    

Batches: 100%|██████████| 35/35 [00:04<00:00,  7.06it/s]
Batches: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]
Batches: 100%|██████████| 111/111 [00:15<00:00,  7.35it/s]
Batches: 100%|██████████| 5/5 [00:02<00:00,  2.48it/s]
Languages:  25%|██▌       | 2/8 [00:52<02:44, 27.49s/it]

: 

In [9]:
posts.langs

['fra', 'spa', 'eng', 'por', 'tha', 'deu', 'msa', 'ara']

In [2]:
df_fc.head()

Unnamed: 0_level_0,claim,instances,title,full_text
fact_check_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
56,"""(A Pere Aragonés) ustedes eliminaron la conse...",[https://www.newtral.es/fact-check-medio-ambie...,El Departamento de Medio Ambiente de Cataluña ...,el departamento de medio ambiente de cataluña ...
58,"""(Andalucía tiene) un 36% de paro juvenil de m...",[https://www.newtral.es/paro-juvenil-andalucia...,Andalucía no dobla al resto de España en paro ...,andalucía no dobla al resto de españa en paro ...
59,"""(Biden) destruirá la protección para las enfe...",[https://www.telemundo.com/noticias/noticias-t...,"No, Biden no eliminaría la cobertura de salud ...","no, biden no eliminaría la cobertura de salud ..."
64,"""(El PP) vincula su oposición a renovar el CGP...",[https://www.newtral.es/factcheck-abalos-pp-re...,El PP no ha vinculado su oposición a renovar e...,el pp no ha vinculado su oposición a renovar e...
65,"""(En Euskadi) llevamos 40 años de victorias el...",[https://www.newtral.es/el-bloque-de-izquierda...,El bloque de izquierda no lleva 40 años de vic...,el bloque de izquierda no lleva 40 años de vic...


In [3]:
df_posts_train.head()

Unnamed: 0_level_0,ocr,verdicts,text,fb,tw,ig,full_text,gs
post_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
4,"""Bienaventurados los perseguidos por mi causa ...",,,1,0,0,"""Bienaventurados los perseguidos por mi causa ...",[80729]
6,"""Coman lo que quieran en Semana Santa, el sacr...",False information,,1,0,0,"""Coman lo que quieran en Semana Santa, el sacr...",[50769]
7,"""DISCURSO DE PEDRO CASTILLO, ESTUBO BASADO EN ...",False information,,1,0,0,"""DISCURSO DE PEDRO CASTILLO, ESTUBO BASADO EN ...",[56968]
8,"""Debemos ser Solidarios con los que menos tien...",Partly false information,,1,0,0,"""Debemos ser Solidarios con los que menos tien...",[148668]
11,"""EL AGUINALDO DE 2 MILLONES NO ES SOLO PARA NO...",False information,,1,0,0,"""EL AGUINALDO DE 2 MILLONES NO ES SOLO PARA NO...","[3332, 51854]"


In [4]:
import torch
from sentence_transformers import SentenceTransformer

class EmbeddingModel:
    def __init__(self, model_name, df_fc, device="cuda", show_progress_bar=True, batch_size=128, normalize_embeddings=True, k=10):
        self.model = SentenceTransformer(model_name, device=device)
        self.device = device
        self.show_progress_bar = show_progress_bar
        self.batch_size = batch_size
        self.normalize_embeddings = normalize_embeddings
        self.emb_fc = self.encode(df_fc["full_text"].values)
        self.pos_to_idx = {pos: idx for pos, idx in enumerate(df_fc.index)}
        self.k = k

    def encode(self, texts):
        return torch.tensor(self.model.encode(texts, device="cuda", show_progress_bar=self.show_progress_bar, 
                                              batch_size=self.batch_size, normalize_embeddings=self.normalize_embeddings))
    
    def similarity(self, emb1, emb2):
        return torch.mm(emb1, emb2.T).cpu().numpy()
    
    def predict(self, texts):
        arr1 = self.encode(texts)
        sim = self.similarity(arr1, self.emb_fc)
        idx_sim = np.argsort(sim, axis=1)[:, ::-1][:, :self.k]
        # Apply the function element-wise to the array
        vectorized_map = np.vectorize(lambda x: self.pos_to_idx.get(x, None))
        return vectorized_map(idx_sim)

model_name = '/home/bsc/bsc830651/.cache/huggingface/hub/models--intfloat--multilingual-e5-large/snapshots/ab10c1a7f42e74530fe7ae5be82e6d4f11a719eb'
model = EmbeddingModel(model_name, df_fc)


  from tqdm.autonotebook import tqdm, trange
Batches: 100%|██████████| 111/111 [00:15<00:00,  7.28it/s]


In [5]:
df_posts_train_aux = df_posts_train.copy()
df_posts_train_aux["preds"] = model.predict(df_posts_train_aux["full_text"].values).tolist()
df_posts_train_aux.head()

Batches:   0%|          | 0/44 [00:00<?, ?it/s]

Batches: 100%|██████████| 44/44 [00:15<00:00,  2.75it/s]


Unnamed: 0_level_0,ocr,verdicts,text,fb,tw,ig,full_text,gs,preds
post_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
4,"""Bienaventurados los perseguidos por mi causa ...",,,1,0,0,"""Bienaventurados los perseguidos por mi causa ...",[80729],"[80729, 82043, 81835, 88952, 21759, 1755, 8956..."
6,"""Coman lo que quieran en Semana Santa, el sacr...",False information,,1,0,0,"""Coman lo que quieran en Semana Santa, el sacr...",[50769],"[50769, 150832, 27416, 139364, 101111, 36525, ..."
7,"""DISCURSO DE PEDRO CASTILLO, ESTUBO BASADO EN ...",False information,,1,0,0,"""DISCURSO DE PEDRO CASTILLO, ESTUBO BASADO EN ...",[56968],"[56968, 56952, 152115, 106355, 41513, 102318, ..."
8,"""Debemos ser Solidarios con los que menos tien...",Partly false information,,1,0,0,"""Debemos ser Solidarios con los que menos tien...",[148668],"[148668, 197461, 102317, 102321, 102313, 10231..."
11,"""EL AGUINALDO DE 2 MILLONES NO ES SOLO PARA NO...",False information,,1,0,0,"""EL AGUINALDO DE 2 MILLONES NO ES SOLO PARA NO...","[3332, 51854]","[51854, 3332, 195763, 94629, 195996, 27418, 11..."


In [6]:
df_posts_train_aux.apply(lambda x: len(set(x["gs"]).intersection(x["preds"])), axis=1).mean()

np.float64(0.9822316986496091)

In [7]:
df_posts_dev_aux = df_posts_dev.copy()
df_posts_dev_aux["preds"] = model.predict(df_posts_dev_aux["full_text"].values).tolist()
df_posts_dev_aux.head()

Batches: 100%|██████████| 5/5 [00:01<00:00,  2.50it/s]


Unnamed: 0_level_0,ocr,verdicts,text,fb,tw,ig,full_text,preds
post_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
48,"""Un pueblo que elige corruptos, impostores, la...",Partly false information,,1,0,0,"""Un pueblo que elige corruptos, impostores, la...","[98618, 20499, 20498, 54015, 35924, 154254, 88..."
149,10:39 PM | 8.9kB/s ← Tweet Euquico [USER] f Gr...,Partly false information,,1,0,0,10:39 PM | 8.9kB/s ← Tweet Euquico [USER] f Gr...,"[39355, 89359, 48291, 118790, 74222, 50905, 13..."
280,"44 ANDRÉS FELIPE ARIAS NO SE HA FUGADO, ESTA E...",,,1,0,0,"44 ANDRÉS FELIPE ARIAS NO SE HA FUGADO, ESTA E...","[74210, 53901, 113641, 138621, 50641, 112439, ..."
324,"8:26 PM 3 Resistencia Uruguay 154 miembros, 3 ...",False information,,1,0,0,"8:26 PM 3 Resistencia Uruguay 154 miembros, 3 ...","[80872, 52933, 80873, 80871, 82062, 84432, 820..."
326,95 CO DE OME BANCO DEL COLOM 50% SALARIO MÍNIM...,Partly false information,,1,0,0,95 CO DE OME BANCO DEL COLOM 50% SALARIO MÍNIM...,"[51436, 51433, 51434, 51197, 115426, 56927, 19..."


In [8]:
df_posts_dev_aux["preds"].to_dict()

{48: [98618, 20499, 20498, 54015, 35924, 154254, 88303, 133885, 136850, 51623],
 149: [39355,
  89359,
  48291,
  118790,
  74222,
  50905,
  137496,
  134233,
  48274,
  48284],
 280: [74210,
  53901,
  113641,
  138621,
  50641,
  112439,
  44675,
  33542,
  112491,
  114082],
 324: [80872, 52933, 80873, 80871, 82062, 84432, 82009, 9081, 81993, 139364],
 326: [51436,
  51433,
  51434,
  51197,
  115426,
  56927,
  195833,
  63966,
  156720,
  118481],
 411: [57403, 57404, 57396, 197500, 8628, 25068, 57398, 57395, 57399, 73682],
 465: [80349, 80348, 21704, 80347, 84466, 99724, 36400, 80737, 41575, 53721],
 496: [124639,
  124641,
  124645,
  124640,
  106789,
  7210,
  80338,
  139640,
  54301,
  29927],
 596: [34099, 34143, 75748, 30895, 44948, 51580, 53866, 48279, 53524, 52484],
 642: [35178,
  14624,
  146909,
  22875,
  88961,
  107161,
  44673,
  120803,
  86502,
  107151],
 653: [80981, 22720, 2265, 51414, 7210, 50797, 137304, 41513, 150835, 72763],
 658: [51345, 146945, 54409, 

In [23]:
df_posts_train["full_text"].values[:10].reshape(1, -1).shape

(1, 10)

In [24]:
model.emb_fc.sum(axis=1).shape


torch.Size([300])

In [28]:
model.predict(df_posts_train["full_text"].values[:20])

Batches: 100%|██████████| 1/1 [00:00<00:00, 16.95it/s]


array([[1755, 3856,  937, 3921,  813, 3336, 3245, 2176, 3886, 3860],
       [3064, 3885, 3857, 3454, 3856, 2155, 3342, 3247, 3248, 4542],
       [4440, 1182, 2148, 1699, 1452,  711, 2133, 4546, 3210, 4542],
       [3955, 1698, 1672, 3211,  937, 3311, 2156, 3860, 1592, 1851],
       [3332, 3856, 3634, 1674,   99, 1697, 1698, 3247, 1604, 2264],
       [3311, 1586, 3231, 3339, 1450, 1674, 1388, 1612, 4408, 1451],
       [  73, 3342, 3247,  175, 3202, 2156, 3248, 4545, 1580, 1677],
       [3886, 3964,  209, 3955, 1603, 3887, 3857, 1355, 4456,  502],
       [3886, 3964, 3955,  209, 3857, 4545, 3887, 1603, 4456,  175],
       [ 502, 3213, 3071,  934, 4855, 3202,   59,  935,   99, 1604],
       [3211, 1580, 1709, 3212, 1753, 3210, 2025,  175, 4549, 1672],
       [ 946, 3515, 3078, 1608, 3072, 3064, 3887,  464, 3330, 4545],
       [3857, 2155, 3064, 4542, 2859, 1699, 1597, 1594, 1893, 2131],
       [3857, 2155, 3064, 4542, 2859, 1699, 1594, 1597, 4456,  335],
       [3337, 1615, 1598, 1763, 35

In [13]:
def get_k_similar_posts(emb_post, emb_fc, k=10,):
    # emb_post_sparse = sparse.csr_matrix(post)
    # emb_fc_sparse = sparse.csr_matrix(fc_pool)

    similarity = torch.mm(emb_post.reshape(1, -1), emb_fc.T).cpu().numpy()

    # Sort in descending order
    st_similarity = np.sort(similarity)[::-1]
    idx_sim = np.argsort(similarity)[::-1]
    print(st_similarity[:k],"\n", idx_sim[:k])
    return idx_sim[:k], st_similarity[:k]

def compute_similar_posts(post_ids, df_posts, df_fc, similarity, k=10, show=False):

    idx_to_pos = {idx: pos for pos, idx in enumerate(df_posts.index)}
    pos_to_idx_fc = {pos: idx for pos, idx in enumerate(df_fc.index)}

    ls_pos = [idx_to_pos[post_id] for post_id in post_ids]
    idx_sim, similarity = np.argsort(similarity[ls_pos, :], axis=1)[:, ::-1][:, :k], np.sort(similarity[ls_pos, :], axis=1)[:, ::-1][:, :k]

    # Apply the function element-wise to the array
    vectorized_map = np.vectorize(lambda x: pos_to_idx_fc.get(x, None))
    fc_ids = vectorized_map(idx_sim)

    if show:
    # Show the full text of the post
        print(df_posts.loc[post_ids, :].to_markdown())
        print("="*100)
        print("\n\n")
        for i, idx in enumerate(idx_sim):
            print(f"Similarity: {similarity[i]}")
            print(df_fc.iloc[idx].to_markdown())
            print("\n\n")
            
    return fc_ids, similarity


compute_similar_posts([11, 8], df_posts_train, df_fc, similarity=prec_similarity, show=True);

|   post_id | ocr                                                                                                                                                                                                                               | verdicts                 | text   |   fb |   tw |   ig | gs            | full_text                                                                                                                                                                                                                         | preds                                                                         |
|----------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------|:-------|-----:|-----:|-----:|:--------------|:--------------------------------------------------------------

In [8]:
df_posts_train["preds"] = compute_similar_posts(df_posts_train.index, df_posts_train, df_fc, similarity=prec_similarity, k=10)[0].tolist()
print(df_posts_train.apply(lambda x: len(list(set(x["gs"]).intersection(set(x["preds"])))) > 0, axis=1).mean())
df_posts_train.head()

0.8930348258706468


Unnamed: 0_level_0,ocr,verdicts,text,fb,tw,ig,gs,full_text,preds
post_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
4,"""Bienaventurados los perseguidos por mi causa ...",,,1,0,0,[80729],"""bienaventurados los perseguidos por mi causa ...","[80729, 82043, 81982, 53979, 89563, 81835, 107..."
6,"""Coman lo que quieran en Semana Santa, el sacr...",False information,,1,0,0,[50769],"""coman lo que quieran en semana santa, el sacr...","[50769, 80782, 101111, 39736, 27415, 36525, 90..."
7,"""DISCURSO DE PEDRO CASTILLO, ESTUBO BASADO EN ...",False information,,1,0,0,[56968],"""discurso de pedro castillo, estubo basado en ...","[56968, 152115, 38000, 101078, 107780, 35886, ..."
8,"""Debemos ser Solidarios con los que menos tien...",Partly false information,,1,0,0,[148668],"""debemos ser solidarios con los que menos tien...","[148668, 197461, 102321, 102317, 102313, 10231..."
11,"""EL AGUINALDO DE 2 MILLONES NO ES SOLO PARA NO...",False information,,1,0,0,"[3332, 51854]","""el aguinaldo de 2 millones no es solo para no...","[51854, 3332, 195763, 37266, 6070, 55226, 8479..."


In [15]:
df_posts_dev["preds"] = compute_similar_posts(df_posts_dev.index, df_posts_dev, df_fc, similarity=prec_similarity_dev, k=10)[0].tolist()
print(df_posts_dev.apply(lambda x: len(list(set(x["gs"]).intersection(set(x["preds"])))) > 0, axis=1).mean())
df_posts_dev.head()

0.0


Unnamed: 0_level_0,ocr,verdicts,text,fb,tw,ig,gs,full_text,preds
post_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
48,"""Un pueblo que elige corruptos, impostores, la...",Partly false information,,1,0,0,[],"""un pueblo que elige corruptos, impostores, la...","[98618, 20499, 20498, 54015, 35924, 53032, 154..."
149,10:39 PM | 8.9kB/s ← Tweet Euquico [USER] f Gr...,Partly false information,,1,0,0,[],10:39 pm | 8.9kb/s ← tweet euquico [user] f gr...,"[39355, 50905, 89359, 137496, 48291, 74222, 50..."
280,"44 ANDRÉS FELIPE ARIAS NO SE HA FUGADO, ESTA E...",,,1,0,0,[],"44 andrés felipe arias no se ha fugado, esta e...","[74210, 150855, 112997, 113641, 140952, 118496..."
324,"8:26 PM 3 Resistencia Uruguay 154 miembros, 3 ...",False information,,1,0,0,[],"8:26 pm 3 resistencia uruguay 154 miembros, 3 ...","[80872, 80873, 52933, 80871, 82062, 84432, 908..."
326,95 CO DE OME BANCO DEL COLOM 50% SALARIO MÍNIM...,Partly false information,,1,0,0,[],95 co de ome banco del colom 50% salario mínim...,"[51436, 51433, 51434, 115426, 195833, 51197, 1..."


In [14]:
compute_similar_posts([48, 149], df_posts_dev, df_fc, similarity=prec_similarity_dev, show=True);

|   post_id | ocr                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               | verdicts                 | text   |   fb |   tw |   ig | gs   | full_text                                                                                                                                             

In [25]:
np.dot(emb_post[0], emb_fc[0]) / (np.linalg.norm(emb_post[0]) * np.linalg.norm(emb_fc[0]))

np.float32(0.24843584)

In [27]:
np.dot(emb_post[0], emb_fc[1]) / (np.linalg.norm(emb_post[0]) * np.linalg.norm(emb_fc[1]))

np.float32(0.20492055)

In [28]:
np.dot(emb_post[1], emb_fc[0]) / (np.linalg.norm(emb_post[1]) * np.linalg.norm(emb_fc[0]))

np.float32(0.3132233)