# Mamba vs. Transformer-based RALMs

The purpose of this project is to compare the performance of Retrieval Augmented Language Models (RALMs) based on the newly released Mamba architecture to those based on the more prevalent Transformer architecture. We will compare the [Mamba-Chat](https://huggingface.co/havenhq/mamba-chat) model to the [Dolly](https://huggingface.co/databricks/dolly-v2-3b) model. Both these models are approximately 2.8B parameters.

## Part 0: Imports, Environment Setup, and Dataset Loading

First, to facilitate the use of this notebook in colab, load the rest of the repo. ONLY RUN THIS CODE IN COLAB.

In [None]:
!wget -q https://raw.githubusercontent.com/tsunrise/colab-github/main/colab_github.py
import colab_github
colab_github.github_auth(persistent_key=True)

In [None]:
#after updating ssh key on github account if necessary, run below
!git clone git@github.com:abarton51/CS_4650_Project.git
!mv CS_4650_Project/* .

Next, install the required dependencies.

In [None]:
!pip install -r requirements.txt

For COLAB, if you have not done so before, unzip the tarball file in google drive to get all the evaluation data.

In [None]:
!tar -xvf drive/MyDrive/Mamba_RAG/CS_4650_Project/data/triviaqa-rc.tar.gz -C data/triviaqa

Finally, import all necessary packages

In [7]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("src/")
import json
from langchain_community.vectorstores import FAISS
from datasets import load_dataset
#from mamba_ralm import MambaRALM
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader
import vector_store

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Part 1: Construct Vector Database

Now let us construct a Vector Database, or load it if already created

In [None]:
data_directory_path = "data/triviaqa/evidence"

triviaqa_vector_store = vector_store.RAGVectorStore(data_directory_path)

try:

    db = triviaqa_vector_store.load_db("triviaqa_vector_store")

except RuntimeError:

    #db doesnt exist (i hope that some other random error wasnt caught)

    db = triviaqa_vector_store.create_db("triviaqa_vector_store", verbose=True)

In [None]:
query = "What is the capital of Japan?"
docs = db.similarity_search(query, k=4)
print([docs])
"""Example Output: Japan[a] is an island country in East Asia. It is in the northwest Pacific Ocean and is bordered on the west by the Sea of Japan, extending from the 
Sea of Okhotsk in the north toward the East China Sea, Philippine Sea, and Taiwan in the south. Japan is a part of the Ring of Fire, and spans an 
archipelago of 14,125 islands, with the five main islands being Hokkaido, Honshu (the "mainland"), Shikoku, Kyushu, and Okinawa. Tokyo is the country's 
capital and largest city, followed by Yokohama, Osaka, Nagoya, Sapporo, Fukuoka, Kobe, and Kyoto.
"""

: 

# Part 2: Initialize Model

Initialize the model. Select the desired model here.

In [None]:
model = MambaRALM("havenhq/mamba-chat", db)

Run inference on the model.

In [None]:
model.provide_no_context = True # for testing
model._no_context_string = "The station went on the air as KXIV in 1989. It functioned as the second independent station for the Salt Lake City area. In 1993, Larry H. Miller, the then-owner of the Utah Jazz of the NBA, purchased the station and renamed it KJZZ-TV; it also became the new TV home of the basketball team for 16 seasons. During Miller's ownership, the station affiliated for five years with UPN, with the station's decision not to renew leading to accusations of racism against management; in the latter years, operations and programming were outsourced in turn to two other Salt Lake stations."
output = model.predict("What team does KJZZ-TV broadcast for?")
print(output)

## Part 3: Perform Evaluation

Here we load the evaluation dataset, which we will use to evaluate our RAG system.

In [32]:
import json

evaluation_dataset_name = "wikipedia-dev"
evaluation_ds_filepath = "data/triviaqa-rc/qa/{ds_name}.json".format(ds_name=evaluation_dataset_name)

with open (evaluation_ds_filepath, "r") as json_file:

    eval_ds = json.load(json_file)["Data"]

In [None]:
evaluation_dataset_name = "wikipedia-dev"
evaluation_ds_filepath = "data/triviaqa-rc/qa/{ds_name}.json".format(ds_name=evaluation_dataset_name)

with open (evaluation_ds_filepath, "r") as json_file:

    eval_ds = json.load(json_file)["Data"]

model.provide_no_context = False
model._no_context_string = "" # Doesn't matter since provide_no_context is False

MAX_EVALS = 100

evals = 0
accurate_evals = 0

#evaluation loop
for i in range(min(len(eval_ds), MAX_EVALS)):

    correct_answers = [answer.lower() for answer in eval_ds[i]["Answer"]["Aliases"]]

    model_answer = model.predict(eval_ds[i]["Question"]).lower()

    if model_answer in correct_answers:

        accurate_evals += 1

    evals += 1

print("Evaluation Complete")
print("Over {evals} datapoints, the model accurately answered {accurate_evals}, reflecting an over all accuracy of {overall_accuracy}".format(
    evals=evals,
    accurate_evals=accurate_evals,
    overall_accuracy=(accurate_evals/evals)
))
    
    