<a href="https://colab.research.google.com/github/Danny2173/RAGproject/blob/main/5_Retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Imports and Installs

In [None]:
# Install dependencies
%pip install -q transformers faiss-cpu tqdm

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Imports
import os, json, re, pickle
import torch
import numpy as np
import faiss
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from tqdm import tqdm


Mounted at /content/drive


##Load Corpus

In [None]:
load_path = '/content/drive/MyDrive/corpus.json'

# Load the corpus
with open(load_path, "r") as f:
    corpus = json.load(f)

print("Corpus loaded")

Corpus loaded


## Formatting for retriever understanding

In [None]:
import re

def format_text(text, merge_bullets=True):
    text = re.sub(r'(\.\s{2,})([a-zA-Z])', r'\1\n\2', text)

    # Removing Non-urgent advice heading
    text = re.sub(r'^Subsection: Non-urgent advice:\s*', '', text, flags=re.MULTILINE)

    # Fix: Convert colon to period if next line is Section/Subsection/Subsubsection
    text = re.sub(
        r'(:)\n(Section:|Subsection:|Subsubsection:)',
        r'.\n\2',
        text
    )

    def bullet_merge_safely(match):
        parts = [s.strip() for s in re.split(r'\s{2,}', match.group(3)) if s.strip()]
        merged_parts = []
        for part in parts:
            if part.endswith('.'):
                # If this part is a full sentence, break it to a new line
                merged_parts.append(part + "\n")
            else:
                merged_parts.append(part)

        # Merge non-sentence-ending items into a sentence
        bullet_items = [p for p in merged_parts if not p.endswith('\n')]

        # Ensure the last bullet item ends with a period
        if bullet_items and not bullet_items[-1].endswith('.'):
            bullet_items[-1] += '.'

        sentence = ""
        if bullet_items:
            sentence = ", ".join(bullet_items[:-1]) + ", and " + bullet_items[-1] if len(bullet_items) > 1 else bullet_items[0]

        # Join sentence part and line-break part
        result = match.group(1) + ' ' + sentence
        for part in merged_parts:
            if part.endswith('\n'):
                result += "\n" + part.strip()
        return result

    # Apply bullet merging
    text = re.sub(r'(:)(\s{2,})([^\n]+)', bullet_merge_safely, text)

    lines = text.splitlines()
    output = []
    buffer = []
    current_header = None
    in_do_section = False  # Flag to check if we're in a "Do" block

    for line in lines:
        line = line.strip()
        if not line:
            continue

        # Detect entering "Do" section
        if "subsubsection:" in line.lower() and "do" in line.lower():
            if buffer:
                if merge_bullets and current_header:
                    sentence = f"{current_header.rstrip(':').strip()} " + (
                        ", ".join(buffer[:-1]) + ", and " + buffer[-1] if len(buffer) > 1 else buffer[0]
                    ) + "."
                    output.append(sentence)
                else:
                    output.extend([f"- {item}" for item in buffer])
                buffer = []
            current_header = None
            in_do_section = True
            continue

        if line.startswith(("Section:", "Subsection:", "Subsubsection:")):
            if buffer:
                if merge_bullets and current_header and not in_do_section:
                    sentence = f"{current_header.rstrip(':').strip()} " + (
                        ", ".join(buffer[:-1]) + ", and " + buffer[-1] if len(buffer) > 1 else buffer[0]
                    ) + "."
                    output.append(sentence)
                else:
                    output.extend([f"- {item}" for item in buffer])
                buffer = []
                current_header = None

            output.append(line)
            continue

        if line.endswith(":") and not line.startswith("-") and not line.lower().strip() in ["do:", "don't:"]:
            if buffer:
                if merge_bullets and current_header:
                    sentence = f"{current_header.rstrip(':').strip()} " + (
                        ", ".join(buffer[:-1]) + ", and " + buffer[-1] if len(buffer) > 1 else buffer[0]
                    ) + "."
                    output.append(sentence)
                else:
                    output.extend([f"- {item}" for item in buffer])
                buffer = []

            current_header = line
            continue

        if line.startswith("- "):
            bullet_content = line[2:].strip()
            subitems = re.split(r'\s{2,}', bullet_content)
            buffer.extend([s.strip() for s in subitems if s.strip()])
            continue

        # Normal paragraph lines
        if buffer:
            if merge_bullets and current_header:
                sentence = f"{current_header.rstrip(':').strip()} " + (
                    ", ".join(buffer[:-1]) + ", and " + buffer[-1] if len(buffer) > 1 else buffer[0]
                ) + "."
                output.append(sentence)
            else:
                output.extend([f"- {item}" for item in buffer])
            buffer = []
            current_header = None

        output.append(line)

    # Final buffer flush
    if buffer:
        if in_do_section:
            sentence = "Do " + (
                ", ".join(buffer[:-1]) + ", and " + buffer[-1] if len(buffer) > 1 else buffer[0]
            ) + "."
            output.append(sentence)
        elif merge_bullets and current_header:
            sentence = f"{current_header.rstrip(':').strip()} " + (
                ", ".join(buffer[:-1]) + ", and " + buffer[-1] if len(buffer) > 1 else buffer[0]
            ) + "."
            output.append(sentence)
        else:
            output.extend([f"- {item}" for item in buffer])

    # Ensure headings end with a period
    final_text = "\n".join(output)
    final_text = re.sub(r'^(Section:[^\n]*?)(?<!\.)$', r'\1.', final_text, flags=re.MULTILINE)
    final_text = re.sub(r'^(Subsection:[^\n]*?)(?<!\.)$', r'\1.', final_text, flags=re.MULTILINE)
    final_text = re.sub(r'^(Subsubsection:[^\n]*?)(?<!\.)$', r'\1.', final_text, flags=re.MULTILINE)
    # Remove excess colons
    final_text = re.sub(r'(?<!Section)(?<!Subsection)(?<!Subsubsection):', '', final_text)

    # Ensure each sentence is on its own line
    final_text = re.sub(r'([.?!])\s+', r'\1\n', final_text)

    return final_text

