In [1]:
!pip install --upgrade pip wheel

Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1


In [2]:
!pip install requests tqdm faiss-cpu transformers torch sentence-transformers textblob gensim numba accelerate ninja

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Collecting ninja
  Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting 

In [3]:
!MAX_JOBS=12 python -m pip -v install flash-attn --no-build-isolation  --use-pep517

Using pip 25.0.1 from /usr/local/lib/python3.11/dist-packages/pip (python 3.11)
Collecting flash-attn
  Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m71.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Running command Preparing metadata (pyproject.toml)


  torch.__version__  = 2.5.1+cu124


  running dist_info
  creating /tmp/pip-modern-metadata-98clmalw/flash_attn.egg-info
  writing /tmp/pip-modern-metadata-98clmalw/flash_attn.egg-info/PKG-INFO
  writing dependency_links to /tmp/pip-modern-metadata-98clmalw/flash_attn.egg-info/dependency_links.txt
  writing requirements to /tmp/pip-modern-metadata-98clmalw/flash_attn.egg-info/requires.txt
  writing top-level names to /tmp/pip-modern-metadata-98clmalw/flash_attn.egg-info/top_level.txt
  writing manifest file '/tmp/pip-modern-metadata-98clmalw/flash_attn.egg-info/SOURCES.txt'
  reading manifest file '/tmp/pip-modern-metadata-98clmalw/flash_attn.egg-info/SOU

In [4]:
import os
import requests
import zipfile
from pathlib import Path
from tqdm import tqdm
import re
import json
from gensim.utils import simple_preprocess
from textblob import TextBlob

# Directory to store downloaded and extracted data
DATA_DIR = Path("./mimic_textbooks")
# URLs for the dataset and pre-chunked JSON
dataset_url = "https://www.dropbox.com/scl/fi/gk1y8ll3d7wllwbb24kqe/textbooks.zip?rlkey=cdpqf8cbeu3difouvhwsc866w&st=resv96io&dl=1"
CHUNKED_DOCUMENTS_PATH = Path("./chunked_documents.json")
CHUNKED_DOCUMENTS_URL = "https://www.dropbox.com/scl/fi/07wd0zwvz2xcq80hy5f91/chunked_documents.json?rlkey=jwvfpczo4zeyke9j74cdphovi&st=oeqmcfi8&dl=1"

# Download and extract the dataset zip file
def download_and_extract_zip(url, extract_to=DATA_DIR):
    # Ensure the directory exists
    extract_to.mkdir(parents=True, exist_ok=True)

    # Download the zip file
    zip_path = extract_to / "textbooks.zip"
    print("Downloading dataset...")
    response = requests.get(url, stream=True)
    with open(zip_path, "wb") as file:
        for chunk in tqdm(response.iter_content(chunk_size=1024), unit='KB'):
            if chunk:
                file.write(chunk)

    # Extract the zip file
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_to)
    print("Dataset downloaded and extracted.")

# Load text files from a given directory
def load_text_files(directory):
    texts = []
    for file_path in Path(directory).glob("*.txt"):
        with open(file_path, "r", encoding="utf-8") as file:
            texts.append(file.read())
    return texts

# Cleaning and preprocessing function
def clean_and_tokenize(text):
    # Remove extra spaces, lowercase text, and remove special characters
    text = re.sub(r'\s+', ' ', text)
    text = text.lower()
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    tokens = simple_preprocess(text)
    return ' '.join(tokens)

# Chunk text into fixed-size chunks
def chunk_text(text, chunk_size=1000):
    words = text.split()
    return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]

# Main process:
if CHUNKED_DOCUMENTS_PATH.exists():
    print("Loading existing chunked_documents.json...")
    with open(CHUNKED_DOCUMENTS_PATH, "r", encoding="utf-8") as f:
        chunked_documents = json.load(f)
