In [None]:
pip install deepsparse pandas transformers

Collecting deepsparse
  Downloading deepsparse-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting sparsezoo~=1.8.0 (from deepsparse)
  Downloading sparsezoo-1.8.1-py3-none-any.whl.metadata (20 kB)
Collecting onnx<1.15.0,>=1.5.0 (from deepsparse)
  Downloading onnx-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (15 kB)
Collecting py-machineid>=0.3.0 (from sparsezoo~=1.8.0->deepsparse)
  Downloading py_machineid-0.6.0-py3-none-any.whl.metadata (2.3 kB)
Collecting geocoder>=1.38.0 (from sparsezoo~=1.8.0->deepsparse)
  Downloading geocoder-1.38.1-py2.py3-none-any.whl.metadata (14 kB)
Collecting onnxruntime>=1.0.0 (from sparsezoo~=1.8.0->deepsparse)
  Downloading onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting future (from geocoder>=1.38.0->sparsezoo~=1.8.0->deepsparse)
  Downloading future-1.0.0-py3-none-any.whl.metadata (4.0 kB)
Collecting ratelim (from geocode

In [None]:
from deepsparse import compile_model

stub = "zoo:roberta-base-conll2003_wikipedia_bookcorpus-pruned85"
engine = compile_model(stub, batch_size=1)
print(engine)

Downloading Chunks for deployment.tar.gz: 100%|██████████| 189M/189M [00:02<00:00, 73.7MB/s]
Combining Chunks: 100%|██████████| 189M/189M [00:00<00:00, 641MB/s]


deepsparse.engine.Engine:
	onnx_file_path: /root/.cache/sparsezoo/neuralmagic/roberta-base-conll2003_wikipedia_bookcorpus-pruned85/deployment/model.onnx
	batch_size: 1
	num_cores: 48
	num_streams: 1
	scheduler: Scheduler.default
	fraction_of_supported_ops: 0.9825
	cpu_avx_type: avx512
	cpu_vnni: False


In [None]:
import pandas as pd
import pickle
import numpy as np
import torch
from deepsparse import Engine
from transformers import AutoTokenizer

In [None]:
# Load the news dataset TSV file
df = pd.read_csv("news_min.tsv", sep='\t')

# Load the tokenizer (RoBERTa-based model)
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



In [None]:
df.head(10)

Unnamed: 0,News ID,Category,Topic,Headline,News body
0,N10000,sports,soccer,Predicting Atlanta United's lineup against Col...,"Only FIVE internationals allowed, count em, FI..."
1,N10001,news,newspolitics,Mitch McConnell: DC statehood push is 'full bo...,WASHINGTON -- Senate Majority Leader Mitch McC...
2,N10002,news,newsus,Home In North Highlands Damaged By Fire,NORTH HIGHLANDS (CBS13) Fire damaged a home ...
3,N10003,news,newspolitics,Meghan McCain blames 'liberal media' and 'thir...,Meghan McCain is speaking out after a journali...
4,N10004,news,newsworld,Today in History: Aug 1,"1714: George I becomes King Georg Ludwig, Elec..."
5,N10005,sports,football_nfl,Odell Beckham Jr New Custom Rolls Royce Cullinan,Odell Beckham Jr New Custom Rolls Royce Cullin...
6,N10006,autos,autosclassics,This Attention-Grabbing Chevrolet Malibu Packs...,This muscle car has the power to justify its d...
7,N10007,news,newspolitics,GOP senators urge Trump to reject Iran's 'nucl...,Three GOP senators this week urged President T...
8,N10008,foodanddrink,newstrends,5 great Mexican restaurants in Lebanon County,The latest from the Richland Borough Council M...
9,N10009,sports,football_nfl,NFL officials union head expects a lot of pres...,The NFL has formalized the language of the rul...


In [None]:
def get_embeddings(text, model=engine):
    """
    Generate sentence embeddings using DeepSparse engine and tokenizer.

    Args:
        text (str): The input text for which to compute embeddings.
        model (Engine): The DeepSparse model engine (default is the loaded engine).

    Returns:
        torch.Tensor: Sentence embedding vector.
    """
    text = str(text)
    # Tokenize the input text with max_length=128 to match model's expected input size
    tokens = tokenizer(text, return_tensors='np', padding='max_length', truncation=True, max_length=128)

    # Convert input_ids and attention_mask to numpy arrays with int64 dtype
    input_ids = np.array(tokens['input_ids'], dtype=np.int64)
    attention_mask = np.array(tokens['attention_mask'], dtype=np.int64)

    # Pass both input_ids and attention_mask to the model
    embeddings = model([input_ids,attention_mask])

    # Convert the embedding result to a torch tensor (first element of the first batch)
    embedding = torch.tensor(embeddings[0][0])

    return embedding


In [None]:
from tqdm import tqdm
import os

In [None]:
checkpoint_file = 'news_embeddings_checkpoint.pkl'
checkpoint_freq = 5000  # Save checkpoint every 10,000 rows

In [None]:
if os.path.exists(checkpoint_file):
    # Load the checkpoint and get the last processed index
    with open(checkpoint_file, 'rb') as f:
        news_embeddings_dict, last_processed_index = pickle.load(f)
    print(f"Resuming from checkpoint. Last processed row: {last_processed_index}")
else:
    # Start from scratch if no checkpoint is found
    news_embeddings_dict = {}
    last_processed_index = -1

# Iterate through each row in the dataset
for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing news articles"):
    news_id = row['News ID']    # Get the News ID
    news_body = row['News body']  # Get the News body text

    # Get the embeddings for the News body text
    embedding = get_embeddings(news_body)

    # Store the embeddings in the dictionary with news_id as the key
    news_embeddings_dict[news_id] = embedding

    if index > 0 and index % checkpoint_freq == 0:
        with open(checkpoint_file, 'wb') as f:
            pickle.dump((news_embeddings_dict, index), f)
        print(f"Checkpoint saved at row: {index}")

# Save the embeddings dictionary as a .pkl file
with open('news_embeddings.pkl', 'wb') as f:
    pickle.dump(news_embeddings_dict, f)

print("Embeddings have been saved to 'news_embeddings.pkl'.")

Processing news articles:   0%|          | 507/113762 [00:09<35:29, 53.19it/s]

Checkpoint saved at row: 500


Processing news articles:   1%|          | 1010/113762 [00:20<33:23, 56.29it/s]

Checkpoint saved at row: 1000


Processing news articles:   1%|▏         | 1508/113762 [00:31<53:06, 35.23it/s]  

Checkpoint saved at row: 1500


Processing news articles:   2%|▏         | 2004/113762 [00:40<44:07, 42.21it/s]

Checkpoint saved at row: 2000


Processing news articles:   2%|▏         | 2509/113762 [00:50<41:12, 44.99it/s]

Checkpoint saved at row: 2500


Processing news articles:   3%|▎         | 3008/113762 [01:01<46:45, 39.47it/s]

Checkpoint saved at row: 3000


Processing news articles:   3%|▎         | 3512/113762 [01:14<49:57, 36.78it/s]

Checkpoint saved at row: 3500


Processing news articles:   4%|▎         | 4013/113762 [01:23<43:15, 42.28it/s]

Checkpoint saved at row: 4000


Processing news articles:   4%|▍         | 4509/113762 [01:34<54:18, 33.53it/s]  

Checkpoint saved at row: 4500


Processing news articles:   4%|▍         | 5007/113762 [01:45<56:38, 32.00it/s]

Checkpoint saved at row: 5000


Processing news articles:   5%|▍         | 5508/113762 [01:55<1:16:44, 23.51it/s]

Checkpoint saved at row: 5500


Processing news articles:   5%|▌         | 6010/113762 [02:05<53:13, 33.74it/s]  

Checkpoint saved at row: 6000


Processing news articles:   6%|▌         | 6506/113762 [02:16<1:23:40, 21.36it/s]

Checkpoint saved at row: 6500


Processing news articles:   6%|▌         | 7013/113762 [02:27<1:05:32, 27.14it/s]

Checkpoint saved at row: 7000


Processing news articles:   7%|▋         | 7510/113762 [02:37<1:00:34, 29.23it/s]

Checkpoint saved at row: 7500


Processing news articles:   7%|▋         | 8001/113762 [02:49<1:29:23, 19.72it/s]

Checkpoint saved at row: 8000


Processing news articles:   7%|▋         | 8511/113762 [03:00<1:03:27, 27.64it/s]

Checkpoint saved at row: 8500


Processing news articles:   8%|▊         | 9009/113762 [03:11<1:07:39, 25.81it/s]

Checkpoint saved at row: 9000


Processing news articles:   8%|▊         | 9513/113762 [03:21<1:11:55, 24.16it/s]

Checkpoint saved at row: 9500


Processing news articles:   9%|▉         | 10008/113762 [03:32<59:54, 28.87it/s]  

Checkpoint saved at row: 10000


Processing news articles:   9%|▉         | 10509/113762 [03:43<1:06:42, 25.79it/s]

Checkpoint saved at row: 10500


Processing news articles:  10%|▉         | 11007/113762 [03:55<1:13:31, 23.29it/s]

Checkpoint saved at row: 11000


Processing news articles:  10%|█         | 11513/113762 [04:06<1:14:21, 22.92it/s]

Checkpoint saved at row: 11500


Processing news articles:  11%|█         | 12010/113762 [04:16<1:23:26, 20.32it/s]

Checkpoint saved at row: 12000


Processing news articles:  11%|█         | 12511/113762 [04:28<1:27:35, 19.27it/s]

Checkpoint saved at row: 12500


Processing news articles:  11%|█▏        | 13005/113762 [04:40<1:38:08, 17.11it/s]

Checkpoint saved at row: 13000


Processing news articles:  12%|█▏        | 13508/113762 [04:52<2:14:40, 12.41it/s]

Checkpoint saved at row: 13500


Processing news articles:  12%|█▏        | 14003/113762 [05:04<2:26:49, 11.32it/s]

Checkpoint saved at row: 14000


Processing news articles:  13%|█▎        | 14495/113762 [05:13<25:24, 65.11it/s]

Checkpoint saved at row: 14500


Processing news articles:  13%|█▎        | 15008/113762 [05:28<1:24:32, 19.47it/s]

Checkpoint saved at row: 15000


Processing news articles:  14%|█▎        | 15512/113762 [05:40<1:26:28, 18.94it/s]

Checkpoint saved at row: 15500


Processing news articles:  14%|█▍        | 16008/113762 [05:52<1:21:55, 19.89it/s]

Checkpoint saved at row: 16000


Processing news articles:  15%|█▍        | 16509/113762 [06:04<2:01:44, 13.31it/s]

Checkpoint saved at row: 16500


Processing news articles:  15%|█▍        | 17003/113762 [06:15<2:45:18,  9.76it/s]

Checkpoint saved at row: 17000


Processing news articles:  15%|█▌        | 17499/113762 [06:26<29:10, 54.99it/s]

Checkpoint saved at row: 17500


Processing news articles:  16%|█▌        | 18009/113762 [06:40<1:36:01, 16.62it/s]

Checkpoint saved at row: 18000


Processing news articles:  16%|█▋        | 18505/113762 [06:52<2:22:12, 11.16it/s]

Checkpoint saved at row: 18500


Processing news articles:  17%|█▋        | 19010/113762 [07:05<1:50:45, 14.26it/s]

Checkpoint saved at row: 19000


Processing news articles:  17%|█▋        | 19509/113762 [07:18<1:37:00, 16.19it/s]

Checkpoint saved at row: 19500


Processing news articles:  18%|█▊        | 20008/113762 [07:31<1:43:18, 15.12it/s]

Checkpoint saved at row: 20000


Processing news articles:  18%|█▊        | 20508/113762 [07:43<1:45:38, 14.71it/s]

Checkpoint saved at row: 20500


Processing news articles:  18%|█▊        | 21006/113762 [07:53<2:19:36, 11.07it/s]

Checkpoint saved at row: 21000


Processing news articles:  19%|█▉        | 21512/113762 [08:05<1:49:07, 14.09it/s]

Checkpoint saved at row: 21500


Processing news articles:  19%|█▉        | 22011/113762 [08:18<1:38:48, 15.48it/s]

Checkpoint saved at row: 22000


Processing news articles:  20%|█▉        | 22509/113762 [08:31<2:04:04, 12.26it/s]

Checkpoint saved at row: 22500


Processing news articles:  20%|██        | 23009/113762 [08:44<1:58:53, 12.72it/s]

Checkpoint saved at row: 23000


Processing news articles:  21%|██        | 23501/113762 [08:56<2:39:37,  9.42it/s]

Checkpoint saved at row: 23500


Processing news articles:  21%|██        | 24008/113762 [09:09<2:20:14, 10.67it/s]

Checkpoint saved at row: 24000


Processing news articles:  22%|██▏       | 24508/113762 [09:21<1:59:41, 12.43it/s]

Checkpoint saved at row: 24500


Processing news articles:  22%|██▏       | 25006/113762 [09:32<2:03:39, 11.96it/s]

Checkpoint saved at row: 25000


Processing news articles:  22%|██▏       | 25509/113762 [09:44<1:59:09, 12.34it/s]

Checkpoint saved at row: 25500


Processing news articles:  23%|██▎       | 26008/113762 [09:59<2:19:15, 10.50it/s]

Checkpoint saved at row: 26000


Processing news articles:  23%|██▎       | 26505/113762 [10:13<2:31:29,  9.60it/s]

Checkpoint saved at row: 26500


Processing news articles:  24%|██▎       | 27005/113762 [10:27<2:46:47,  8.67it/s]

Checkpoint saved at row: 27000


Processing news articles:  24%|██▍       | 27509/113762 [10:40<2:23:09, 10.04it/s]

Checkpoint saved at row: 27500


Processing news articles:  25%|██▍       | 28006/113762 [10:55<2:11:21, 10.88it/s]

Checkpoint saved at row: 28000


Processing news articles:  25%|██▌       | 28508/113762 [11:09<2:03:29, 11.51it/s]

Checkpoint saved at row: 28500


Processing news articles:  25%|██▌       | 29008/113762 [11:23<2:13:48, 10.56it/s]

Checkpoint saved at row: 29000


Processing news articles:  26%|██▌       | 29507/113762 [11:37<2:13:18, 10.53it/s]

Checkpoint saved at row: 29500


Processing news articles:  26%|██▋       | 30001/113762 [11:51<2:45:51,  8.42it/s]

Checkpoint saved at row: 30000


Processing news articles:  27%|██▋       | 30501/113762 [12:06<2:55:57,  7.89it/s]

Checkpoint saved at row: 30500


Processing news articles:  27%|██▋       | 30999/113762 [12:18<22:54, 60.22it/s]

Checkpoint saved at row: 31000


Processing news articles:  28%|██▊       | 31505/113762 [12:35<3:13:44,  7.08it/s]

Checkpoint saved at row: 31500


Processing news articles:  28%|██▊       | 32007/113762 [12:48<2:13:52, 10.18it/s]

Checkpoint saved at row: 32000


Processing news articles:  29%|██▊       | 32507/113762 [13:01<2:11:33, 10.29it/s]

Checkpoint saved at row: 32500


Processing news articles:  29%|██▉       | 33007/113762 [13:15<2:25:59,  9.22it/s]

Checkpoint saved at row: 33000


Processing news articles:  29%|██▉       | 33511/113762 [13:29<2:14:29,  9.94it/s]

Checkpoint saved at row: 33500


Processing news articles:  30%|██▉       | 34012/113762 [13:43<2:11:03, 10.14it/s]

Checkpoint saved at row: 34000


Processing news articles:  30%|███       | 34508/113762 [13:57<2:23:10,  9.23it/s]

Checkpoint saved at row: 34500


Processing news articles:  31%|███       | 35011/113762 [14:13<2:33:38,  8.54it/s]

Checkpoint saved at row: 35000


Processing news articles:  31%|███       | 35510/113762 [14:28<2:22:04,  9.18it/s]

Checkpoint saved at row: 35500


Processing news articles:  32%|███▏      | 36003/113762 [14:42<3:17:55,  6.55it/s]

Checkpoint saved at row: 36000


Processing news articles:  32%|███▏      | 36507/113762 [14:56<2:21:14,  9.12it/s]

Checkpoint saved at row: 36500


Processing news articles:  33%|███▎      | 37007/113762 [15:10<2:53:36,  7.37it/s]

Checkpoint saved at row: 37000


Processing news articles:  33%|███▎      | 37501/113762 [15:26<3:31:56,  6.00it/s]

Checkpoint saved at row: 37500


Processing news articles:  33%|███▎      | 38011/113762 [15:41<2:24:39,  8.73it/s]

Checkpoint saved at row: 38000


Processing news articles:  34%|███▍      | 38506/113762 [15:55<3:02:50,  6.86it/s]

Checkpoint saved at row: 38500


Processing news articles:  34%|███▍      | 39002/113762 [16:09<3:04:58,  6.74it/s]

Checkpoint saved at row: 39000


Processing news articles:  35%|███▍      | 39506/113762 [16:23<3:08:38,  6.56it/s]

Checkpoint saved at row: 39500


Processing news articles:  35%|███▌      | 40007/113762 [16:37<2:57:16,  6.93it/s]

Checkpoint saved at row: 40000


Processing news articles:  36%|███▌      | 40505/113762 [16:50<3:21:00,  6.07it/s]

Checkpoint saved at row: 40500


Processing news articles:  36%|███▌      | 41012/113762 [17:04<2:19:17,  8.70it/s]

Checkpoint saved at row: 41000


Processing news articles:  36%|███▋      | 41511/113762 [17:18<2:36:22,  7.70it/s]

Checkpoint saved at row: 41500


Processing news articles:  37%|███▋      | 42005/113762 [17:32<2:40:09,  7.47it/s]

Checkpoint saved at row: 42000


Processing news articles:  37%|███▋      | 42502/113762 [17:46<3:39:47,  5.40it/s]

Checkpoint saved at row: 42500


Processing news articles:  38%|███▊      | 43006/113762 [18:00<3:34:03,  5.51it/s]

Checkpoint saved at row: 43000


Processing news articles:  38%|███▊      | 43508/113762 [18:14<2:18:37,  8.45it/s]

Checkpoint saved at row: 43500


Processing news articles:  39%|███▊      | 44009/113762 [18:28<2:39:52,  7.27it/s]

Checkpoint saved at row: 44000


Processing news articles:  39%|███▉      | 44507/113762 [18:45<3:01:56,  6.34it/s]

Checkpoint saved at row: 44500


Processing news articles:  40%|███▉      | 45008/113762 [19:01<3:06:01,  6.16it/s]

Checkpoint saved at row: 45000


Processing news articles:  40%|████      | 45507/113762 [19:16<2:39:44,  7.12it/s]

Checkpoint saved at row: 45500


Processing news articles:  40%|████      | 46013/113762 [19:31<2:22:19,  7.93it/s]

Checkpoint saved at row: 46000


Processing news articles:  41%|████      | 46508/113762 [19:46<2:52:18,  6.51it/s]

Checkpoint saved at row: 46500


Processing news articles:  41%|████▏     | 47010/113762 [20:03<2:39:26,  6.98it/s]

Checkpoint saved at row: 47000


Processing news articles:  42%|████▏     | 47507/113762 [20:18<2:47:43,  6.58it/s]

Checkpoint saved at row: 47500


Processing news articles:  42%|████▏     | 48006/113762 [20:33<2:46:52,  6.57it/s]

Checkpoint saved at row: 48000


Processing news articles:  43%|████▎     | 48510/113762 [20:49<2:22:15,  7.64it/s]

Checkpoint saved at row: 48500


Processing news articles:  43%|████▎     | 49011/113762 [21:04<2:40:41,  6.72it/s]

Checkpoint saved at row: 49000


Processing news articles:  44%|████▎     | 49506/113762 [21:17<3:21:51,  5.31it/s]

Checkpoint saved at row: 49500


Processing news articles:  44%|████▍     | 50005/113762 [21:31<3:18:25,  5.36it/s]

Checkpoint saved at row: 50000


Processing news articles:  44%|████▍     | 50512/113762 [21:45<2:28:17,  7.11it/s]

Checkpoint saved at row: 50500


Processing news articles:  45%|████▍     | 51012/113762 [21:59<2:24:11,  7.25it/s]

Checkpoint saved at row: 51000


Processing news articles:  45%|████▌     | 51511/113762 [22:13<2:34:47,  6.70it/s]

Checkpoint saved at row: 51500


Processing news articles:  46%|████▌     | 52007/113762 [22:27<2:53:18,  5.94it/s]

Checkpoint saved at row: 52000


Processing news articles:  46%|████▌     | 52509/113762 [22:42<2:40:56,  6.34it/s]

Checkpoint saved at row: 52500


Processing news articles:  47%|████▋     | 53009/113762 [22:58<2:46:48,  6.07it/s]

Checkpoint saved at row: 53000


Processing news articles:  47%|████▋     | 53507/113762 [23:15<2:35:23,  6.46it/s]

Checkpoint saved at row: 53500


Processing news articles:  47%|████▋     | 54005/113762 [23:30<3:38:41,  4.55it/s]

Checkpoint saved at row: 54000


Processing news articles:  48%|████▊     | 54508/113762 [23:46<2:28:14,  6.66it/s]

Checkpoint saved at row: 54500


Processing news articles:  48%|████▊     | 55007/113762 [24:02<2:48:15,  5.82it/s]

Checkpoint saved at row: 55000


Processing news articles:  49%|████▉     | 55506/113762 [24:16<3:18:44,  4.89it/s]

Checkpoint saved at row: 55500


Processing news articles:  49%|████▉     | 56007/113762 [24:33<2:58:13,  5.40it/s]

Checkpoint saved at row: 56000


Processing news articles:  50%|████▉     | 56506/113762 [24:48<3:28:55,  4.57it/s]

Checkpoint saved at row: 56500


Processing news articles:  50%|█████     | 57007/113762 [25:05<2:59:46,  5.26it/s]

Checkpoint saved at row: 57000


Processing news articles:  51%|█████     | 57506/113762 [25:20<3:18:59,  4.71it/s]

Checkpoint saved at row: 57500


Processing news articles:  51%|█████     | 58010/113762 [25:37<2:29:09,  6.23it/s]

Checkpoint saved at row: 58000


Processing news articles:  51%|█████▏    | 58508/113762 [25:53<2:34:02,  5.98it/s]

Checkpoint saved at row: 58500


Processing news articles:  52%|█████▏    | 59008/113762 [26:10<2:29:57,  6.09it/s]

Checkpoint saved at row: 59000


Processing news articles:  52%|█████▏    | 59508/113762 [26:24<2:16:53,  6.61it/s]

Checkpoint saved at row: 59500


Processing news articles:  53%|█████▎    | 60011/113762 [26:40<2:50:11,  5.26it/s]

Checkpoint saved at row: 60000


Processing news articles:  53%|█████▎    | 60506/113762 [26:57<2:37:26,  5.64it/s]

Checkpoint saved at row: 60500


Processing news articles:  54%|█████▎    | 61011/113762 [27:12<2:34:42,  5.68it/s]

Checkpoint saved at row: 61000


Processing news articles:  54%|█████▍    | 61510/113762 [27:28<2:28:15,  5.87it/s]

Checkpoint saved at row: 61500


Processing news articles:  54%|█████▍    | 62000/113762 [27:41<16:41, 51.68it/s]

Checkpoint saved at row: 62000


Processing news articles:  55%|█████▍    | 62508/113762 [28:02<2:39:38,  5.35it/s]

Checkpoint saved at row: 62500


Processing news articles:  55%|█████▌    | 63009/113762 [28:20<2:43:26,  5.18it/s]

Checkpoint saved at row: 63000


Processing news articles:  56%|█████▌    | 63512/113762 [28:36<2:25:26,  5.76it/s]

Checkpoint saved at row: 63500


Processing news articles:  56%|█████▋    | 64012/113762 [28:53<2:21:00,  5.88it/s]

Checkpoint saved at row: 64000


Processing news articles:  57%|█████▋    | 64510/113762 [29:10<2:36:07,  5.26it/s]

Checkpoint saved at row: 64500


Processing news articles:  57%|█████▋    | 65007/113762 [29:25<2:26:41,  5.54it/s]

Checkpoint saved at row: 65000


Processing news articles:  58%|█████▊    | 65505/113762 [29:40<3:52:55,  3.45it/s]

Checkpoint saved at row: 65500


Processing news articles:  58%|█████▊    | 66004/113762 [29:55<3:11:51,  4.15it/s]

Checkpoint saved at row: 66000


Processing news articles:  58%|█████▊    | 66499/113762 [30:07<16:29, 47.79it/s]

Checkpoint saved at row: 66500


Processing news articles:  59%|█████▉    | 67008/113762 [30:29<2:29:06,  5.23it/s]

Checkpoint saved at row: 67000


Processing news articles:  59%|█████▉    | 67504/113762 [30:46<3:17:48,  3.90it/s]

Checkpoint saved at row: 67500


Processing news articles:  60%|█████▉    | 68010/113762 [31:03<2:14:46,  5.66it/s]

Checkpoint saved at row: 68000


Processing news articles:  60%|██████    | 68508/113762 [31:19<2:58:26,  4.23it/s]

Checkpoint saved at row: 68500


Processing news articles:  61%|██████    | 69006/113762 [31:35<3:06:28,  4.00it/s]

Checkpoint saved at row: 69000


Processing news articles:  61%|██████    | 69501/113762 [31:52<3:36:31,  3.41it/s]

Checkpoint saved at row: 69500


Processing news articles:  62%|██████▏   | 70008/113762 [32:09<2:17:00,  5.32it/s]

Checkpoint saved at row: 70000


Processing news articles:  62%|██████▏   | 70508/113762 [32:26<2:34:27,  4.67it/s]

Checkpoint saved at row: 70500


Processing news articles:  62%|██████▏   | 71007/113762 [32:41<2:36:44,  4.55it/s]

Checkpoint saved at row: 71000


Processing news articles:  63%|██████▎   | 71510/113762 [32:58<2:26:27,  4.81it/s]

Checkpoint saved at row: 71500


Processing news articles:  63%|██████▎   | 72010/113762 [33:15<2:17:33,  5.06it/s]

Checkpoint saved at row: 72000


Processing news articles:  64%|██████▎   | 72506/113762 [33:35<3:40:04,  3.12it/s]

Checkpoint saved at row: 72500


Processing news articles:  64%|██████▍   | 73008/113762 [33:51<2:27:59,  4.59it/s]

Checkpoint saved at row: 73000


Processing news articles:  65%|██████▍   | 73508/113762 [34:07<3:12:33,  3.48it/s]

Checkpoint saved at row: 73500


Processing news articles:  65%|██████▌   | 73999/113762 [34:19<12:15, 54.09it/s]

Checkpoint saved at row: 74000


Processing news articles:  65%|██████▌   | 74500/113762 [34:38<14:00, 46.71it/s]

Checkpoint saved at row: 74500


Processing news articles:  66%|██████▌   | 75007/113762 [35:00<2:16:02,  4.75it/s]

Checkpoint saved at row: 75000


Processing news articles:  66%|██████▋   | 75507/113762 [35:19<2:23:22,  4.45it/s]

Checkpoint saved at row: 75500


Processing news articles:  67%|██████▋   | 76002/113762 [35:37<3:32:18,  2.96it/s]

Checkpoint saved at row: 76000


Processing news articles:  67%|██████▋   | 76512/113762 [35:55<2:12:33,  4.68it/s]

Checkpoint saved at row: 76500


Processing news articles:  68%|██████▊   | 77007/113762 [36:10<2:11:49,  4.65it/s]

Checkpoint saved at row: 77000


Processing news articles:  68%|██████▊   | 77507/113762 [36:27<2:42:51,  3.71it/s]

Checkpoint saved at row: 77500


Processing news articles:  69%|██████▊   | 78008/113762 [36:45<2:19:47,  4.26it/s]

Checkpoint saved at row: 78000


Processing news articles:  69%|██████▉   | 78506/113762 [37:02<3:01:16,  3.24it/s]

Checkpoint saved at row: 78500


Processing news articles:  69%|██████▉   | 79008/113762 [37:18<2:07:26,  4.55it/s]

Checkpoint saved at row: 79000


Processing news articles:  70%|██████▉   | 79504/113762 [37:36<4:19:52,  2.20it/s]

Checkpoint saved at row: 79500


Processing news articles:  70%|███████   | 80008/113762 [37:54<1:58:39,  4.74it/s]

Checkpoint saved at row: 80000


Processing news articles:  71%|███████   | 80505/113762 [38:12<3:01:54,  3.05it/s]

Checkpoint saved at row: 80500


Processing news articles:  71%|███████   | 81006/113762 [38:29<2:45:56,  3.29it/s]

Checkpoint saved at row: 81000


Processing news articles:  72%|███████▏  | 81509/113762 [38:46<2:10:24,  4.12it/s]

Checkpoint saved at row: 81500


Processing news articles:  72%|███████▏  | 82008/113762 [39:04<2:24:55,  3.65it/s]

Checkpoint saved at row: 82000


Processing news articles:  73%|███████▎  | 82508/113762 [39:22<1:57:19,  4.44it/s]

Checkpoint saved at row: 82500


Processing news articles:  73%|███████▎  | 83006/113762 [39:38<2:26:20,  3.50it/s]

Checkpoint saved at row: 83000


Processing news articles:  73%|███████▎  | 83505/113762 [39:54<3:02:18,  2.77it/s]

Checkpoint saved at row: 83500


Processing news articles:  74%|███████▍  | 84011/113762 [40:13<1:58:21,  4.19it/s]

Checkpoint saved at row: 84000


Processing news articles:  74%|███████▍  | 84510/113762 [40:31<1:49:08,  4.47it/s]

Checkpoint saved at row: 84500


Processing news articles:  75%|███████▍  | 85001/113762 [40:49<2:34:48,  3.10it/s]

Checkpoint saved at row: 85000


Processing news articles:  75%|███████▌  | 85506/113762 [41:06<2:08:09,  3.67it/s]

Checkpoint saved at row: 85500


Processing news articles:  76%|███████▌  | 86005/113762 [41:25<2:08:10,  3.61it/s]

Checkpoint saved at row: 86000


Processing news articles:  76%|███████▌  | 86508/113762 [41:43<1:51:02,  4.09it/s]

Checkpoint saved at row: 86500


Processing news articles:  76%|███████▋  | 87005/113762 [41:59<2:51:24,  2.60it/s]

Checkpoint saved at row: 87000


Processing news articles:  77%|███████▋  | 87511/113762 [42:18<1:44:52,  4.17it/s]

Checkpoint saved at row: 87500


Processing news articles:  77%|███████▋  | 88008/113762 [42:36<1:32:15,  4.65it/s]

Checkpoint saved at row: 88000


Processing news articles:  78%|███████▊  | 88507/113762 [42:54<1:49:39,  3.84it/s]

Checkpoint saved at row: 88500


Processing news articles:  78%|███████▊  | 89007/113762 [43:13<1:47:01,  3.86it/s]

Checkpoint saved at row: 89000


Processing news articles:  79%|███████▊  | 89508/113762 [43:31<1:53:10,  3.57it/s]

Checkpoint saved at row: 89500


Processing news articles:  79%|███████▉  | 90007/113762 [43:49<1:34:04,  4.21it/s]

Checkpoint saved at row: 90000


Processing news articles:  80%|███████▉  | 90509/113762 [44:09<2:05:16,  3.09it/s]

Checkpoint saved at row: 90500


Processing news articles:  80%|███████▉  | 91008/113762 [44:27<1:32:00,  4.12it/s]

Checkpoint saved at row: 91000


Processing news articles:  80%|████████  | 91509/113762 [44:45<1:28:39,  4.18it/s]

Checkpoint saved at row: 91500


Processing news articles:  81%|████████  | 92009/113762 [45:03<1:25:37,  4.23it/s]

Checkpoint saved at row: 92000


Processing news articles:  81%|████████▏ | 92507/113762 [45:23<1:26:50,  4.08it/s]

Checkpoint saved at row: 92500


Processing news articles:  82%|████████▏ | 93006/113762 [45:40<2:02:44,  2.82it/s]

Checkpoint saved at row: 93000


Processing news articles:  82%|████████▏ | 93506/113762 [45:59<1:49:39,  3.08it/s]

Checkpoint saved at row: 93500


Processing news articles:  83%|████████▎ | 94007/113762 [46:18<1:24:09,  3.91it/s]

Checkpoint saved at row: 94000


Processing news articles:  83%|████████▎ | 94506/113762 [46:36<1:33:28,  3.43it/s]

Checkpoint saved at row: 94500


Processing news articles:  84%|████████▎ | 95007/113762 [46:55<1:17:44,  4.02it/s]

Checkpoint saved at row: 95000


Processing news articles:  84%|████████▍ | 95507/113762 [47:13<1:15:01,  4.06it/s]

Checkpoint saved at row: 95500


Processing news articles:  84%|████████▍ | 96007/113762 [47:30<1:13:17,  4.04it/s]

Checkpoint saved at row: 96000


Processing news articles:  85%|████████▍ | 96507/113762 [47:49<1:22:13,  3.50it/s]

Checkpoint saved at row: 96500


Processing news articles:  85%|████████▌ | 97006/113762 [48:08<1:39:01,  2.82it/s]

Checkpoint saved at row: 97000


Processing news articles:  86%|████████▌ | 97512/113762 [48:27<1:09:09,  3.92it/s]

Checkpoint saved at row: 97500


Processing news articles:  86%|████████▌ | 98006/113762 [48:46<1:16:23,  3.44it/s]

Checkpoint saved at row: 98000


Processing news articles:  87%|████████▋ | 98506/113762 [49:05<1:07:17,  3.78it/s]

Checkpoint saved at row: 98500


Processing news articles:  87%|████████▋ | 99010/113762 [49:22<1:09:31,  3.54it/s]

Checkpoint saved at row: 99000


Processing news articles:  87%|████████▋ | 99505/113762 [49:42<1:24:11,  2.82it/s]

Checkpoint saved at row: 99500


Processing news articles:  88%|████████▊ | 100009/113762 [50:00<1:00:21,  3.80it/s]

Checkpoint saved at row: 100000


Processing news articles:  88%|████████▊ | 100508/113762 [50:18<59:56,  3.69it/s]  

Checkpoint saved at row: 100500


Processing news articles:  89%|████████▉ | 101001/113762 [50:38<1:35:15,  2.23it/s]

Checkpoint saved at row: 101000


Processing news articles:  89%|████████▉ | 101500/113762 [50:51<03:47, 53.79it/s]

Checkpoint saved at row: 101500


Processing news articles:  90%|████████▉ | 102004/113762 [51:17<1:21:05,  2.42it/s]

Checkpoint saved at row: 102000


Processing news articles:  90%|█████████ | 102502/113762 [51:38<1:19:37,  2.36it/s]

Checkpoint saved at row: 102500


Processing news articles:  91%|█████████ | 103004/113762 [51:56<1:18:16,  2.29it/s]

Checkpoint saved at row: 103000


Processing news articles:  91%|█████████ | 103499/113762 [52:08<04:44, 36.09it/s]

Checkpoint saved at row: 103500


Processing news articles:  91%|█████████▏| 104006/113762 [52:35<1:04:41,  2.51it/s]

Checkpoint saved at row: 104000


Processing news articles:  92%|█████████▏| 104503/113762 [52:55<1:06:20,  2.33it/s]

Checkpoint saved at row: 104500


Processing news articles:  92%|█████████▏| 105003/113762 [53:14<54:14,  2.69it/s]

Checkpoint saved at row: 105000


Processing news articles:  93%|█████████▎| 105510/113762 [53:34<48:36,  2.83it/s]  

Checkpoint saved at row: 105500


Processing news articles:  93%|█████████▎| 106002/113762 [53:53<49:00,  2.64it/s]

Checkpoint saved at row: 106000


Processing news articles:  94%|█████████▎| 106508/113762 [54:11<31:18,  3.86it/s]

Checkpoint saved at row: 106500


Processing news articles:  94%|█████████▍| 107007/113762 [54:31<46:50,  2.40it/s]  

Checkpoint saved at row: 107000


Processing news articles:  95%|█████████▍| 107510/113762 [54:51<30:11,  3.45it/s]

Checkpoint saved at row: 107500


Processing news articles:  95%|█████████▍| 108002/113762 [55:09<47:21,  2.03it/s]

Checkpoint saved at row: 108000


Processing news articles:  95%|█████████▌| 108503/113762 [55:30<33:53,  2.59it/s]

Checkpoint saved at row: 108500


Processing news articles:  96%|█████████▌| 109009/113762 [55:50<23:26,  3.38it/s]

Checkpoint saved at row: 109000


Processing news articles:  96%|█████████▋| 109508/113762 [56:10<20:30,  3.46it/s]

Checkpoint saved at row: 109500


Processing news articles:  97%|█████████▋| 110007/113762 [56:29<17:20,  3.61it/s]

Checkpoint saved at row: 110000


Processing news articles:  97%|█████████▋| 110500/113762 [56:42<01:30, 36.00it/s]

Checkpoint saved at row: 110500


Processing news articles:  98%|█████████▊| 111006/113762 [57:11<18:20,  2.51it/s]

Checkpoint saved at row: 111000


Processing news articles:  98%|█████████▊| 111510/113762 [57:32<10:35,  3.54it/s]

Checkpoint saved at row: 111500


Processing news articles:  98%|█████████▊| 112005/113762 [57:52<09:56,  2.95it/s]

Checkpoint saved at row: 112000


Processing news articles:  99%|█████████▉| 112504/113762 [58:11<08:05,  2.59it/s]

Checkpoint saved at row: 112500


Processing news articles:  99%|█████████▉| 113000/113762 [58:24<00:12, 60.21it/s]

Checkpoint saved at row: 113000


Processing news articles: 100%|█████████▉| 113511/113762 [58:51<01:17,  3.22it/s]

Checkpoint saved at row: 113500


Processing news articles: 100%|██████████| 113762/113762 [58:58<00:00, 32.15it/s]


Embeddings have been saved to 'news_embeddings.pkl'.


In [None]:
pickle_file_path = '/content/news_embeddings.pkl'

with open(pickle_file_path, 'rb') as f:
    # The pickle file contains only the news_embeddings_dict
    news_embeddings_dict = pickle.load(f)
    # Remove last_processed_index as it's not stored in the file
    # Initialize or get it from elsewhere if needed
    # For example, you could reset it to -1 or get it from the checkpoint file if available
    last_processed_index = -1

# Display the contents of the pickle file
print(f"Last processed index: {last_processed_index}")  # May need to adjust depending on how you set last_processed_index
print("First 10 Embeddings from the dictionary:\n")

# Print out the first 10 entries
for news_id, embedding in list(news_embeddings_dict.items())[:10]:
    print(f"News ID: {news_id}")
    print(f"Embedding (shape: {embedding.shape}): {embedding}\n")

  return torch.load(io.BytesIO(b))


Last processed index: -1
First 10 Embeddings from the dictionary:

News ID: N10000
Embedding (shape: torch.Size([128, 9])): tensor([[ 1.0791, -0.0520, -0.6545,  ..., -0.8640,  0.2703, -0.8104],
        [10.2101, -1.1307, -2.1981,  ..., -2.1493, -1.5064, -2.2739],
        [ 6.6266, -0.4424, -3.0224,  ..., -2.8820,  1.3735, -2.1363],
        ...,
        [10.3190, -1.6721, -2.4339,  ..., -2.0109, -1.7192, -1.7840],
        [ 1.1526,  1.0283, -2.8732,  ..., -3.0361,  0.5396, -2.9618],
        [ 1.0785, -0.0523, -0.6550,  ..., -0.8642,  0.2719, -0.8099]])

News ID: N10001
Embedding (shape: torch.Size([128, 9])): tensor([[ 1.0557, -0.0509, -0.6470,  ..., -0.8223,  0.2251, -0.8401],
        [-0.0704,  0.1551, -2.5323,  ..., -0.2067, -1.5623, -3.3790],
        [10.2061, -1.5499, -2.3800,  ..., -1.8061, -1.8805, -2.1766],
        ...,
        [10.5437, -1.2833, -2.3469,  ..., -2.1078, -1.5302, -2.4848],
        [10.6229, -1.5746, -2.2555,  ..., -1.8906, -1.7168, -2.2884],
        [ 1.0545, -0.

In [None]:
def distance(embedding1, embedding2):
    # Calculate the squared differences and take the mean
    squared_diff = np.square(embedding1 - embedding2)
    return np.mean(squared_diff)

In [None]:
def divergence(Dt1, Dt2, Ut1, Ut2):
    # Calculate the distances using the mean squared distance function
    document_dist1 = distance(Dt1, Ut1)
    document_dist2 = distance(Dt2, Ut2)
    doc_dist_t1_t2 = distance(Dt1, Dt2)
    user_dist_t1_t2 = distance(Ut1, Ut2)

    # Compute the divergence as per the given formula
    divergence_value = (document_dist1 - document_dist2) / (doc_dist_t1_t2 - user_dist_t1_t2)

    return divergence_value

In [None]:
with open(pickle_file_path, 'rb') as f:
    news_embeddings_data = pickle.load(f)

if isinstance(news_embeddings_data, dict):
    news_embeddings = news_embeddings_data
else:
    # If it's not a dictionary, process it accordingly (this depends on the data format)
    # For instance, if it contains a list of tuples, convert to a dictionary
    # Example: news_embeddings_data = [(news_id, embedding), (news_id, embedding), ...]
    news_embeddings = {item[0]: item[1] for item in news_embeddings_data}

# Now news_embeddings is a dictionary mapping news IDs to embeddings
print(news_embeddings)

Buffered data was truncated after reaching the output size limit.

In [None]:
dataset = pd.read_csv('/content/personalized_test.tsv', sep='\t', on_bad_lines='skip')

dataset.head(10)

In [None]:
# Iterate over each user trajectory
overall_divergence = 0
for index, row in dataset.iterrows():
    user_id = row['userid']
    clicked_news_id = row['clicknewsID']
    news_ids = row['posnewID']
    summaries = row['rewrite_titles']

    # Loop over each timestamp in the trajectory, assuming the embedding of user at t1 is summary embedding at t1
    for t1 in range(len(news_ids) - 1):
        # Get embeddings for documents
        Dt1 = news_embeddings[news_ids[t1]]
        Dt2 = news_embeddings[news_ids[t1 + 1]]

    '''
    for t1 in range(len(clicked_news_id) - 1):
        # Get embeddings for documents
        Dt1 = news_embeddings[clicked_news_id[t1]]
        Dt2 = news_embeddings[clicked_news_id[t1 + 1]]
    '''

        # Get user embeddings at time t1 and t2 (assume user embedding is the same as the summary embedding at that time)
        Ut1 = get_embeddings(summaries[t1])
        Ut2 = get_embeddings[summaries[t1 + 1]]

        # Calculate divergence using the given formula
        document_divergence = (divergence(Dt1, Ut1) - divergence(Dt2, Ut2)) / (divergence(Dt1, Dt2) - divergence(Ut1, Ut2))
        overall_divergence += document_divergence

# Normalize the overall divergence (by the number of trajectories, for example)
overall_divergence = 0.5 * overall_divergence / len(dataset)

# Output the final overall divergence
print("Overall divergence in the dataset:", overall_divergence)