for doc in corpus:
    # Keep original structure for Generator
    doc["text"] = format_text(doc["text"], merge_bullets=False)
    # Apply merge_bullets function to normalized text version
    doc["normalized_text"] = format_text(doc["normalized_text"], merge_bullets=True)


##Test

In [None]:
print(corpus[2]["normalized_text"])
print("-"*100)
print(corpus[2]["normalized_text"])
print("-"*100)

Section: esophageal achalasia.
esophageal achalasia is a rare disorder of the food pipe (oesophagus), which can make it difficult to swallow food and drink.
Normally, the muscles of the oesophagus contract to squeeze food along towards the stomach.
A ring of muscle at the end of the food pipe then relaxes to let food into the stomach.
In esophageal achalasia, the muscles in the oesophagus do not contract correctly and the ring of muscle can fail to open properly, or does not open at all.
Food and drink cannot pass into the stomach and becomes stuck.
It is often brought back up.
Section: esophageal achalasia.
Subsection: Symptoms of esophageal achalasia.
Not everyone with esophageal achalasia will have symptoms.
But most people with esophageal achalasia will find it difficult to swallow food or drink (known as deglutition disorders ).
Swallowing tends to get gradually more difficult or painful over a couple of years, to the point where it is sometimes impossible.
Other symptoms include 

## Importing Term Normalization functions

In [None]:
# Importing conversion functions

with open("/content/drive/MyDrive/filtered_term_to_CUI.pkl", "rb") as f:
    term_to_CUI = pickle.load(f)

with open("/content/drive/MyDrive/filtered_cui_to_main_term.pkl", "rb") as f:
    cui_to_main_term = pickle.load(f)

# Creating ngrams and tracking indices
def ngram_tokenize_tokens(tokens, max_len=5):
    ngrams = []
    for i in range(len(tokens)):
        for j in range(i + 1, min(i + 1 + max_len, len(tokens) + 1)):
            span = tokens[i:j]
            ngram = ' '.join(span)
            ngrams.append((ngram, i, j))
    return ngrams

