In [1]:
import pandas as pd
import glob
import os
import json
from pyserini.search.lucene import LuceneSearcher
from pyserini.index.lucene import IndexReader
import tiktoken

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# encoding = 'cl100k_base'
encoding = 'p50k_base'
enc = tiktoken.get_encoding(encoding)
assert enc.decode(enc.encode("hello world")) == "hello world"

In [3]:
def tokenize(text):
    return ' '.join(map(str,enc.encode(text, disallowed_special=())))

In [8]:
!ls --color


README.md             hs_err_pid34690.log         scrape.py
bm25.ipynb            hs_err_pid73328.log         scrape_local.py
clone_repos.py        [0m[01;34mindex[0m                       temp.ipynb
clone_repos.sh        [01;34mlogs[0m                        temp.py
code_extensions.json  [01;34mmisc[0m                        test.json
[01;34mdata[0m                  plot.py                     test.parquet
[01;34mdata_api[0m              programming_languages.json  test_repos.txt
[01;34mdata_test[0m             [01;34mrepos[0m                       [01;34mtmp[0m
diff_json.py          requirements.txt            tmp.txt
hs_err_pid11351.log   run_scrape.sh               top_repos.txt
hs_err_pid11926.log   sbatch_scrape.sh


In [4]:
!ls -GFlash data/karpathy_llama2.c/

total 2.4M
   0 drwxr-xr-x  9 siddharth  288 Oct  3 01:48 ./
   0 drwxr-xr-x 17 siddharth  544 Oct  5 03:14 ../
   0 drwxr-xr-x  3 siddharth   96 Oct  3 01:45 jsonl/
868K -rw-r--r--  1 siddharth 868K Oct  2 22:02 karpathy_llama2.c_commit_data_0.parquet
616K -rw-r--r--  1 siddharth 615K Oct  2 22:02 karpathy_llama2.c_commit_data_1.parquet
372K -rw-r--r--  1 siddharth 372K Oct  2 22:03 karpathy_llama2.c_commit_data_2.parquet
284K -rw-r--r--  1 siddharth 283K Oct  2 22:03 karpathy_llama2.c_commit_data_3.parquet
240K -rw-r--r--  1 siddharth 238K Oct  2 22:03 karpathy_llama2.c_commit_data_4.parquet
   0 drwxr-xr-x 21 siddharth  672 Oct  3 00:52 searcher/


In [4]:
# Load the parquet file
df = pd.read_parquet('data/karpathy_llama2.c/karpathy_llama2.c_commit_data_0.parquet')

In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 85 entries, 0 to 84
Data columns (total 13 columns):
 #   Column                 Non-Null Count  Dtype              
---  ------                 --------------  -----              
 0   owner                  85 non-null     string             
 1   repo_name              85 non-null     string             
 2   commit_date            85 non-null     datetime64[ns, UTC]
 3   commit_id              85 non-null     string             
 4   commit_message         85 non-null     string             
 5   file_path              85 non-null     string             
 6   previous_commit_id     85 non-null     string             
 7   previous_file_content  82 non-null     string             
 8   cur_file_content       79 non-null     string             
 9   diff                   76 non-null     string             
 10  status                 85 non-null     category           
 11  is_merge_request       85 non-null     bool               
 

In [6]:
 # print just the memory usage in human readable format (MB) to 2 decimal places
print(f'{df.memory_usage(deep=True).sum() / 1024 ** 2:.2f} MB')

6.41 MB


In [7]:
print('Number of unique commits stored (others excluded for not being code commits):', df.commit_id.nunique())

Number of unique commits stored (others excluded for not being code commits): 72


In [31]:
# BASE_DIR = 'data/karpathy_llama2.c/'
REPO_LIST = ['karpathy_llama2.c', 'siddharth-gandhi_refpred', 'facebook_react', 'apache_kafka', 'ggerganov_llama.cpp', 'nodejs_node']
# REPO_LIST = ['karpathy_llama2.c', 'siddharth-gandhi_refpred', 'facebook_react', 'apache_kafka']

In [4]:
REPO_LIST = ['karpathy_llama2.c']

In [61]:
# def convert_data_to_jsonl(data_dir, output_file):
#     all_files = glob.glob(os.path.join(data_dir, '*.parquet'))
#     all_dataframes = [pd.read_parquet(file) for file in all_files]
#     combined_df = pd.concat(all_dataframes, ignore_index=True)
#     # replace NaN with empty string
#     combined_df.fillna('', inplace=True)

#     with open(output_file, 'w') as f:
#         for index, row in combined_df.iterrows():
#             doc = {
#                 'id': row['commit_id'],
#                 'contents': row['commit_message'] + '\n' + row['cur_file_content'],
#                 # Optionally include source code
#                 # 'source_code': row['cur_file_content']
#             }
#             f.write(json.dumps(doc) + '\n')

In [49]:
def count_commits(repo_dir):
    all_files = glob.glob(os.path.join(repo_dir, '*.parquet'))
    all_dataframes = [pd.read_parquet(file) for file in all_files]
    combined_df = pd.concat(all_dataframes, ignore_index=True)

    # number of unique commit_id columns
    return combined_df.commit_id.nunique()

In [50]:
total_commits = 0
for repo in REPO_LIST:
    total_commits += count_commits('data/' + repo + '/')

In [51]:
print('Total number of commits:', total_commits)

Total number of commits: 36731


In [33]:
def convert_repo_to_jsonl(repo_dir, output_file, use_tokenizer=False):
    all_files = glob.glob(os.path.join(repo_dir, '*.parquet'))
    all_dataframes = [pd.read_parquet(file) for file in all_files]
    combined_df = pd.concat(all_dataframes, ignore_index=True)
    # replace NaN with empty string in non-category columns
    # combined_df.fillna('', inplace=True)

    combined_df['commit_message'] = combined_df['commit_message'].fillna('')
    combined_df['cur_file_content'] = combined_df['cur_file_content'].fillna('')

    # print combined_df memory usage
    # print(combined_df.info(memory_usage='deep'))
    print(f'Combined Memory Usage: {combined_df.memory_usage(deep=True).sum() / 1024 ** 2:.2f} MB for {len(combined_df)} rows')
    print(output_file)
    with open(output_file, 'x') as f:
        for index, row in combined_df.iterrows():
            doc = {
                'id': row['commit_id'],
                # 'contents': row['commit_message'] + '\n' + row['cur_file_content'],
                # 'source_code': row['cur_file_content'],  # Optionally include source code
                # 'contents': tokenize(row['commit_message']) + '\n' + tokenize(row['cur_file_content']),
                'contents': tokenize(row['commit_message']) + '\n' + tokenize(row['cur_file_content']) if use_tokenizer else row['commit_message'] + '\n' + row['cur_file_content'],
                'repo_name': row['repo_name'],
                'file_path': row['file_path'],
            }
            f.write(json.dumps(doc) + '\n')

In [10]:
# empty data/jsonl if it has data
# !rm -rf data/jsonl_tiktoken

In [35]:
jsonl_dir_name = 'jsonl_6'
for repo_name in REPO_LIST:
    repo_dir = os.path.join('data', repo_name)
    # create data/jsonl directory if it doesn't exist
    os.makedirs(os.path.join('data', jsonl_dir_name), exist_ok=True)

    # store in data/jsonl
    output_jsonl_file = os.path.join('data', jsonl_dir_name, f'{repo_name}.jsonl')
    convert_repo_to_jsonl(repo_dir, output_jsonl_file)

Combined Memory Usage: 18.29 MB for 402 rows
data/jsonl_6/karpathy_llama2.c.jsonl
Combined Memory Usage: 0.94 MB for 108 rows
data/jsonl_6/siddharth-gandhi_refpred.jsonl
Combined Memory Usage: 2699.89 MB for 73551 rows
data/jsonl_6/facebook_react.jsonl
Combined Memory Usage: 3645.70 MB for 75870 rows
data/jsonl_6/apache_kafka.jsonl
Combined Memory Usage: 605.11 MB for 2111 rows
data/jsonl_6/ggerganov_llama.cpp.jsonl
Combined Memory Usage: 11010.96 MB for 208188 rows
data/jsonl_6/nodejs_node.jsonl


In [62]:
# Usage
# jsonl_file_path = f'{BASE_DIR}/jsonl/llama2.jsonl'
# convert_data_to_jsonl(BASE_DIR, jsonl_file_path)

In [121]:
# # get list of jsonl files which are present in data/repo_name/jsonl/repo_name.jsonl
# jsonl_files = glob.glob('data/*/*/*.jsonl')
# print(jsonl_files)

For normal untokenized
- Parquet -> JSONL 22s
- Index build 1m26s
- 6 repos
    Parquet -> JSONL 1m11s
    Same mem usage as before, just lower time since no need for tokenization
    Index Build 3m51s
    Index Size 5Gb

For tokenized
- Parquet -> JSONL 8m3s
- Index Build 2m12s
- 6 repos:
    Parquert -> JSONL 24m
        Combined Memory Usage: 18.29 MB for 402 rows data/isonl_6/karpathy_llama2.c.jsonl \\
        Combined Memory Usage: 0.94 MB for 108 rows data/json1_6/siddharth-gandhi_refpred.jsonl \\
        Combined Memory Usage: 2699.89 MB for 73551 rows data/jsonl_6/facebook_react.jsonl \\
        Combined Memory Usage: 3645.70 MB for 75870 rows data/jsonl_6/apache_kafka. jsonl \\
        Combined Memory Usage: 605.11 MB for 2111 rows data/jsonl_6/ggerganov_llama.cpp.jsonl \\
        Combined Memory Usage: 11010.96 MB for 208188 rows data/jsonl_6/nodejs_node.json
        36731 total commits 
        Total ~360K rows
        Interesting heuristic, on avg 10 files edited per commit?
    Index build 6m42s
    Index Size 10GB

In [10]:
!ls -GFlash .

total 976128
     0 drwxr-xr-x@ 46 siddharth  staff   1.4K Oct  5 03:15 [1m[36m.[m[m/
     0 drwxrwxrwx  25 siddharth  staff   800B Oct  4 18:13 [30m[43m..[m[m/
    24 -rw-r--r--@  1 siddharth  staff    10K Oct  3 07:52 .DS_Store
     8 -rw-r--r--@  1 siddharth  staff    62B Sep 17 22:59 .env
     0 drwxr-xr-x@ 14 siddharth  staff   448B Oct  3 03:45 [1m[36m.git[m[m/
     8 -rw-r--r--@  1 siddharth  staff   3.6K Oct  3 03:45 .gitignore
     0 drwxr-xr-x@  9 siddharth  staff   288B Oct  3 00:54 [1m[36m.idea[m[m/
     0 drwxr-xr-x@  4 siddharth  staff   128B Sep 23 17:54 [1m[36m.vscode[m[m/
     8 -rw-r--r--@  1 siddharth  staff   542B Oct  2 17:45 README.md
   160 -rw-r--r--@  1 siddharth  staff    77K Oct  5 03:15 bm25.ipynb
     0 drwxr-xr-x@ 72 siddharth  staff   2.3K Oct  3 03:20 [1m[36mbm25_index[m[m/
     0 drwxr-xr-x@ 72 siddharth  staff   2.3K Oct  3 14:40 [1m[36mbm25_index_6[m[m/
     0 drwxr-xr-x@ 72 siddharth  staff   2.3K Oct  3 03:33 [1m[36mbm2

In [6]:
%%bash
# Directory to store the index
# index_dir="./bm25_index_6/"
# jsonl_dir_name="jsonl_6"

index_dir="./idx_karpathy/"
# jsonl_dir_name="jsonl_tiktoken_6"
jsonl_dir_name="test_dir"

# Create the directory if it doesn't exist
mkdir -p "$index_dir"

# Remove any existing indexes
rm -rf "$index_dir/*"

# build the index from data/jsonl
python -m pyserini.index.lucene -collection JsonCollection -generator DefaultLuceneDocumentGenerator \
 -threads 4 -input data/"$jsonl_dir_name"/ -index "$index_dir" -storePositions -storeDocvectors -storeRaw -impact -pretokenized

2023-10-05 03:15:51,452 INFO  [main] index.IndexCollection (IndexCollection.java:380) - Setting log level to INFO
2023-10-05 03:15:51,453 INFO  [main] index.IndexCollection (IndexCollection.java:383) - Starting indexer...
2023-10-05 03:15:51,453 INFO  [main] index.IndexCollection (IndexCollection.java:385) - DocumentCollection path: data/test_dir/
2023-10-05 03:15:51,453 INFO  [main] index.IndexCollection (IndexCollection.java:386) - CollectionClass: JsonCollection
2023-10-05 03:15:51,453 INFO  [main] index.IndexCollection (IndexCollection.java:387) - Generator: DefaultLuceneDocumentGenerator
2023-10-05 03:15:51,453 INFO  [main] index.IndexCollection (IndexCollection.java:388) - Threads: 4
2023-10-05 03:15:51,454 INFO  [main] index.IndexCollection (IndexCollection.java:389) - Language: en
2023-10-05 03:15:51,454 INFO  [main] index.IndexCollection (IndexCollection.java:390) - Stemmer: porter
2023-10-05 03:15:51,454 INFO  [main] index.IndexCollection (IndexCollection.java:391) - Keep sto

In [117]:
# !python -m pyserini.index.lucene \
#   --collection JsonCollection \
#   --input data/karpathy_llama2.c/jsonl/ \
#   --index data/karpathy_llama2.c/searcher/ \
#   --generator DefaultLuceneDocumentGenerator \
#   --threads 1 \
#   --storePositions --storeDocvectors --storeRaw

6ce91b1b3b56ff7d43d894c204f965bfbf5d63c9

In [57]:
# llama2.c
# query = 'nInference for Llama-2 Transformer model in pure C'

# refpred
# query = 'if is_arxiv:\n return f"https://api.semanticscholar.org/graph/v1/paper/arXiv:{paper_id}/references?fields=title,
# abstract,url,venue,publicationVenue,year,referenceCount,citationCount,influentialCitationCount,isOpenAccess'

# react
# query = "export {default} from './npm/Circle';"

# kafka
# public class MockKafkaLog4jAppender extends KafkaLog4jAppender {
#     private MockProducer<byte[], byte[]> mockProducer =
#             new MockProducer<>(false, new MockSerializer(), new MockSerializer());

#     private Properties producerProperties;

#     @Override
#     protected Producer<byte[], byte[]> getKafkaProducer(Properties props) {
#         producerProperties = props;
#         return mockProducer;
#     }

#     void setKafkaProducer(MockProducer<byte[], byte[]> producer) {
#         this.mockProducer = producer;
#     }
# """

# Kakfa
# query = """
# /**
#  * Local file based quorum state store. It takes the JSON format of {@link QuorumStateData}
#  * with an extra data version number as part of the data for easy deserialization.
#  *
#  * Example format:
#  * <pre>
#  * {"clusterId":"",
#  *   "leaderId":1,
#  *   "leaderEpoch":2,
#  *   "votedId":-1,
#  *   "appliedOffset":0,
#  *   "currentVoters":[],
#  *   "data_version":0}
#  * </pre>
#  * */

# """

# kakfa
query = """Convert coordinator retriable errors to a known producer…
… response error (#14378)

KIP-890 Part 1 tries to address hanging transactions on old clients. Thus, the produce version can not be bumped and no new errors can be added. Before we used the java client's notion of retriable and abortable errors -- retriable errors are defined as such by extending the retriable error class, fatal errors are defined explicitly, and abortable errors are the remaining. However, many other clients treat non specified errors as fatal and that means many retriable errors kill the application."""

# kakfa
# query = """Fix flaky TopicAdminTest::retryEndOffsetsShouldRetryWhenTopicNotFound test case"""

# nodejs
# query = """bool ShouldAbortOnUncaughtException(Isolate* isolate) {
#   DebugSealHandleScope scope(isolate);
#   Environment* env = Environment::GetCurrent(isolate);
#   return env != nullptr &&
#          (env->is_main_thread() || !env->is_stopping()) &&
#          env->abort_on_uncaught_exception() &&
#          env->should_abort_on_uncaught_toggle()[0] &&
#          !env->inside_should_not_abort_on_uncaught_scope();
# }"""

In [58]:
bm25searcher = LuceneSearcher('bm25_index_6/')
hits = bm25searcher.search(query, k=10)
# print(hits[0])
for i in range(len(hits)):
    # print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
    # print with repo name and file name
    obj = json.loads(hits[i].raw)
    print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f} {obj["repo_name"]}/{obj["file_path"]}')

 1 5aecd2825644728f68a26558c957f5dfd4643423 99.51060 kafka/core/src/main/scala/kafka/server/ReplicaManager.scala
 2 29a1a16668d76a1cc04ec9e39ea13026f2dce1de 82.57980 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 3 5aad085a8e7514c14a17121d316a2e2b2add8bcc 81.72260 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 4 5aecd2825644728f68a26558c957f5dfd4643423 81.36090 kafka/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
 5 ef09a2e3fc11a738f6681fd57fb84ad109593fd3 80.57710 kafka/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
 6 f5d5f654db359af077088685e29fbe5ea69616cf 79.69870 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 7 2b6365c78b6e659f8df0651a24013d028f39edd9 79.64400 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 8 ff77b3ad041c1a4c80119f960e1f

In [13]:
index_reader = IndexReader('idx_karpathy/')
index_reader.stats()

{'total_terms': 696778,
 'documents': 402,
 'non_empty_documents': 402,
 'unique_terms': 6840}

In [14]:
from pyserini.index import IndexReader

In [18]:
index_reader = IndexReader('idx_karpathy_double_token/')

In [19]:
index_reader.dump_documents_BM25('tmp/idx_karpathy_double.jsonl')

100%|██████████| 402/402 [00:02<00:00, 190.05it/s]


In [12]:
index_reader = IndexReader('idx_karpathy_double_token/')
index_reader.stats()

{'total_terms': 578447,
 'documents': 402,
 'non_empty_documents': 402,
 'unique_terms': 3034}

In [59]:
tiktoken_searcher = LuceneSearcher('bm25_index_tiktoken_6/')
# get tokenized query with enc.encode
tokeninzed_query = tokenize(query)
hits = tiktoken_searcher.search(tokeninzed_query, k=10)
# print(hits[0])
for i in range(len(hits)):
    # print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')
    # print with repo name and file name
    obj = json.loads(hits[i].raw)
    print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f} {obj["repo_name"]}/{obj["file_path"]}')

 1 5aecd2825644728f68a26558c957f5dfd4643423 141.63670 kafka/core/src/main/scala/kafka/server/ReplicaManager.scala
 2 5aecd2825644728f68a26558c957f5dfd4643423 112.99820 kafka/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
 3 5aad085a8e7514c14a17121d316a2e2b2add8bcc 111.59350 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 4 ff77b3ad041c1a4c80119f960e1f87c07b9e93dd 111.57550 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 5 29a1a16668d76a1cc04ec9e39ea13026f2dce1de 110.54000 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 6 ea0bb001262320bc9233221955a2be31c85993b9 109.68660 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 7 f5d5f654db359af077088685e29fbe5ea69616cf 109.62250 kafka/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 8 b937ec7567

In [46]:
tiktoken_index_reader = IndexReader('bm25_index_tiktoken_6/')
tiktoken_index_reader.stats()

{'total_terms': 2698903862,
 'documents': 360230,
 'non_empty_documents': 360230,
 'unique_terms': -1}

In [47]:
# print the document source code inside the first hit raw
content = json.loads(hits[0].raw)['contents']

# print the document source code inside the first hit raw by decoding the tokenized string with enc.decode (convert to array of int and then decode)
# print(enc.decode(json.loads(hits[0].raw)['contents']))

# convert content to array of int
content_arr = [int(i) for i in content.split()]

In [48]:
print(enc.decode(content_arr))

worker: fix --abort-on-uncaught-exception handling

The `set_abort_on_uncaught_exception(false)` line was supposed to
prevent aborting when running Workers in
`--abort-on-uncaught-exception` mode, but it was incorrectly set
and not checked properly in the should-abort callback.

PR-URL: https://github.com/nodejs/node/pull/34724
Reviewed-By: Colin Ihrig <cjihrig@gmail.com>
Reviewed-By: Richard Lau <riclau@uk.ibm.com>
Reviewed-By: James M Snell <jasnell@gmail.com>
Reviewed-By: Mary Marchini <oss@mmarchini.me>

#include "node.h"
#include "node_context_data.h"
#include "node_errors.h"
#include "node_internals.h"
#include "node_native_module_env.h"
#include "node_platform.h"
#include "node_v8_platform-inl.h"
#include "uv.h"

#if HAVE_INSPECTOR
#include "inspector/worker_inspector.h"  // ParentInspectorHandle
#endif

namespace node {
using errors::TryCatchScope;
using v8::Array;
using v8::Context;
using v8::EscapableHandleScope;
using v8::Function;
using v8::FunctionCallbackInfo;
using v8::H

In [153]:
msg = '''
A crawler for the Semantic Scholar API.

import asyncio
import json
import logging
import logging.config
import sys
import time
from dataclasses import dataclass, field
from typing import Dict, List, Set

import httpx  # https://github.com/encode/httpx
import requests
from crawl_utils import get_batch_url, get_reference_url
from db import MongoDBClient

from config import S2_API_KEY, S2_RATE_LIMIT

logging.config.fileConfig(fname="logging.conf", disable_existing_loggers=False)
logger = logging.getLogger(__name__)

class RateLimitExceededException(Exception):
    """Exception raised when rate limit is exceeded"""

    def __init__(self, message):
        self.message = message

    def __str__(self):
        return f"RateLimitExceededException: {self.message}"


class TimeoutException(Exception):
    """Exception raised when a request times out"""
    def __init__(self, message):
        self.message = message

    def __str__(self):
        return f"TimeoutException: {self.message}"


@dataclass()
class Crawler:
    """A crawler for the Semantic Scholar API"""

    client: httpx.AsyncClient = field(repr=False)
    initial_papers: List[str] = field(default_factory=list)
    num_workers: int = 10
    max_papers: int = 100
    mongodb_client: MongoDBClient = field(default_factory=MongoDBClient)
    headers: dict = field(repr=False, default_factory=dict)
    todo: asyncio.Queue = field(init=False, repr=False, default_factory=asyncio.Queue)
    seen: Set[str]= field(init=False, default_factory=set)
    done: Set[str] = field(init=False, default_factory=set)
    retry: Dict[str, int] = field(init=False, default_factory=dict)
    total: int = field(init=False, default=0)
    MAX_RETRIES: int = field(init=False, default=3)

    @classmethod
    def from_dict(cls, settings: dict) -> "Crawler":
        """
        Create a Crawler instance from a dict of settings"""
        return cls(**settings)

    async def run(self) -> None:
        """Run the crawler by creating workers until todo queue is empty"""
        self.init_done()
        await self.init_queue()
        workers = [asyncio.create_task(self.worker()) for _ in range(self.num_workers)]
        await self.todo.join()
        for worker in workers:
            worker.cancel()

    async def init_queue(self) -> None:
        """Initialize the queue with the initial papers"""
        batch_url = get_batch_url()
        data = json.dumps({"ids": self.initial_papers})
        response = requests.post(url=batch_url, data=data, headers=self.headers, timeout=10)
        # initial_paper_id = self.initial_papers[0]
        # initial_url = get_paper_url(initial_paper_id)
        # response = requests.get(initial_url, headers=self.headers, timeout=10)
        if response.status_code != 200:
            logger.error("Error fetching initial papers")
            sys.exit(1)
        logger.debug(f"Fetching data for intial papers {self.initial_papers}")
        result_data = response.json()
        # result_data["_id"] = result_data["paperId"]
        for paper in result_data:
            paper["_id"] = paper["paperId"]
        # prime the queue
        await self.on_found_papers(result_data, initial=True)

    def init_done(self) -> None:
        """Initialize the seen set with already stored papers from DB"""
        # self.seen = set(self.initial_papers)
        self.done = self.mongodb_client.get_ids()
        logger.info(f"Already stored {len(self.done)} papers")

    async def worker(self) -> None:
        """One worker processes one paper at a time from the queue in a loop until cancelled"""
        while True:
            try:
                await self.process_one()
            except asyncio.CancelledError:
                return

    async def retry_crawl(self, paper) -> None:
        """Retry crawling a paper in case of an exception"""
        if paper["_id"] in self.retry and self.retry[paper["_id"]] > self.MAX_RETRIES:
            logger.error(f"Error processing {paper['_id']} even after retrying {self.MAX_RETRIES} times")
            return
        # self.retry.add(paper["_id"])
        self.retry[paper["_id"]] = self.retry.get(paper["_id"], 0) + 1
        logger.info(f"Retry #{self.retry[paper['_id']]} for {paper['_id']}")
        # await self.todo.put_nowait(cur_paper)
        await asyncio.sleep(1)
        await self.crawl(paper)

    async def process_one(self) -> None:
        """Gets one paper from the queue and processes it"""
        # cur_paper is a dict
        cur_paper = await self.todo.get()
        try:
            await self.crawl(cur_paper)
        except TimeoutException as te:
            # logger.warning(f"Timeout for {cur_paper['_id']}")
            logger.warning(te)
            await self.retry_crawl(cur_paper)
        except RateLimitExceededException as rlee:
            logger.critical("Rate limit exceeded, retrying in 2 second")
            logger.critical(rlee)
            await asyncio.sleep(2)
            await self.retry_crawl(cur_paper)
        finally:
            self.todo.task_done()

    async def crawl(self, cur_paper: dict) -> None:
        """
        Crawl a paper and its references, stores them in the database.
        """
        # TODO proper rate limiting to 100 requests / second
        # await asyncio.sleep(1 / self.num_workers)
        await asyncio.sleep(1)

        cur_paper_id = cur_paper["paperId"]
        ref_url = get_reference_url(cur_paper_id)
        cur_paper["_id"] = cur_paper_id
        if cur_paper["title"] is None or cur_paper["abstract"] is None:
            logger.debug(f"Skipping {cur_paper_id} as empty title or abstract")
            # I have no clue why this total -= 1 is here, it shouldn't be required, but crawler just prematurely stops
            self.total -= 1
            return
        # async with self.semaphore:
        # async with self.client.get(ref_url, headers=self.headers) as response:

        response = await self.client.get(ref_url, headers=self.headers)

        # if self.semaphore.locked():
        #     logger.warning(f"Semaphore locked for {cur_paper_id}")
        #     await asyncio.sleep(1)

        if response.status_code == 429:
            # logger.critical(
            #     f"Rated limited for {cur_paper_id} - {response.status_code}"
            # )
            # # await self.todo.put_nowait(cur_paper)
            # await asyncio.sleep(1)
            # await self.crawl(cur_paper)
            raise RateLimitExceededException(
                f"Rated limited for {cur_paper_id} - {response.status_code}"
            )

        if response.status_code == 504:
            # raise asyncio.exceptions.TimeoutError(
            #     f"Timeout for {cur_paper_id} - {response.status_code}"
            # )
            raise TimeoutException(f"Timeout for {cur_paper_id} - {response.status_code}")

        if response.status_code != 200:
            logger.error(f"Error fetching references for {cur_paper_id} - {response.status_code}")
            return

        logger.debug(f"Fetching references for {cur_paper_id} - {response.status_code}")

        result_data = response.json()
        found_references = result_data["data"]
        found_references = [ref["citedPaper"] for ref in found_references]
        found_references = sorted(found_references, key=lambda x: x["citationCount"] or 0, reverse=True)
        ref_ids = [ref["paperId"] for ref in found_references if ref["paperId"] is not None]
        cur_paper["references"] = ref_ids
        cur_paper["allReferencesStored"] = True
        if len(ref_ids) != cur_paper["referenceCount"]:
            cur_paper["allReferencesStored"] = False

        # self.collection.insert_one(cur_paper)
        self.mongodb_client.insert_one(cur_paper)
        self.done.add(cur_paper["paperId"])
        # self.stored += 1
        # if self.stored % 100 == 0:
        #     logger.info(f"Stored {self.stored} papers")

        await self.on_found_papers(found_references)

    # async def get_paper_references(self, base: str, text: str) -> set[str]:
    #     parser = UrlParser(base, self.filter_url)
    #     parser.feed(text)
    #     return parser.found_references

    async def on_found_papers(self, papers: List[dict], initial: bool = False) -> None:
        """
        Called when new papers are found. Filters out papers that have already been seen and puts the new ones in the queue.
        """
        if initial:
            for paper in papers:
                await self.put_todo(paper)
            return
        ids = {paper["paperId"] for paper in papers if paper["paperId"] is not None}
        new = ids - self.seen
        self.seen.update(new)

        for paper in papers:
            if paper["paperId"] in new:
                await self.put_todo(paper)

    async def put_todo(self, paper: dict) -> None:
        """Put a paper in the queue"""
        # paper is a dict with fields like paper_id, title, abstract, etc.
        if self.total >= self.max_papers:
            return
        self.total += 1
        await self.todo.put(paper)


async def main() -> None:
    """Main function"""
    start = time.perf_counter()
    headers={
        "Content-type": "application/json",
        "x-api-key": S2_API_KEY,
    }
    mongodb_client = MongoDBClient(mongo_url='mongodb://localhost:27017', db_name='refpred', collection_name='review3_demo', init_new=True)
    timeout = httpx.Timeout(10, connect=10, read=None, write=10)
    # based on https://towardsdatascience.com/top-10-research-papers-in-ai-1f02cf844e26
    initial_papers = ["204e3073870fae3d05bcbc2f6a8e263d9b72e776", "bee044c8e8903fb67523c1f8c105ab4718600cdb", "36eff562f65125511b5dfab68ce7f7a943c27478", "8388f1be26329fa45e5807e968a641ce170ea078", "846aedd869a00c09b40f1f1f35673cb22bc87490", "e0e9a94c4a6ba219e768b4e59f72c18f0a22e23d", "fa72afa9b2cbc8f0d7b05d52548906610ffbb9c5", "424561d8585ff8ebce7d5d07de8dbf7aae5e7270", "4d376d6978dad0374edfa6709c9556b42d3594d3", "a6cb366736791bcccc5c8639de5a8f9636bf87e8", "df2b0e26d0599ce3e70df8a9da02e51594e0e992", "913f54b44dfb9202955fe296cf5586e1105565ea", "156d217b0a911af97fa1b5a71dc909ccef7a8028", "a3e4ceb42cbcd2c807d53aff90a8cb1f5ee3f031", "5c5751d45e298cea054f32b392c12c61027d2fe7", "bc1586a2e74d6d1cf87b083c4cbd1eede2b09ea5", "921b2958cac4138d188fd5047aa12bbcf37ac867", "cb92a7f9d9dbcf9145e32fdfa0e70e2a6b828eb1"]
    MAX_PAPERS = 10000
    async with httpx.AsyncClient(timeout=timeout) as client:
        # starting with the famous paper 'Attention is all you need'
        crawler = Crawler(
            client=client,
            initial_papers=initial_papers,
            num_workers=S2_RATE_LIMIT,
            max_papers=MAX_PAPERS,
            mongodb_client=mongodb_client,
            headers=headers,
        )
        await crawler.run()
    end = time.perf_counter()

    logger.info("Results:")
    logger.info(f"Crawled: {len(crawler.done)} Papers")
    logger.info(f"Found: {len(crawler.seen)} Papers")
    logger.info(f"Done in {end - start:.2f}s")


if __name__ == "__main__":
    asyncio.run(main())


# TODO
# 1. Batch processing of seed papers
# 2. Initialize seen from dataset to avoid restarting over
# 3. Null abstract papers need to be removed from the dataset ✅'''

In [157]:
simple_msg = '''# A simple hello world program in python with docstring
def hello_world():
    """A simple hello world program in python with docstring"""
    print("Hello World!")'''

In [159]:
# see tokenized output of the above code
enc.encode(simple_msg)[:10]

[2, 362, 4382, 24748, 1917, 2068, 304, 10344, 449, 4733]

In [161]:
# see tokenized output of the above code
enc.encode(simple_msg)[:10]

[2, 317, 2829, 23748, 995, 1430, 287, 21015, 351, 2205]