else:
    print("chunked_documents.json does not exist. Trying to download from remote URL...")
    try:
        response = requests.get(CHUNKED_DOCUMENTS_URL, allow_redirects=True)
        with open(CHUNKED_DOCUMENTS_PATH, "wb") as f:
            f.write(response.content)
        print("Successfully downloaded chunked_documents.json from remote URL.")
        with open(CHUNKED_DOCUMENTS_PATH, "r", encoding="utf-8") as f:
            chunked_documents = json.load(f)
    except Exception as e:
        print(f"Failed to download chunked_documents.json: {e}")
        print("Creating chunked_documents.json from dataset...")

        # Download and extract the textbooks if needed
        download_and_extract_zip(dataset_url)

        # Load, clean, and process documents
        documents = load_text_files(DATA_DIR / "en")
        cleaned_documents = [clean_and_tokenize(doc) for doc in documents]
        chunked_documents = []
        for doc in cleaned_documents:
            chunked_documents.extend(chunk_text(doc))
        print(f"Total document chunks created: {len(chunked_documents)}")

        # Save the chunked documents to JSON
        with open(CHUNKED_DOCUMENTS_PATH, "w", encoding="utf-8") as f:
            json.dump(chunked_documents, f)
        print("chunked_documents.json created.")

print(f"Total document chunks available: {len(chunked_documents)}")


chunked_documents.json does not exist. Trying to download from remote URL...
Successfully downloaded chunked_documents.json from remote URL.
Total document chunks available: 12272


In [5]:
import os
import numpy as np
import faiss
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModel

# Assuming 'chunked_documents' is already defined from earlier processing
INDEX_PATH = "./faiss_index.idx"
FAISS_INDEX_URL = "https://www.dropbox.com/scl/fi/05ez2886nz5fkkcqsv6hs/faiss_index.idx?rlkey=yil6ollju5smk04upluenqot4&st=yu0oji49&dl=1"
dimension = 384  # Embedding size from MiniLM model
# Load model and tokenizer (using PyTorch) only when needed
retrieval_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
retrieval_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

if os.path.exists(INDEX_PATH):
    print("Loading existing FAISS index from disk...")
    index = faiss.read_index(INDEX_PATH)
    print(f"Total embeddings indexed: {index.ntotal}")
