# **Building Gemma Research Assistant**

<img src="../figures/GemmaAIO-main-image.webp" alt="main-image"/>

Hey everyone, welcome to this notebook where we're diving into something super cool: building an all-in-one research chatbot using the power of the Gemma Large Language Model! üöÄ

Here's the game plan:

- First up, Section 1: We're kicking things off with a research paper query engine. Imagine being able to find any research paper with just a simple chat. Sounds handy, right?

- Moving on to Section 2: We'll spice things up with a graph paper relationship engine. This is all about connecting the dots between different papers and seeing the bigger picture.

- Section 3: We'll add a basic data science assistant to our toolkit. This chatbot will help with all those tricky data questions, from stats to machine learning.

- Section 4 is for the coders: We're building an AI code assistant that's going to be like your coding sidekick, helping you solve problems and understand complex codes.

- And for the grand finale, Section 5: We're bringing it all together with a combination module. This is where we make sure everything works in harmony, giving you a powerhouse tool for any research or coding project.

So, let's roll up our sleeves and jump into this exciting project. An overview of this project is below: üåü

<img src="../figures/RAG%20-%20Scientific%20Assistant%20-%20Frame%201.jpg" alt="pipeline" width=800/>

# **1. Scientific Research Assistant**

In this section, we're focusing on creating the first part of our chatbot: a tool that can search through a huge number of research papers on arXiv. The key to this tool is using embeddings, taken from paper abstracts. Think of these as unique IDs that sum up what each paper is about.

When you ask the chatbot something, it uses these embeddings to look through the abstracts and find papers that really match what you're looking for, not just by keywords, but by the actual ideas and concepts you're interested in. This is more about understanding the meaning of your question and finding papers that really match.

We'll go through everything: picking the right papers from arXiv, getting the abstracts ready, and choosing a way to turn these abstracts into embeddings. Then, we'll set up a smart search that can quickly find the best matches when you ask a question.

<img src="../figures/Science-Paper-Search.jpg" alt="science paper search" width=800/>

## **1.1 Data Preprocessing**

In [1]:
# https://www.kaggle.com/code/matthewmaddock/nlp-arxiv-dataset-transformers-and-umap

# This takes about 1 minute.
import json
import pandas as pd

cols = ['id', 'title', 'abstract', 'categories']
data = []
file_name = '../data/arxiv-metadata-oai-snapshot.json'


with open(file_name, encoding='latin-1') as f:
    for line in f:
        doc = json.loads(line)
        lst = [doc['id'], doc['title'], doc['abstract'], doc['categories']]
        data.append(lst)

df_data = pd.DataFrame(data=data, columns=cols)

print(df_data.shape)

df_data.head()

(2455227, 4)