# Normalizing medical terms using main condition name
def cui_normalization(sentence, max_ngram_len=5):
    tokens = re.findall(r'\w+|\W+', sentence)

    # Filtering out words
    words = [w.lower() for w in tokens if re.match(r'\w+', w)]

    # Call tokenization function to return ngrams tuples
    ngrams = ngram_tokenize_tokens(words, max_ngram_len)
    replacements = []

    # Tracking matched CUIs
    matched_cuis = set()

    # Searching for terms in dictionary
    for ngram, start, end in ngrams:
        if ngram in term_to_CUI:
            cui = term_to_CUI[ngram]
            if cui in cui_to_main_term:
                replacements.append((start, end, cui_to_main_term[cui]))
                matched_cuis.add(cui)

    # Sorting by length then index (ensure longer terms first)
    replacements.sort(key=lambda x: (x[0], -(x[1] - x[0])))
    used = set()
    final = []
    # Ensure no overlap (check already used indices)
    for start, end, main_term in replacements:
        if not any(i in used for i in range(start, end)):
            final.append((start, end, main_term))
            used.update(range(start, end))

    # Reconstruct the sentence
    word_idx = 0
    output = []
    i = 0
    while i < len(tokens):
        # If the token is a word
        if re.match(r'\w+', tokens[i]):
            # Checking if index appears in final
            match = next((f for f in final if f[0] == word_idx), None)
            if match:
                output.append(match[2]) # append main term
                skip = match[1] - match[0]
                while skip > 0 and i < len(tokens):
                    if re.match(r'\w+', tokens[i]):
                        skip -= 1
                    i += 1
                # Update word-level index
                word_idx += (match[1] - match[0])
                continue
            word_idx += 1
        output.append(tokens[i])
        i += 1

    normalized_text = ''.join(output)
    matched_main_terms = [term for _, _, term in final]
    return normalized_text, matched_main_terms, list(matched_cuis)


## Building FAISS Index

In [None]:
# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Loading dpr context encoder
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base").to(device)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

corpus_embeddings = []
corpus_texts = []              # Retain for Generator
normalized_for_index = []


for item in tqdm(corpus, desc="Encoding corpus"):

    # Extracting title from id
    raw_title = item.get("id", "").split("_")[0]
    if not raw_title.endswith("."):
        raw_title += "."

    # Appending title to raw text
    raw_body = item.get("text", "").strip()
    corpus_texts.append(f"{raw_title}\n{raw_body}")

    # Normalize title and append to text
    norm_title, _, _ = cui_normalization(raw_title)
    norm_body = item.get("normalized_text", "").strip()
    full_text = f"{norm_title}\n{norm_body}".strip()
    normalized_for_index.append(full_text)

    # Tokenize text
    inputs = ctx_tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512).to(device)
    with torch.no_grad():
        emb = ctx_encoder(**inputs).pooler_output.squeeze().cpu().numpy()
    # Append embedding
    corpus_embeddings.append(emb)

# Build FAISS index
corpus_embeddings = np.asarray(corpus_embeddings, dtype="float32")
corpus_embeddings /= np.linalg.norm(corpus_embeddings, axis=1, keepdims=True)

index = faiss.IndexFlatIP(corpus_embeddings.shape[1])
index.add(corpus_embeddings)

print(f"FAISS index built with {index.ntotal} vectors")


##Saving

In [None]:
# Save corpus_texts
with open("/content/drive/MyDrive/corpus_texts.pkl", "wb") as f:
    pickle.dump(corpus_texts, f)

# Save normalized_for_index
with open("/content/drive/MyDrive/normalized_for_index.pkl", "wb") as f:
    pickle.dump(normalized_for_index, f)

# Save raw corpus
with open("/content/drive/MyDrive/corpus.pkl", "wb") as f:
    pickle.dump(corpus, f)

# Save FAISS index
faiss.write_index(index, "/content/drive/MyDrive/faiss_index.index")