else:
    print("FAISS index does not exist. Trying to download from remote URL...")
    try:
        response = requests.get(FAISS_INDEX_URL, allow_redirects=True)
        with open(INDEX_PATH, "wb") as f:
            f.write(response.content)
        print("Successfully downloaded FAISS index from remote URL.")
        index = faiss.read_index(INDEX_PATH)
        print(f"Total embeddings indexed: {index.ntotal}")
    except Exception as e:
        print("FAISS index not found. Creating index...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        retrieval_model.to(device)
        retrieval_model.eval()  # Set the model to evaluation mode

        # Function to generate embeddings for all chunks in a batch
        def get_embeddings_in_batch(texts, batch_size=16):
            all_embeddings = []

            # Wrap the loop with tqdm to display a progress bar
            for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
                batch_texts = texts[i:i + batch_size]

                # Tokenize the batch of texts
                inputs = retrieval_tokenizer(batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
                # Generate embeddings on the GPU
                outputs = retrieval_model(**inputs).last_hidden_state  # [batch_size, sequence_length, hidden_size]
                batch_embeddings = torch.mean(outputs, dim=1).cpu().detach().numpy()  # [batch_size, hidden_size]

                # Append batch embeddings to the list
                all_embeddings.extend(batch_embeddings)

            return np.array(all_embeddings)

        # Generate embeddings for all document chunks in batches
        embeddings = get_embeddings_in_batch(chunked_documents, batch_size=128)
        print(f"Generated embeddings for {len(embeddings)} document chunks.")

        # Create the FAISS index and add embeddings
        index = faiss.IndexFlatL2(dimension)
        # Ensure embeddings are in the correct shape and type
        embedding_matrix = np.array([embedding.flatten() for embedding in embeddings]).astype('float32')
        index.add(embedding_matrix)
        print(f"Total embeddings indexed: {index.ntotal}")

        # Write the FAISS index to disk
        faiss.write_index(index, INDEX_PATH)
        print(f"FAISS index written to {INDEX_PATH}")


tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

FAISS index does not exist. Trying to download from remote URL...
Successfully downloaded FAISS index from remote URL.
Total embeddings indexed: 12021


## Retrival Method

In [6]:
retrieval_model.cpu().eval()

# Function to generate embeddings for a new query
def get_query_embedding(query):
    with torch.no_grad():
        inputs = retrieval_tokenizer(query, return_tensors="pt", padding=True, truncation=True)
        outputs = retrieval_model(**inputs)
        embedding = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
    return embedding

# Load FAISS index with existing embeddings
embedding_dim = 384
index = faiss.IndexFlatL2(embedding_dim)

# Function to retrieve relevant documents based on the query
def retrieve_documents(query, top_k=5):
    query_embedding = get_query_embedding(query).astype("float32")
    distances, indices = index.search(query_embedding, top_k)
    results = [chunked_documents[idx] for idx in indices[0]]
    return results

# Test retrieval component
sample_query = "What are the symptoms of heart failure?"
similar_documents = retrieve_documents(sample_query)
print("Retrieved documents:", similar_documents)


Retrieved documents: ['the transcription factor batf xenografts grafted organs taken from different species than the recipient xenoimmunity in the context of immune mediated disease refers to immunity directed against foreign antigens of non human species such as bacteria derived antigens of the commensal microbiota that are targets in inﬂammatory bowel disease ibd xeroderma pigmentosum several autosomal recessive diseases caused by defects in repair of ultraviolet light induced dna damage defects in polη cause type xeroderma pigmentosum xid see linked linked xla genetic disorder in which cell development is arrested at the pre cell stage and no mature cells or antibodies are formed the disease is due to defect in the gene encoding the protein tyrosine kinase btk which is encoded on the chromosome linked hyper igm syndrome see cd ligand deficiency syndrome with some features resembling hyper igm syndrome it is caused by mutations in the protein nemo component of the nfκb signaling path

## Generation Method

In [7]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generation_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct", trust_remote_code=True)
generation_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3.5-mini-instruct", device_map="cuda", torch_dtype="auto", trust_remote_code=True)
generation_model.to(device).eval()


tokenizer_config.json:   0%|          | 0.00/3.98k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.45k [00:00<?, ?B/s]

configuration_phi3.py:   0%|          | 0.00/11.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3.5-mini-instruct:
- configuration_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_phi3.py:   0%|          | 0.00/73.8k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3.5-mini-instruct:
- modeling_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json:   0%|          | 0.00/16.3k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/195 [00:00<?, ?B/s]

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3LongRoPEScaledRotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm()
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm()
      )
    )
    (norm): Phi3RMSNorm()
  )
  (lm_head): Linear(in_features=3072, out

In [8]:

# Function to generate a response using retrieved context
def generate_response(query, context, max_new_tokens=100):
    input_text = f"User query: {query}\n\nContext:\n{context}\n\nAnswer:"

    # Tokenize the input and move tensors to GPU
    inputs = generation_tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)

    # Generate response using max_new_tokens to control output length
    with torch.no_grad():
        outputs = generation_model.generate(**inputs, max_new_tokens=max_new_tokens, num_return_sequences=1)

    # Decode the generated response
    response_text = generation_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response_text

# Testing generation with retrieved documents as context
retrieved_text = " ".join(similar_documents)  # Concatenate retrieved documents as context
response = generate_response(sample_query, retrieved_text)
print("Generated response:", response)

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. Calling `get_max_cache()` will raise error from v4.48


Generated response: User query: What are the symptoms of heart failure?

Context:
the transcription factor batf xenografts grafted organs taken from different species than the recipient xenoimmunity in the context of immune mediated disease refers to immunity directed against foreign antigens of non human species such as bacteria derived antigens of the commensal microbiota that are targets in inﬂammatory bowel disease ibd xeroderma pigmentosum several autosomal recessive diseases caused by defects in repair of ultraviolet light induced dna damage defects in polη cause type xeroderma pigmentosum xid see linked linked xla genetic disorder in which cell development is arrested at the pre cell stage and no mature cells or antibodies are formed the disease is due to defect in the gene encoding the protein tyrosine kinase btk which is encoded on the chromosome linked hyper igm syndrome see cd ligand deficiency syndrome with some features resembling hyper igm syndrome it is caused by mutatio