Unnamed: 0,id,title,abstract,categories
0,704.0001,Calculation of prompt diphoton production cros...,A fully differential calculation in perturba...,hep-ph
1,704.0002,Sparsity-certifying Graph Decompositions,"We describe a new algorithm, the $(k,\ell)$-...",math.CO cs.CG
2,704.0003,The evolution of the Earth-Moon system based o...,The evolution of Earth-Moon system is descri...,physics.gen-ph
3,704.0004,A determinant of Stirling cycle numbers counts...,We show that a determinant of Stirling cycle...,math.CO
4,704.0005,From dyadic $\Lambda_{\alpha}$ to $\Lambda_{\a...,In this paper we show how to compute the $\L...,math.CA math.FA


There are a total of almost 2,5M papers on arxiv, that's too much! However, not all of them are about AI, so let's narrow down to the topics we're interested in.

In [3]:
import pandas as pd

topics = ['cs.AI', 'cs.CV', 'cs.IR', 'cs.LG', 'cs.CL']

# Create a regular expression pattern that matches any of the topics
# The pattern will look like 'cs.AI|cs.CV|cs.IR|cs.LG|cs.CL'
pattern = '|'.join(topics)

# Filter the DataFrame to include rows where the 'categories' column contains any of the topics
# The na=False parameter makes sure that NaN values are treated as False
df_filtered = df_data[df_data['categories'].str.contains(pattern, na=False)]

# Display the filtered DataFrame
df_filtered

Unnamed: 0,id,title,abstract,categories
46,0704.0047,Intelligent location of simultaneously active ...,The intelligent acoustic emission locator is...,cs.NE cs.AI
49,0704.0050,Intelligent location of simultaneously active ...,Part I describes an intelligent acoustic emi...,cs.NE cs.AI
303,0704.0304,The World as Evolving Information,This paper discusses the benefits of describ...,cs.IT cs.AI math.IT q-bio.PE
670,0704.0671,Learning from compressed observations,The problem of statistical learning is to co...,cs.IT cs.LG math.IT
953,0704.0954,Sensor Networks with Random Links: Topology De...,"In a sensor network, in practice, the commun...",cs.IT cs.LG math.IT
...,...,...,...,...
2443613,quant-ph/0411140,Improved Bounds on Quantum Learning Algorithms,In this article we give several new results ...,quant-ph cs.LG
2445483,quant-ph/0507231,Algebras of Measurements: the logical structur...,"In Quantum Physics, a measurement is represe...",quant-ph cs.AI
2448330,quant-ph/0607111,`Plausibilities of plausibilities': an approac...,Probability-like parameters appearing in som...,quant-ph cs.AI
2450042,quant-ph/0702072,Markovian Entanglement Networks,Graphical models of probabilistic dependenci...,quant-ph cs.AI


Great! Now we down to about 330K papers. Now, let's clean the text.

In [6]:
df_filtered.iloc[110]

id                                                    0707.0705
title         Optimal Solutions for Sparse Principal Compone...
abstract        Given a sample covariance matrix, we examine...
categories                                          cs.AI cs.LG
Name: 13875, dtype: object

In [7]:
def clean_text(x):
    
    # Replace newline characters with a space
    new_text = " ".join([c.strip() for c in x.replace("\n", "").split()])
    # Remove leading and trailing spaces
    new_text = new_text.strip()
    
    return new_text

df_data['title'] = df_data['title'].apply(clean_text)
df_data['abstract'] = df_data['abstract'].apply(clean_text)

df_data['prepared_text'] = df_data['title'] + '\n ' + df_data['abstract']
df_data.head()

Unnamed: 0,id,title,abstract,categories,prepared_text
46,704.0047,Intelligent location of simultaneously active ...,The intelligent acoustic emission locator is d...,cs.NE cs.AI,Intelligent location of simultaneously active ...
49,704.005,Intelligent location of simultaneously active ...,Part I describes an intelligent acoustic emiss...,cs.NE cs.AI,Intelligent location of simultaneously active ...
303,704.0304,The World as Evolving Information,This paper discusses the benefits of describin...,cs.IT cs.AI math.IT q-bio.PE,The World as Evolving Information\n This paper...
670,704.0671,Learning from compressed observations,The problem of statistical learning is to cons...,cs.IT cs.LG math.IT,Learning from compressed observations\n The pr...
953,704.0954,Sensor Networks with Random Links: Topology De...,"In a sensor network, in practice, the communic...",cs.IT cs.LG math.IT,Sensor Networks with Random Links: Topology De...


In [8]:
from llama_index.core import Document

arxiv_documents = [Document(text=prepared_text, doc_id=id) for prepared_text,id in list(zip(df_data['prepared_text'], df_data['id']))]

  from .autonotebook import tqdm as notebook_tqdm


## **1.2 Creating Index**

The `VectorStoreIndex` is by far the most frequently used type of Index in llamaindex. This class takes your Documents and splits them up into Nodes. Then, it creates `vector_embeddings` of the text of every node. But what is `vector_embedding`?

Vector embeddings are like turning the essence of your words into a mathematical sketch. Imagine every idea or concept in your text getting its unique numerical fingerprint. This is handy because even if two snippets of text use different words, if they're sharing the same idea, their numerical sketches‚Äîor embeddings‚Äîwill be close neighbors in the numerical space. This magic is done using tools known as embedding models.

Choosing the right embedding model is crucial. It's like picking the right artist to paint your portrait; you want the one who captures you best. A great place to start is the MTEB leaderboard, where the cr√®me de la cr√®me of embedding models are ranked. As we have quite a large dataset, the model size matters, we don't want to wait all day for the model to extract all the vector embeddings. When I last checked, the `BAAI/bge-small-en-v1.5` model was leading the pack, especially considering its size. It could be a solid choice if you're diving into the world of text embeddings.


In [62]:
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings
import chromadb
import torch
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext

Settings.llm = None
# Create embed model
device_type = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5", cache_folder="../models", device=device_type)

LLM is explicitly disabled. Using MockLLM.


Great! Now we have to find somewhere to store all of the embeddings extracted by the model, and that's why we need a `vector store`. There are many to choose from, in this tutorial, I will choose the `chroma` vector store

In [68]:
chroma_client = chromadb.PersistentClient(path="../DB/arxiv")
chroma_collection = chroma_client.get_or_create_collection("gemma_assistant_arxiv_papers")


# Create vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

This part takes quite a lot of time! So I precomputed the embedding and store them into chroma db

In [9]:
# index = VectorStoreIndex.from_documents(
#     arxiv_documents, storage_context=storage_context, embed_model=embed_model, show_progress=True
# )

## **1.3 Loading from arxiv vector store**

In [10]:
from llama_index.core import VectorStoreIndex, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext
import torch


Settings.llm = None # Set this to none to make the index only do retrieval
device_type = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5", cache_folder="../models", device=device_type) # must be the same as the previous stage

chroma_client = chromadb.PersistentClient(path="../DB/arxiv")
chroma_collection = chroma_client.get_or_create_collection("gemma_assistant_arxiv_papers")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# load the vectorstore
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_vector_store(vector_store, storage_context=storage_context, embed_model=embed_model)

LLM is explicitly disabled. Using MockLLM.


In [11]:
paper_query_engine = index.as_query_engine(
    similarity_top_k=10,
)

In [14]:
print(paper_query_engine.query("What are some papers about image generation?"))

Context information is below.
---------------------
deep image synthesis from intuitive user input: a review and perspectives
 In many applications of computer graphics, art and design, it is desirablefor a user to provide intuitive non-image input, such as text, sketch, stroke,graph or layout, and have a computer system automatically generatephoto-realistic images that adhere to the input content. While classic worksthat allow such automatic image content generation have followed a framework ofimage retrieval and composition, recent advances in deep generative models suchas generative adversarial networks (GANs), variational autoencoders (VAEs), andflow-based methods have enabled more powerful and versatile image generationtasks. This paper reviews recent works for image synthesis given intuitive userinput, covering advances in input versatility, image generation methodology,benchmark datasets, and evaluation metrics. This motivates new perspectives oninput representation and interact

# **2. Graph-based paper relationship search**

In this section, we dive into constructing a knowledge graph about the relationships of papers. This graph could be used for interactive visualization, searching relationships between papers (e.g. How is paper A related to paper B), or search for a specific relationship in a paper (e.g. What are works that paper A based on?). The steps of constructing this knowledge graph are:


- Step 1: arXiv Data Extraction: The process starts with academic papers from the arXiv database, which undergo OCR (Optical Character Recognition) and PDF parsing, which organizes the content into structured data such as the title, abstract, sections, and references of the papers. 

- Step 2: Text Splitter: The text in each section is then processed by a Text Splitter, which split the paper section into smaller chunks, which could be easier for LLMs to process.  

- Step 3 GPT-3.5 Processing: Gemma couldn't generate the knowledge graph out-of-the-box. So we need knowledge distillation from a bigger model, which I choose GPT-3.5. The structured data is passed to GPT-3.5 to extract citation relationships, such as "Data Source", "Extension", or "Theoretical Foundation", etc. Each relationship is paired with a dense explanation. I extracted a total of ~300 papers, which cost around 4$.

- Step 4 Training Gemma - 7B: The distilled knowledge data are then used to train Gemma-7b. Then I use this model to generate citation relationships for as many papers as I can. In total, I extracted 7k papers, with around 150K triplets! Crazy!!

- Step 5 Graph Store: Finally, a Graph Store is created containing 7K papers and 586K triplets. This could then be used for searching relationships or visualization.


<img src="../figures/Graph-Paper-Search.jpg" alt="graph-search" width=1200/>

## **2.1 Download pre-extracted citation data**

In [80]:
from datasets import load_dataset

parsed_article = load_dataset("BachNgoH/ParsedArxivPapers")['train']

In [81]:
parsed_article = parsed_article.to_list()

In [84]:
import json

for article in parsed_article:
    if article['citation_data'] != None:
        article['citation_data'] = json.loads(article['citation_data'])

In [86]:
parsed_article[104]['citation_data']

[{'Category': 'Supporting Evidence',
  'Citation': '(Tatman, 2017)',
  'Explanation': 'The cited work by Tatman (2017) provides evidence of the potential impact of demographic biases in NLP models, highlighting the importance of addressing fairness in the field.'},
 {'Category': 'Supporting Evidence',
  'Citation': '(Perez, 2019)',
  'Explanation': 'The cited work by Perez (2019) emphasizes the need for accurate identification of speakers and their needs in NLP models, further supporting the importance of fairness in the field.'},
 {'Category': 'Supporting Evidence',
  'Citation': '(Agarwal et al., 2019)',
  'Explanation': 'The cited work by Agarwal et al. (2019) highlights the issue of hurtful stereotypes in NLP models, emphasizing the need for fairness in the field to address this problem.'},
 {'Category': 'Supporting Evidence',
  'Citation': '(Nozza et al., 2022)',
  'Explanation': 'The cited work by Nozza et al. (2022) provides further evidence of the need for fairness in NLP model

Let's see the number of annotated papers for now!

In [87]:
annotated_article = [x for x in parsed_article if x['citation_data'] is not None]

In [88]:
print("Annotated Papers: ", len(annotated_article))

Annotated Papers:  7243


## **2.2 Parsing generated data**

From my observation, there are 2 main citation styles in AI papers, Author-year style and Numeric style:

Example of Author-year style:
- (Bassignana and Plank, 2022a) 
- (Liu et al., 2021)
- (K√∂ksal and √ñzg√ºr, 2020)

Example of Numeric style:
- [1], [2], [3]
- [2, 56, 67]
- [7 - 9]

Therefore, we need different strategy to handle each style of citation

### **2.2.1 Handle Author-Year citation style**

Handling this citation style can be quite frustrating. Initially, we must separate combined citations like (Liu et al., 2021; Littell et al., 201) into individual entries. Then, we need to identify the first author and publication year. Subsequently, we have to locate the corresponding reference within our reference list based on the author's name and publication year.

In [89]:
# Parse annotated articles
import re

# Function to normalize author names for comparison
def normalize_author_name(name):
    # Convert to lowercase and remove middle initials
    name = name.lower()
    name = re.sub(r"\s+[a-z]\.", "", name)  # Remove middle initials
    return name


citation_names = [c['Citation'] for c in annotated_article[0]['citation_data']]
citation_names

['(Cohn et al., 1996)',
 '(Settles, 2009)',
 '(Dasgupta, 2011)',
 '(Gururangan et al., 2020)',
 '(Houlsby et al., 2019)',
 '(Pfeiffer et al., 2023)',
 '(He et al., 2021;Li and Liang, 2021;Karimi Mahabadi et al., 2021)',
 '(Toneva et al., 2019)',
 '(Ein-Dor et al., 2020)',
 '(Margatina et al., 2021)',
 '(Shelmanov et al., 2021)',
 '(Karamcheti et al., 2021)',
 '(Schr√∂der et al., 2022)',
 '(Mosbach et al., 2021)',
 '(Zhang et al., 2021)',
 '(Dodge et al., 2020)',
 '(Grie√ühaber et al., 2020)',
 '(Yuan et al., 2020)',
 '(Yu et al., 2022)',
 '(Margatina et al., 2022)',
 '(Jukiƒá and ≈†najder, 2023)',
 '(Ansell et al., 2021)',
 '(Lee et al., 2022)',
 '(Paroviƒá et al., 2022)',
 '(Li and Liang, 2021)',
 '(Mao et al., 2022)',
 '(He et al., 2021)',
 '(Kim et al., 2021)',
 '(Pang and Lee, 2004)',
 '(Li and Roth, 2002)',
 '(Socher et al., 2013)',
 '(Zhang et al., 2015)',
 '(Houlsby et al., 2019)',
 '(Li and Liang, 2021)',
 '(Hu et al., 2022)',
 '(Mao et al., 2022)',
 '(Devlin et al., 2019)',
 '

In [90]:
# Refined function to identify and normalize the first author from a citation
def identify_and_normalize_first_author(citation_authors):
    # Check for 'et al.' and 'and' to find the first author
    if 'et al.' in citation_authors:
        first_author = citation_authors.split('et al.')[0].strip()
    elif ' and ' in citation_authors:
        first_author = citation_authors.rsplit(' and ', 1)[0].split(',')[0].strip()
    else:
        first_author = citation_authors.split(',')[0].strip()
    # Normalize the first author's name for comparison
    return first_author.lower()


# Function to split and parse citations in cases of citation 
# like (Culotta and Sorensen 2004; Bunescu and Mooney 2005; Ittoo and Bouma 2013)
def split_and_parse_citation(citation):

    # Remove outer parentheses
    citation = citation.strip("()")
    # Split on semicolon if it's present, indicating multiple citations within one
    if ';' in citation:
        sub_citations = citation.split(';')
    else:
        sub_citations = [citation]
    
    # Parse each sub-citation for author names and year
    for sub_citation in sub_citations:
        # Splitting based on the last occurrence of space which is assumed to be before the year
        *authors, year = sub_citation.rsplit(' ', 1)
        authors = ' '.join(authors)  # Joining back the authors in case there are multiple names
        parsed_citation = {'Author': identify_and_normalize_first_author(authors), 'Year': year}
    
    return parsed_citation

In [91]:
references = annotated_article[0]['references']
references

[{'authors': 'Alan Ansell; Maria Edoardo; Jonas Ponti; Sebastian Pfeiffer; Goran Ruder; Ivan Glava≈°; Anna Vuliƒá;  Korhonen',
  'journal': 'Association for Computational Linguistics',
  'ref_id': 'b0',
  'title': 'MAD-G: Multilingual adapter generation for efficient cross-lingual transfer',
  'year': '2021'},
 {'authors': 'Robert Baldock; Hartmut Maennel; Behnam Neyshabur',
  'journal': '',
  'ref_id': 'b1',
  'title': 'Deep learning through the lens of example difficulty',
  'year': '2021'},
 {'authors': 'Curran Associates; Inc ',
  'journal': '',
  'ref_id': 'b2',
  'title': '',
  'year': ''},
 {'authors': 'Zoubin David A Cohn; Michael I Ghahramani;  Jordan',
  'journal': 'Journal of artificial intelligence research',
  'ref_id': 'b3',
  'title': 'Active learning with statistical models',
  'year': '1996'},
 {'authors': 'Sanjoy Dasgupta',
  'journal': '',
  'ref_id': 'b4',
  'title': 'Two faces of active learning',
  'year': '2009'},
 {'authors': 'Jacob Devlin; Ming-Wei Chang; Kento

In [92]:
# Function to normalize and extract the first author's name
def get_first_author(authors_str):
    first_author = authors_str.split(';')[0].strip()
    # Normalize the first author's name for comparison
    return first_author.lower()

# Generalized regular expression for detecting years in various date formats and standalone years

# Function to detect various year patterns and extract the year
def extract_years(string):
    general_year_pattern = re.compile(r'(?:\b|\D)(\d{4})(?:\b|\D)')
    # Find all matches for the general year pattern

    matches = general_year_pattern.findall(string)
    # Add all unique years found in this string
    year = matches[0] if matches else None
    return year

# Function to match citations with references
def match_citations_with_references(citation, references):
    match = None
    citation_first_author = citation['Author']
    citation_year = citation['Year'].strip()
    for ref in references:
        ref_first_author = get_first_author(ref['authors'])
        ref_year = extract_years(ref['year']) if ref['year'] is not None else None
        # Check for match by first author and year
        if citation_first_author in ref_first_author: #and (citation_year == ref_year or ref_year is None):
            match = {
                'ref_id': ref['ref_id']
            }
    return match

In [93]:
# test with the first sample
for citation in annotated_article[0]['citation_data']:
    parsed_name = split_and_parse_citation(citation['Citation'])
    match = match_citations_with_references(parsed_name, references)
    citation['ref_id'] = match['ref_id'] if match else None

In [94]:
annotated_article[0]['citation_data']

[{'Category': 'Methodological Basis',
  'Citation': '(Cohn et al., 1996)',
  'Explanation': 'The cited work introduces the concept of active learning as a potential solution to the challenge of data labeling in low-resource settings, which the citing paper builds upon in its research on efficient finetuning methods for PLMs.',
  'ref_id': 'b3'},
 {'Category': 'Methodological Basis',
  'Citation': '(Settles, 2009)',
  'Explanation': 'The cited work provides a more in-depth discussion of active learning and its potential benefits in reducing labeling costs, which the citing paper further explores in the context of PLMs and low-resource settings.',
  'ref_id': 'b37'},
 {'Category': 'Methodological Basis',
  'Citation': '(Dasgupta, 2011)',
  'Explanation': 'The cited work highlights the importance of label complexity in active learning and the need to reduce it for efficient model training, which the citing paper addresses in its research on efficient finetuning methods for PLMs in low-res

In [95]:
references[26:]

[{'authors': 'Xin Li; Dan Roth',
  'journal': '',
  'ref_id': 'b26',
  'title': 'Learning question classifiers',
  'year': '2002'},
 {'authors': 'Yuning Mao; Lambert Mathias; Rui Hou; Amjad Almahairi; Hao Ma; Jiawei Han; Scott Yih; Madian Khabsa',
  'journal': 'Association for Computational Linguistics',
  'ref_id': 'b27',
  'title': 'UniPELT: A unified framework for parameter-efficient language model tuning',
  'year': '2022'},
 {'authors': 'Katerina Margatina; Loic Barrault; Nikolaos Aletras',
  'journal': '',
  'ref_id': 'b28',
  'title': 'On the importance of effectively adapting pretrained language models for active learning',
  'year': '2022'},
 {'authors': 'Katerina Margatina; Giorgos Vernikos; Lo√Øc Barrault; Nikolaos Aletras',
  'journal': 'Association for Computational Linguistics',
  'ref_id': 'b29',
  'title': 'Active learning by acquiring contrastive examples',
  'year': '2021'},
 {'authors': 'Marius Mosbach; Maksym Andriushchenko; Dietrich Klakow',
  'journal': '',
  'ref

Now we need to group the citation data by ref_id

In [96]:
# Function to regroup citations by ref_id
def regroup_citations_by_ref_id(citations):
    grouped_citations = {}
    for citation in citations:
        if 'ref_id' in citation.keys():
            ref_id = citation['ref_id']
            # Create a copy of the citation without the ref_id
            citation_copy = {k: v for k, v in citation.items() if k != 'ref_id'}
            # Append the citation to the list associated with its ref_id
            if ref_id in grouped_citations:
                grouped_citations[ref_id].append(citation_copy)
            else:
                grouped_citations[ref_id] = [citation_copy]
    return grouped_citations


# Regroup the citationb list by ref_id
grouped_citations = regroup_citations_by_ref_id(annotated_article[0]['citation_data'])
print(grouped_citations)

{'b3': [{'Category': 'Methodological Basis', 'Citation': '(Cohn et al., 1996)', 'Explanation': 'The cited work introduces the concept of active learning as a potential solution to the challenge of data labeling in low-resource settings, which the citing paper builds upon in its research on efficient finetuning methods for PLMs.'}], 'b37': [{'Category': 'Methodological Basis', 'Citation': '(Settles, 2009)', 'Explanation': 'The cited work provides a more in-depth discussion of active learning and its potential benefits in reducing labeling costs, which the citing paper further explores in the context of PLMs and low-resource settings.'}], 'b4': [{'Category': 'Methodological Basis', 'Citation': '(Dasgupta, 2011)', 'Explanation': 'The cited work highlights the importance of label complexity in active learning and the need to reduce it for efficient model training, which the citing paper addresses in its research on efficient finetuning methods for PLMs in low-resource settings.'}], 'b11': 

Let's combine all the steps together into one function

In [114]:
def preprocess_citation_author_year(article):
    for citation in article['citation_data']:
        try:
            parsed_name = split_and_parse_citation(citation['Citation'])
            match = match_citations_with_references(parsed_name, article['references'])
            citation['ref_id'] = match['ref_id'] if match else None
        except:
            citation['ref_id'] = None
    return article

Now we have a grouped citation data for author-year citation style, let's start solving cases with numeric-style. 

### **2.2.2 Handle Numeric Citation Style**

This style of citation seems simple at first, but there are many edge cases that we have to deal with. From my observation, there are 3 main types:

- Singular citations such as [1] or [4]: These are processed conventionally, where the reference ID equals the citation number minus one.
- Lists, for instance [1, 4, 6]: In this scenario, the citations are split into individual entries: [1], [4], and [6].
- Ranges, like [1 - 5]: Here, the citation is divided into separate entries: [1], [2], [3], [4], [5].
- Mixed ranges, such as [1] - [5]: These are split into distinct citations: [1] and [5].

In [117]:
def split_numeric_citations(citations):
    # Helper function to parse ranges and individual numbers
    def parse_part(part):
        if '-' in part:  # Handle ranges
            start, end = map(int, part.split('-'))
            return list(range(start, end + 1))
        else:  # Handle individual numbers
            return [int(part)]

    # Initialize the result list
    result = []

    # Find all parts of the input that match the patterns
    parts = re.findall(r'\[([^]]+)]', citations)
    
    for part in parts:
        # For each part, remove spaces, split by commas and extend the result list
        for subpart in part.replace(' ', '').split(','):
            try:
                result.extend(parse_part(subpart))
            except:
                continue

    return [f"[{num}]" for num in result]

# Function to apply citation splitting to a list of citation entries
def split_citations_in_entries(citation_entries):
    expanded_citation_entries = []
    for entry in citation_entries:
        try:
        # Use the split_citations function to get a list of individual citations from the Citation field
            split_citations_list = split_numeric_citations(entry['Citation'])
            for citation in split_citations_list:
                # Create a new citation entry for each split citation, keeping other fields the same
                
                new_entry = {
                    'Citation': citation,
                    'Category': entry['Category'],
                    'Explanation': entry['Explanation']
                }
                expanded_citation_entries.append(new_entry)
        except:
            continue
    return expanded_citation_entries


def match_numeric_citation(citations):
    for citation in citations:
        # Regular expression to find single numbers inside square brackets
        pattern = re.compile(r'\[\(?(?P<number>\d+)\)?\]')
        try:
            #Find all matches in the text and convert them to integers
            reference_num = [int(match.group('number')) for match in pattern.finditer(citation['Citation'])][0]
            citation['ref_id'] = f"b{reference_num -1}"

        except:
            continue
    return citations



Before parsing

In [99]:
annotated_article[108]['citation_data'][:5]

[{'Category': 'Methodological Basis',
  'Citation': '[29,72]',
  'Explanation': 'The cited works provide a method of using the features learned by predicting different auxiliary maps to assist in predicting saliency maps, which the citing paper adopts in its research on saliency object detection.'},
 {'Category': 'Methodological Basis',
  'Citation': '[52]',
  'Explanation': 'The cited work uses auxiliary maps as input to guide the training process, which the citing paper adopts in its research on saliency object detection.'},
 {'Category': 'Methodological Basis',
  'Citation': '[46,14]',
  'Explanation': 'The cited works introduce a boundary-aware loss to make the models pay more attention to edge pixels, which the citing paper adopts in its research on saliency object detection.'},
 {'Category': 'Methodological Basis',
  'Citation': '[58,60]',
  'Explanation': 'The cited works provide a method of using multiple heads in a single encoder to learn different semantic information, which 

After parsing

In [100]:
annotated_article[108]['citation_data'] = split_citations_in_entries(annotated_article[108]['citation_data'])
annotated_article[108]['citation_data'] =  match_numeric_citation(annotated_article[108]['citation_data'])

In [101]:
annotated_article[108]['citation_data'][:6]

[{'Citation': '[29]',
  'Category': 'Methodological Basis',
  'Explanation': 'The cited works provide a method of using the features learned by predicting different auxiliary maps to assist in predicting saliency maps, which the citing paper adopts in its research on saliency object detection.',
  'ref_id': 'b28'},
 {'Citation': '[72]',
  'Category': 'Methodological Basis',
  'Explanation': 'The cited works provide a method of using the features learned by predicting different auxiliary maps to assist in predicting saliency maps, which the citing paper adopts in its research on saliency object detection.',
  'ref_id': 'b71'},
 {'Citation': '[52]',
  'Category': 'Methodological Basis',
  'Explanation': 'The cited work uses auxiliary maps as input to guide the training process, which the citing paper adopts in its research on saliency object detection.',
  'ref_id': 'b51'},
 {'Citation': '[46]',
  'Category': 'Methodological Basis',
  'Explanation': 'The cited works introduce a boundary-

As you can see, single citation like [14], [58] will be parsed normaly. But for citation like [29, 72], they will get split to 2 separated citations [29] and [72].

Now, let's combine the steps together

In [102]:
def proprocess_citation_numeric(article):
    
    article['citation_data'] = split_citations_in_entries(article['citation_data'])
    article['citation_data'] = match_numeric_citation(article['citation_data'])
    return article

### **2.2.3 Process 2 citation style**

First we need to detect the citation style

In [103]:
def detect_citation_style(text):
    # Pattern to match numeric citations like [1], [1, 2], [1-6], [1, 2-6], [1, 2, 3-6], etc.
    numeric_pattern = re.compile(r'\[\d+(-\d+)?(,\s*\d+(-\d+)?)*\]')
    # Pattern for "Author-Year" citations like (Author, Year)
    author_year_pattern = re.compile(r'\([A-Za-z]+,\s*\d{4}\)')

    # Check for numeric citation style
    if numeric_pattern.search(text):
        return "Numeric"
    # Check for author-year citation style
    elif author_year_pattern.search(text):
        return "Author-Year"
    else:
        return "Author-Year"

In [104]:
detect_citation_style("(Amin et al., 2019)")

'Author-Year'

In [105]:
detect_citation_style("[1,6]")

'Numeric'

In [118]:
from tqdm import tqdm

for article in tqdm(annotated_article, total=len(annotated_article)):
    references = article['references']
    if len(article['citation_data']) == 0:
        continue
    citation_style = detect_citation_style(article['citation_data'][0]["Citation"])
    # try:
    if citation_style == "Author-Year":
        article = preprocess_citation_author_year(article)
    elif citation_style == "Numeric":
        article = proprocess_citation_numeric(article)
    else:
        print(f"Uncertain citation style: {citation_style}")
        continue
    # except Exception as e:
    #     print(article['citation_data'])
    #     print(e)
    #     break

    grouped_citations = regroup_citations_by_ref_id(article['citation_data'])
    article['grouped_citations'] = grouped_citations


  2%|‚ñè         | 157/7243 [00:00<00:27, 255.67it/s]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7243/7243 [00:11<00:00, 643.55it/s]


In [119]:
article['citation_data'][0]['Citation']

'[4]'

In [120]:
annotated_article[17]['grouped_citations']

{'b14': [{'Category': 'Methodological Basis',
   'Citation': '(Liu et al., 2020)',
   'Explanation': 'The cited work by Liu et al. introduces a fusion forget gate to control the flow of information between multimodal sequences, which the citing paper adopts as a method to address the problem of redundancy and noise in video multimodal fusion.'},
  {'Category': 'Methodological Basis',
   'Citation': '(Liu et al., 2020)',
   'Explanation': 'The cited work proposed a multistage fusion network with a fusion forget gate module for controlling redundant information in multimodal long sequences, which the citing paper adopts in their video multimodal summarization research.'},
  {'Category': 'Methodological Basis',
   'Citation': '(Liu et al., 2018b)',
   'Explanation': 'The cited work introduces the LMF model, which the citing paper references for text generation tasks.'},
  {'Category': 'Methodological Basis',
   'Citation': '(Liu et al., 2018b)',
   'Explanation': 'The cited work by Liu et

## **2.3 Building citation graph**

### **2.3.1 Parsing annotated triplets**

In [121]:
relationships_dict = {
    "Supporting Evidence": "Is Evidence For",
    "Methodological Basis": "Is Methodological Basis For",
    "Theoretical Foundation": "Is Theoretical Foundation For", 
    "Data Source": "Is Data Source For",
    "Extension or Continuation": "Is Extension or Continuation Of",
}

We have grouped citation data; now we need to find the papers cited in the arXiv dataset by name.

In [122]:
df_data['title'] = df_data['title'].str.lower()
titles = df_data['title'].tolist()


def search_paper_by_name(name):
    # matches = df_data['title'].str.contains(name, case=False, na=False, regex=False)
    # filtered_df = df_data[matches]
    # if len(filtered_df) == 0:
    #     return None
    # return filtered_df.iloc[0]['id']
    titles = df_data['title'].tolist()
    for idx, title in enumerate(titles):
        if name in title:
            return df_data.iloc[idx]['id']
    return None

for article_dict in tqdm(annotated_article, total=len(annotated_article)):

    article_dict["arxiv_id"] = search_paper_by_name(article_dict['title'].lower())

    if "grouped_citations" in article_dict.keys():
        article_dict["mapped_citation"] = {}
        for key,val in article_dict['grouped_citations'].items():
            for ref in article_dict["references"]:
                if ref["ref_id"] == key:
                    title = ref["title"]

            title = title.lower()
            arxiv_id = search_paper_by_name(title)
            article_dict['mapped_citation'][key] = {"title": title, 'arxiv_id': arxiv_id, 'citation': val}

  0%|          | 0/7243 [00:00<?, ?it/s]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7243/7243 [1:54:35<00:00,  1.05it/s]  


In [125]:
with open("../outputs/annotated_articles.json", "w") as f:
    json.dump(annotated_article, f)

In [126]:
article_dict['mapped_citation']

{'b3': {'title': 'an image is worth 16x16 words: transformers for image recognition at scale',
  'arxiv_id': '2010.11929',
  'citation': [{'Citation': '[4]',
    'Category': 'Methodological Basis',
    'Explanation': 'The cited works introduce the concept of using image transformers as a backbone for video descriptor extraction, which the citing paper adopts for its video copy detection model.'}]},
 'b6': {'title': 'swin transformer v2: scaling up capacity and resolution',
  'arxiv_id': '2111.09883',
  'citation': [{'Citation': '[7]',
    'Category': 'Methodological Basis',
    'Explanation': 'The cited works introduce the concept of using image transformers as a backbone for video descriptor extraction, which the citing paper adopts for its video copy detection model.'}]},
 'b0': {'title': 'a simple framework for contrastive learning of visual representations',
  'arxiv_id': '2002.05709',
  'citation': [{'Citation': '[1]',
    'Category': 'Methodological Basis',
    'Explanation': "Th

Let's define a class for a paper node

In [127]:
class PaperNode:
    title: str
    arxiv_id: str
    
    def __init__(self, title, arxiv_id):
        self.title = title
        self.arxiv_id = arxiv_id

    def __str__(self) -> str:
        return f"Title: {self.title},\n Arxiv ID: {self.arxiv_id}"

class PaperEdge:
    category: str
    explanation: str
    verbose = True

    def __init__(self, category, explanation):
        self.category = category
        self.explanation = explanation

    def __str__(self) -> str:
        if self.verbose:
            return f"Category: {self.category},\n Explanation: {self.explanation}"
        else:
            return f"Category: {self.category}"

In [151]:
paper_dict = {}

for article_dict in tqdm(annotated_article, total=len(annotated_article)):
    paper_dict[article_dict['title'].lower()] = PaperNode(title=article_dict['title'], arxiv_id=article_dict['arxiv_id'])

    if "mapped_citation" in article_dict.keys():
        for key,val in article_dict['mapped_citation'].items():
            title = val['title']
            if title not in paper_dict.keys():
                paper_node = PaperNode(title=val['title'], arxiv_id=val['arxiv_id'])
                paper_dict[title] = paper_node


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7243/7243 [00:00<00:00, 23971.08it/s]


In [149]:
annotated_article[-1]['mapped_citation']

{'b3': {'title': 'an image is worth 16x16 words: transformers for image recognition at scale',
  'arxiv_id': '2010.11929',
  'citation': [{'Citation': '[4]',
    'Category': 'Methodological Basis',
    'Explanation': 'The cited works introduce the concept of using image transformers as a backbone for video descriptor extraction, which the citing paper adopts for its video copy detection model.'}]},
 'b6': {'title': 'swin transformer v2: scaling up capacity and resolution',
  'arxiv_id': '2111.09883',
  'citation': [{'Citation': '[7]',
    'Category': 'Methodological Basis',
    'Explanation': 'The cited works introduce the concept of using image transformers as a backbone for video descriptor extraction, which the citing paper adopts for its video copy detection model.'}]},
 'b0': {'title': 'a simple framework for contrastive learning of visual representations',
  'arxiv_id': '2002.05709',
  'citation': [{'Citation': '[1]',
    'Category': 'Methodological Basis',
    'Explanation': "Th

In [152]:
len(paper_dict.keys())

112611

In [153]:
paper_dict["make-a-video: text-to-video generation without text-video data"]

<__main__.PaperNode at 0x70bff98d5900>

In [154]:
triplets = []

for article_dict in annotated_article:
    if "mapped_citation" not in article_dict.keys():
        print(article_dict['title'])
        continue
    for key, val in article_dict['mapped_citation'].items():
        title = val['title']
        citation = val['citation']
        
        # Use a dictionary to group explanations by category
        category_explanations = {}
        for rel in citation:
            # try:
                category = rel['Category']
                explanation = rel['Explanation']
                if category not in category_explanations:
                    category_explanations[category] = []
                category_explanations[category].append(explanation)

        source_node = paper_dict[title]
        target_node = paper_dict[article_dict['title'].lower()]

        # Construct triplets with aggregated explanations for each category
        for category, explanations in category_explanations.items():
            if category not in relationships_dict.keys():
                relationships_dict[category] = f"Is {category} Of"

            aggregated_explanation = "; ".join(set(explanations))  # Remove duplicates and join explanations
            rel = PaperEdge(category=category, explanation=aggregated_explanation)
            reverse_rel = PaperEdge(category=relationships_dict[category], explanation=aggregated_explanation)

            # Add the relationship in both directions
            triplets.append((source_node, rel, target_node))
            triplets.append((target_node, reverse_rel, source_node))

{'Category': 'Methodological Basis', 'Citation': 'The citing paper adopts the data and results from the cited works by Shridhar et al. and Yang et al. to demonstrate the on-par success rates of AutoPlan with human-written demonstrations.'}
{'Category': 'Extension or Continuation', 'Citation': 'The proposed LPS metric is used in the study conducted in the citing paper to further split test nodes into different sensitive groups based on the LPS values.'}
{'Category': 'Data Source', 'Citation': 'The cited work is used to acknowledge the origin of a dataset or specific information that the citing paper utilizes in their research or analysis.'}
{'Category': 'Extension or Continuation', 'Citation': 'The cited work is extended in the citing paper by choosing the optimal number of prompt features and keeping the neural network architecture and hyperparameters the same as in PLOT.'}
{'Category': 'Methodological Basis', 'Citation': 'The cited work contributes a method for choosing the featurizer

In [155]:
rel['Category']

TypeError: 'PaperEdge' object is not subscriptable

In [131]:
import networkx as nx

# Assuming 'triplets' is your list of relationships, 
# and each PaperNode object in the triplets has an 'arxiv_id' attribute

G = nx.DiGraph()

# Add nodes and edges
for source_node, relationship, target_node in triplets:
    # Add nodes if they are not already in the graph
    if source_node.arxiv_id not in G:
        G.add_node(source_node.title, title=str(source_node), arxiv_id=source_node.arxiv_id)
    if target_node.arxiv_id not in G:
        G.add_node(target_node.title, title=str(target_node), arxiv_id=target_node.arxiv_id)
    
    # Add edge with relationship details
    G.add_edge(source_node.title, target_node.title, title=str(relationship), category=relationship.category, explanation=relationship.explanation)

In [135]:
print(len(triplets))

586952


In [143]:
len(G.nodes)

3506

### **2.3.2 Visualizing citation graph**

In [141]:
len(paper_dict.keys())

112611

In [138]:
def find_connected_nodes(graph, node, relationship=None):
    """
    Find nodes connected to the given node with an optional filter on the type of relationship.
    """
    connected_nodes = []
    for n, nbrs in graph.adj.items():
        if n == node:
            for nbr, eattr in nbrs.items():
                if relationship is None or eattr['label'] == relationship:
                    connected_nodes.append(nbr)
    return connected_nodes

# Function to search for a node by arxiv_id and return its details
def find_nodes_by_arxiv_id(graph, arxiv_id):
    for node, data in graph.nodes(data=True):
        if data.get('arxiv_id') == arxiv_id:
            return data  # or return data['paper_node'] to return the PaperNode object itself
    return "Paper not found in the graph."


def find_shortest_path(graph, source, target):
    """
    Find the shortest path between two nodes.
    """
    try:
        path = nx.shortest_path(graph, source=source, target=target)
        return path
    except nx.NetworkXNoPath:
        return None

# Example Usage
phenaki_related_topics = find_connected_nodes(G, 'cogview: mastering text-to-image generation via transformers')
print("Topics related to cogvideo:", phenaki_related_topics)

# Example search
search_result = find_nodes_by_arxiv_id(G, "2209.14792")
print(search_result)


Topics related to cogvideo: []
Paper not found in the graph.


In [142]:
def find_nodes_by_keyword(graph, keyword):
    """
    Find nodes that contain the given keyword in their name and retrieve their connected nodes and relationships.
    """
    keyword = keyword.lower()  # Convert keyword to lowercase for case-insensitive matching
    matching_nodes = [node for node in graph.nodes if keyword in node.lower()]

    related_nodes = {}
    for node in matching_nodes:
        connections = []
        for neighbor, details in graph[node].items():
            connections.append((neighbor, details['title'].split('\n')[0]))
        related_nodes[node] = connections

    return related_nodes

# Example Usage
keyword = "make"
phenaki_related = find_nodes_by_keyword(G, keyword)
for node, connections in phenaki_related.items():
    print(f"Node: {node}")
    for conn in connections:
        print(f"  Connected to: {conn[0]} via {conn[1]}")

Node: google's ai chatbot bard makes factual error in first demo
  Connected to: DecipherPref: Analyzing Influential Factors in Human Preference Judgments via GPT-4 via Category: Data Source,
Node: what makes a good conversation? how controllable attributes affect human judgments
  Connected to: LEFTOVER-LUNCH: ADVANTAGE-BASED OFFLINE REINFORCEMENT LEARNING FOR LANGUAGE MODELS via Category: Extension or Continuation,
Node: does bert make any sense? interpretable word sense disambiguation with contextualized embeddings
  Connected to: SENTECON: Leveraging Lexicons to Learn Human-Interpretable Language Representations via Category: Methodological Basis,
Node: learning the difference that makes a difference with counterfactually-augmented data
  Connected to: Prompting Large Language Models for Counterfactual Generation: An Empirical Study via Category: Methodological Basis,
Node: rethinking the role of demonstrations: what makes in-context learning work? in proceedings of empirical metho

In [140]:
import networkx as nx
from pyvis.network import Network

# Assuming G is your original graph
# Step 1: Create the subgraph for "Node1" and its neighbors
subgraph = nx.ego_graph(G, 'make-a-video: text-to-video generation without text-video data', radius=3, center=True, undirected=False)
# Nodes to be removed because they have a degree of 1 in the full graph
nodes_to_remove = [node for node in subgraph if subgraph.degree(node) < 3]

# Remove the nodes from the ego graph
subgraph.remove_nodes_from(nodes_to_remove)

nt = Network(notebook=True, font_color='#10000000')
nt.from_nx(subgraph)
nt.show("nx.html")

NodeNotFound: Source make-a-video: text-to-video generation without text-video data is not in G

## **2.4 Building Graph Query Engine**

In [72]:
paper_retriever = index.as_retriever(
    similarity_top_k=10,
)

In [1]:
paper_dict['make-a-video: text-to-video generation without text-video data']

NameError: name 'paper_dict' is not defined

In [78]:
results = paper_retriever.retrieve("Give me some paper about Video diffusion models")
nodes = []
for r in results:
    title = r.text.split("/n")[0]
    print(title)
    node = find_nodes_by_keyword(G, title)
    print(node)
    nodes.append(node)

a survey on video diffusion models
 The recent wave of AI-generated content (AIGC) has witnessed substantialsuccess in computer vision, with the diffusion model playing a crucial role inthis achievement. Due to their impressive generative capabilities, diffusionmodels are gradually superseding methods based on GANs and auto-regressiveTransformers, demonstrating exceptional performance not only in imagegeneration and editing, but also in the realm of video-related research.However, existing surveys mainly focus on diffusion models in the context ofimage generation, with few up-to-date reviews on their application in the videodomain. To address this gap, this paper presents a comprehensive review ofvideo diffusion models in the AIGC era. Specifically, we begin with a conciseintroduction to the fundamentals and evolution of diffusion models.Subsequently, we present an overview of research on diffusion models in thevideo domain, categorizing the work into three key areas: video generation,

# **3. Basic Data Science Assistant**

## **3.1 Download Wikipedia Data**

For data science questions, I will use the source from wikipedia

In [48]:
!pip install -q -U wikipedia-api

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [49]:
import re

# Pre-compile the regular expression pattern for better performance
BRACES_PATTERN = re.compile(r'\{.*?\}|\}')

def remove_braces_and_content(text):
    """Remove all occurrences of curly braces and their content from the given text"""
    return BRACES_PATTERN.sub('', text)

def clean_string(input_string):
    """Clean the input string."""
    
    # Remove extra spaces by splitting the string by spaces and joining back together
    cleaned_string = ' '.join(input_string.split())
    
    # Remove consecutive carriage return characters until there are no more consecutive occurrences
    cleaned_string = re.sub(r'\r+', '\r', cleaned_string)
    
    # Remove all occurrences of curly braces and their content from the cleaned string
    cleaned_string = remove_braces_and_content(cleaned_string)
    
    # Return the cleaned string
    return cleaned_string

In [50]:
def extract_wikipedia_pages(wiki_wiki, category_name):
    """Extract all references from a category on Wikipedia"""
    
    # Get the Wikipedia page corresponding to the provided category name
    category = wiki_wiki.page("Category:" + category_name)
    
    # Initialize an empty list to store page titles
    pages = []
    
    # Check if the category exists
    if category.exists():
        # Iterate through each article in the category and append its title to the list
        for article in category.categorymembers.values():
            pages.append(article.title)
    
    # Return the list of page titles
    return pages

In [51]:
import wikipediaapi
from tqdm import tqdm

def get_wikipedia_pages(categories):
    """Retrieve Wikipedia pages from a list of categories and extract their content"""
    
    # Create a Wikipedia object
    wiki_wiki = wikipediaapi.Wikipedia('Kaggle Data Science Assistant with Gemma', 'en')
    
    # Initialize lists to store explored categories and Wikipedia pages
    explored_categories = []
    wikipedia_pages = []

    # Iterate through each category
    print("- Processing Wikipedia categories:")
    for category_name in categories:
        print(f"\tExploring {category_name} on Wikipedia")
        
        # Get the Wikipedia page corresponding to the category
        category = wiki_wiki.page("Category:" + category_name)
        
        # Extract Wikipedia pages from the category and extend the list
        wikipedia_pages.extend(extract_wikipedia_pages(wiki_wiki, category_name))
        
        # Add the explored category to the list
        explored_categories.append(category_name)

    # Extract subcategories and remove duplicate categories
    categories_to_explore = [item.replace("Category:", "") for item in wikipedia_pages if "Category:" in item]
    wikipedia_pages = list(set([item for item in wikipedia_pages if "Category:" not in item]))
    
    # Explore subcategories recursively
    while categories_to_explore:
        category_name = categories_to_explore.pop()
        print(f"\tExploring {category_name} on Wikipedia")
        
        # Extract more references from the subcategory
        more_refs = extract_wikipedia_pages(wiki_wiki, category_name)

        # Iterate through the references
        for ref in more_refs:
            # Check if the reference is a category
            if "Category:" in ref:
                new_category = ref.replace("Category:", "")
                # Add the new category to the explored categories list
                if new_category not in explored_categories:
                    explored_categories.append(new_category)
            else:
                # Add the reference to the Wikipedia pages list
                if ref not in wikipedia_pages:
                    wikipedia_pages.append(ref)

    # Initialize a list to store extracted texts
    extracted_texts = []
    
    # Iterate through each Wikipedia page
    print("- Processing Wikipedia pages:")
    for page_title in tqdm(wikipedia_pages, total=len(wikipedia_pages)):
        # Get the Wikipedia page
        page = wiki_wiki.page(page_title)

        # Append the page title and summary to the extracted texts list
        if len(page.summary) > len(page.title):
            extracted_texts.append(page.title + " : " + clean_string(page.summary))
        
        # Iterate through the sections in the page
        for section in page.sections:
            # Append the page title and section text to the extracted texts list
            if len(section.text) > len(page.title):
                extracted_texts.append(page.title + " : " + clean_string(section.text))
                
    # Return the extracted texts
    return extracted_texts

In [24]:
categories = ["Machine_learning", "Data_science", "Statistics", "Deep_learning", "Artificial_intelligence"]
extracted_texts = get_wikipedia_pages(categories)
print("Found", len(extracted_texts), "Wikipedia pages")

- Processing Wikipedia categories:
	Exploring Machine_learning on Wikipedia


	Exploring Data_science on Wikipedia
	Exploring Statistics on Wikipedia
	Exploring Deep_learning on Wikipedia
	Exploring Artificial_intelligence on Wikipedia
	Exploring Artificial intelligence stubs on Wikipedia
	Exploring Works created using artificial intelligence on Wikipedia
	Exploring Virtual assistants on Wikipedia
	Exploring Turing tests on Wikipedia
	Exploring AI software on Wikipedia
	Exploring Rule engines on Wikipedia
	Exploring Artificial intelligence publications on Wikipedia
	Exploring Philosophy of artificial intelligence on Wikipedia
	Exploring Artificial intelligence people on Wikipedia
	Exploring Open-source artificial intelligence on Wikipedia
	Exploring Non-fiction books about Artificial intelligence on Wikipedia
	Exploring Neural networks on Wikipedia
	Exploring Multi-agent systems on Wikipedia
	Exploring Mind‚Äìbody problem on Wikipedia
	Exploring Machine learning on Wikipedia
	Exploring Artificial intelligence laboratories on Wikipedia
	Exploring Knowledge repres

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3448/3448 [22:26<00:00,  2.56it/s]

Found 16232 Wikipedia pages





In [25]:
wiki_documents = [Document(text=extracted_text, doc_id=str(i)) for i, extracted_text in enumerate(extracted_texts)]

In [70]:
chroma_client = chromadb.PersistentClient(path="../DB/wiki")
chroma_collection = chroma_client.get_or_create_collection("gemma_assistant_wiki")


# Create vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

In [27]:
index = VectorStoreIndex.from_documents(
    wiki_documents, storage_context=storage_context, embed_model=embed_model, show_progress=True
)

Parsing nodes: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 16232/16232 [00:08<00:00, 1830.75it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:11<00:00, 175.13it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:10<00:00, 188.60it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:11<00:00, 175.40it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:11<00:00, 177.25it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:10<00:00, 202.03it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:09<00:00, 209.87it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:11<00:00, 185.12it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2048/2048 [00:11<00:00, 182.42it/s]
Generating embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 73/73 [00:00<00:00, 197.81it/s]


## **3.2 Loading from vector store**

In [5]:
from llama_index.core import VectorStoreIndex, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext
import torch


Settings.llm = None # Set this to none to make the index only do retrieval
device_type = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5", cache_folder="../models", device=device_type) # must be the same as the previous stage

chroma_client = chromadb.PersistentClient(path="../DB/wiki")
chroma_collection = chroma_client.get_or_create_collection("gemma_assistant_wiki")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# load the vectorstore
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_vector_store(vector_store, storage_context=storage_context, embed_model=embed_model)


LLM is explicitly disabled. Using MockLLM.


In [71]:
data_science_query_engine = index.as_query_engine(
    similarity_top_k=10,
)

In [7]:
print(data_science_query_engine.query("What is linear regression"))

Context information is below.
---------------------
Outline of regression analysis : Regression analysis Linear regression

Regression diagnostic : Regression diagnostics have often been developed or were initially proposed in the context of linear regression or, more particularly, ordinary least squares. This means that many formally defined diagnostics are only available for these contexts.

Linear predictor function : In statistics and in machine learning, a linear predictor function is a linear function (linear combination) of a set of coefficients and explanatory variables (independent variables), whose value is used to predict the outcome of a dependent variable. This sort of function usually comes in linear regression, where the coefficients are called regression coefficients. However, they also occur in various types of linear classifiers (e.g. logistic regression, perceptrons, support vector machines, and linear discriminant analysis), as well as in various other models, such 

# **4. Python Code Assistant**

## **3.2 Define a code intepreter**

In [74]:
import os
import io
import regex
import pickle
import traceback
import copy
import datetime
import dateutil.relativedelta
import multiprocess
from multiprocess import Pool
from typing import Any, Dict, Optional
from pebble import ProcessPool
from tqdm import tqdm
from concurrent.futures import TimeoutError
from functools import partial
from timeout_decorator import timeout
from contextlib import redirect_stdout


class GenericRuntime:
    GLOBAL_DICT = {}
    LOCAL_DICT = None
    HEADERS = []
    def __init__(self):
        self._global_vars = copy.copy(self.GLOBAL_DICT)
        self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None

        for c in self.HEADERS:
            self.exec_code(c)

    def exec_code(self, code_piece: str) -> None:
        if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
            raise RuntimeError()
        exec(code_piece, self._global_vars)
        
    def eval_code(self, expr: str) -> Any:
        return eval(expr, self._global_vars)
    
    def inject(self, var_dict: Dict[str, Any]) -> None:
        for k, v in var_dict.items():
            self._global_vars[k] = v
    
    @property
    def answer(self):
        return self._global_vars['answer']

class DateRuntime(GenericRuntime):
    GLOBAL_DICT = {
        'datetime': datetime.datetime, 
        'timedelta': dateutil.relativedelta.relativedelta,
        'relativedelta': dateutil.relativedelta.relativedelta
    }


class CustomDict(dict):
    def __iter__(self):
        return list(super().__iter__()).__iter__()

class ColorObjectRuntime(GenericRuntime):
    GLOBAL_DICT = {'dict': CustomDict}


In [77]:
class PythonExecutor:
    def __init__(
        self,
        runtime: Optional[Any] = None,
        get_answer_symbol: Optional[str] = None,
        get_answer_expr: Optional[str] = None,
        get_answer_from_stdout: bool = False,
        timeout_length: int = 5,
    ) -> None:
        self.runtime = runtime if runtime else GenericRuntime()
        self.answer_symbol = get_answer_symbol
        self.answer_expr = get_answer_expr
        self.get_answer_from_stdout = get_answer_from_stdout
        self.pool = Pool(multiprocess.cpu_count())
        self.timeout_length = timeout_length

    def process_generation_to_code(self, gens: str):
        return [g.split('\n') for g in gens]

    @staticmethod
    def execute(
        code,
        get_answer_from_stdout = None,
        runtime = None,
        answer_symbol = None,
        answer_expr = None,
        timeout_length = 10,
    ):
        try:
            if get_answer_from_stdout:
                program_io = io.StringIO()
                with redirect_stdout(program_io):
                    timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                program_io.seek(0)
                result = program_io.read()
            elif answer_symbol:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                result = runtime._global_vars[answer_symbol]
            elif answer_expr:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
            else:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
                result = timeout(timeout_length)(runtime.eval_code)(code[-1])
            report = "Done"
            str(result)
            pickle.dumps(result) # serialization check
        except:
            result = ''
            report = traceback.format_exc().split('\n')[-2]
        return result, report

    def apply(self, code):
        return self.batch_apply([code])[0]

    @staticmethod
    def truncate(s, max_length=400):
        half = max_length // 2
        if len(s) > max_length:
            s = s[:half] + "..." + s[-half:]
        return s

    def batch_apply(self, batch_code):
        all_code_snippets = self.process_generation_to_code(batch_code)

        timeout_cnt = 0
        all_exec_results = []
        with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool:
            executor = partial(
                self.execute,
                get_answer_from_stdout=self.get_answer_from_stdout,
                runtime=self.runtime,
                answer_symbol=self.answer_symbol,
                answer_expr=self.answer_expr,
                timeout_length=self.timeout_length, # this timeout not work
            )
            future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
            iterator = future.result()

            if len(all_code_snippets) > 100:  
                progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")  
            else:  
                progress_bar = None 

            while True:
                try:
                    result = next(iterator)
                    all_exec_results.append(result)
                except StopIteration:
                    break
                except TimeoutError as error:
                    print(error)
                    all_exec_results.append(("", "Timeout Error"))
                    timeout_cnt += 1
                except Exception as error:
                    print(error)
                    exit()
                if progress_bar is not None:
                    progress_bar.update(1) 
            
            if progress_bar is not None:
                progress_bar.close() 

        batch_results = []
        for code, (res, report) in zip(all_code_snippets, all_exec_results):
            # post processing
            res, report = str(res).strip(), str(report).strip()
            res, report = self.truncate(res), self.truncate(report)
            batch_results.append((res, report))
        return batch_results


def test():
    batch_code = [
"""
print("Hello world!")
"""
    ]

    executor = PythonExecutor(get_answer_from_stdout=True)
    predictions = executor.apply(batch_code[0])
    print(predictions)


test()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

('Hello world!', 'Done')


# **5. Combine all of them together**

## **5.1 Define Router Engine**

In [73]:
from llama_index.core.query_engine import RouterQueryEngine
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core.tools import QueryEngineTool
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.llms.openai import OpenAI
from unsloth import FastLanguageModel
import os

max_seq_length = 8192 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
os.environ["CUDA_VISIBLE_DEVICES"] = "3"


# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name = "unsloth/gemma-7b-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
#     max_seq_length = max_seq_length,
#     dtype = dtype,
#     load_in_4bit = load_in_4bit,
#     token = "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", # use one if using gated models like meta-llama/Llama-2-7b-hf
#     cache_dir = "../models",
# )
# FastLanguageModel.for_inference(model) 

# llm = HuggingFaceLLM(model=model, tokenizer=tokenizer)

llm = OpenAI(model_name="gpt-3.5-turbo", api_key="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")
Settings.llm = llm

In [72]:
paper_tool = QueryEngineTool.from_defaults(
    query_engine=paper_query_engine,
    description="Useful for search for papers",
)
ds_tool = QueryEngineTool.from_defaults(
    query_engine=data_science_query_engine,
    description="Useful for answering data science concepts",
)

query_engine = RouterQueryEngine(
    selector=LLMSingleSelector.from_defaults(),
    query_engine_tools=[
        paper_tool,
        ds_tool,
    ],
    verbose=True
)
print(query_engine.query("What is linear regression?"))

ValueError: Failed to convert output to JSON: ''