In [1]:
from flask import Flask, render_template, request, jsonify
import os
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
from openai import AzureOpenAI
import requests
from pathlib import Path
import PyPDF2
from pptx import Presentation
from docx import Document
import logging
import nltk
from nltk.tokenize import sent_tokenize
from sklearn.cluster import KMeans
import re
import yaml  
import torch
from langchain_community.document_loaders import PyPDFLoader

torch.set_num_threads(1)

# Download NLTK data for sentence tokenization
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

app = Flask(__name__)

# Retrieve configurations from environment variables
API_TYPE = os.environ.get("API_TYPE")
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
DEEPSEEK_ENDPOINT = os.environ.get("DEEPSEEK_ENDPOINT")
AZURE_OPENAI_KEY = os.environ.get("AZURE_OPENAI_KEY")
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_VERSION = os.environ.get("AZURE_OPENAI_VERSION")
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT")

# Validate configurations
if API_TYPE == "deepseek":
    if not DEEPSEEK_API_KEY or not DEEPSEEK_ENDPOINT:
        raise ValueError("DEEPSEEK_API_KEY and DEEPSEEK_ENDPOINT must be set")
elif API_TYPE == "azure":
    if not all([AZURE_OPENAI_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_VERSION, AZURE_OPENAI_DEPLOYMENT]):
        raise ValueError("All Azure OpenAI configurations must be set")
else:
    raise ValueError("API_TYPE must be either 'deepseek' or 'azure'")

# Initialize Azure OpenAI client if using Azure
if API_TYPE == "azure":
    azure_client = AzureOpenAI(
        api_key=AZURE_OPENAI_KEY,
        api_version=AZURE_OPENAI_VERSION,
        azure_endpoint=AZURE_OPENAI_ENDPOINT,
    )

# Initialize embedding model
embedder = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
num_retrieved_indices = 5 # k

# Global variables for indexing
documents = []
embeddings = []
chunk_metadata = []
faiss_index = None
INDEX_INITIALIZED = False



# Store conversation history (for multi-turn dialogue)
conversation_history = []

def PDFLoader(file_path):
    logger.info(f"Attempting to read PDF: {file_path}")
    loader = PyPDFLoader(file_path)
    pages = []
    for page in loader.lazy_load():
        pages.append(page)
    logger.info('PDF file loaded.')
    return pages


# Document processing functions with page tracking
def extract_text_from_pdf(file_path):
    pages = PDFLoader(file_path)
    try:
        chunks_with_pages = []
        for i in range(len(pages)):
            page_num = i + 1
            text = pages[i].page_content
            chunks = chunk_text(text)
            for chunk in chunks:
                chunks_with_pages.append({
                                'text': chunk,
                                'page': page_num
                            })
        logger.info(f"Successfully read {len(chunks_with_pages)} chunks from {file_path}")
        return chunks_with_pages
    except Exception as e:
        logger.error(f"Error reading PDF {file_path}: {str(e)}")
        return []

    # try:
    #     with open(file_path, 'rb') as file:
    #         pdf_reader = PyPDF2.PdfReader(file)
    #         chunks_with_pages = []
    #         for page_num, page in enumerate(pdf_reader.pages, start=1):
    #             text = page.extract_text() + "\n"
    #             if text.strip():
    #                 chunks = chunk_text(text.strip())
    #                 for chunk in chunks:
    #                     chunks_with_pages.append({
    #                         'text': chunk,
    #                         'page': page_num
    #                     })
    #         logger.info(f"Successfully read {len(chunks_with_pages)} chunks from {file_path}")
    #         return chunks_with_pages
    # except Exception as e:
    #     logger.error(f"Error reading PDF {file_path}: {str(e)}")
    #     return []

def chunk_text(text, max_length=1000, min_length=300):
    """Split text into chunks with accurate length calculation and optimized merging.
    
    Args:
        text: Input text to be chunked
        max_length: Maximum character length per chunk (default: 1000)
        min_length: Minimum character length for final chunks (default: 300)
    
    Returns:
        List of text chunks meeting length requirements
    """
    if not text.strip():
        logger.warning("Empty text provided to chunk_text")
        return []

    chunks = []
    current_chunk = []
    current_length = 0

    for sentence in sent_tokenize(text):
        sentence_length = len(sentence)
        space_length = 1 if current_chunk else 0  # Space between sentences
        
        # Calculate potential new length
        new_length = current_length + space_length + sentence_length
        
        if current_chunk and new_length > max_length:
            # Finalize current chunk if meets minimum length
            if current_length >= min_length:
                chunks.append(" ".join(current_chunk))
                current_chunk = []
                current_length = 0
            # Else keep accumulating even if over max_length
            
        # Add sentence to current chunk
        current_chunk.append(sentence)
        current_length += space_length + sentence_length

    # Handle remaining text
    if current_chunk:
        final_chunk = " ".join(current_chunk)
        # Merge small final chunk with previous if needed
        if len(final_chunk) < min_length and chunks:
            chunks[-1] += " " + final_chunk
            # Split back if merge caused overflow
            if len(chunks[-1]) > max_length:
                last_chunk = chunks.pop()
                chunks.extend([last_chunk[:max_length], last_chunk[max_length:]])
        else:
            chunks.append(final_chunk)

    # Post-process to ensure all chunks meet length requirements
    final_chunks = []
    for chunk in chunks:
        while len(chunk) > max_length:
            final_chunks.append(chunk[:max_length])
            chunk = chunk[max_length:]
        if chunk:
            final_chunks.append(chunk)
    
    # Log chunk details
    logger.info('='*10 + f"Created {len(chunks)} chunks from {len(text)} characters" + "="*10)

    # logger.info(f"Created {len(final_chunks)} chunks from {len(text)} characters")
    for idx, chunk in enumerate(final_chunks):
        logger.info(f"Chunk {idx}: {len(chunk):4} chars | Start: {chunk[:120].strip()}")
        logger.info('--'*20)      
    
    
    return final_chunks

# Function to initialize the FAISS index
def initialize_index():
    global documents, embeddings, chunk_metadata, faiss_index, INDEX_INITIALIZED
    if INDEX_INITIALIZED:
        logger.info("Index already initialized, skipping...")
        return
    
    documents_dir = "./documents"
    
    if not os.path.exists(documents_dir):
        os.makedirs(documents_dir)
        logger.warning(f"Created empty documents directory: {documents_dir}")
    
    documents = []
    embeddings = []
    chunk_metadata = []
    
    file_extractors = {
        # '.txt': extract_text_from_txt,
        '.pdf': extract_text_from_pdf,
        # '.pptx': extract_text_from_ppt,
        # '.docx': extract_text_from_docx
    }
    
    logger.info("Starting document indexing...")
    for root, dirs, files in os.walk(documents_dir):
        for filename in files:
            logger.info(f'process for {filename}')
            ext = Path(filename).suffix.lower()
            if ext in file_extractors:
                filepath = os.path.join(root, filename)
                logger.info(f'process for {filepath}')
                chunks_with_pages = file_extractors[ext](filepath)
                logger.info(f'total pages extracted: {len(chunks_with_pages)}')
                for i, chunk_info in enumerate(chunks_with_pages):
                    documents.append(chunk_info['text'])
                    chunk_metadata.append({
                        'filepath': filepath,
                        'chunk_index': i,
                        'original_text': chunk_info['text'],
                        'page': chunk_info['page']
                    })
    
    if documents:
        logger.info(f"Generating embeddings for {len(documents)} document chunks")
        logger.info(f"Sample document chunk: {documents[0][:100]}...")
        embeddings = embedder.encode(documents, show_progress_bar=True)
        embeddings = np.array(embeddings).astype('float32')
        logger.info(f"Generated embeddings - shape: {embeddings.shape}, dtype: {embeddings.dtype}")
    
        dimension = embeddings.shape[1]
        logger.info(f"Creating FAISS index with dimension {dimension}")
        faiss_index = faiss.IndexFlatL2(dimension)
        faiss_index.add(embeddings)
        logger.info(f"Indexed {len(documents)} document chunks")
        logger.info(f"FAISS index size: {faiss_index.ntotal}")
        if len(documents) != len(chunk_metadata):
            logger.error(f"Mismatch between documents ({len(documents)}) and chunk_metadata ({len(chunk_metadata)})")
    else:
        logger.warning("No documents found to index. Please add files to the 'documents/' directory.")
        faiss_index = None
    
    INDEX_INITIALIZED = True
    return faiss_index


def load_prompt_config(file_name):
    # Load configuration
    try:
        with open(file_name, 'r', encoding='utf-8') as config_file:
            config = yaml.safe_load(config_file)
        return config
    except FileNotFoundError:
        logger.error("prompt.yaml not found")
        raise
    except yaml.YAMLError:
        logger.error("Invalid YAML in prompt.yaml")
        raise

def create_prompt(prompt_template, passages, query, max_context_tokens=3000):
    total_length = len(query) + len(prompt_template) - len("{passages}") - len("{query}")
    passages_text = ""
    passage_refs = []
    for idx, passage in enumerate(passages, 1):
        # passage = ''.join([i.strip() for i in '\n'.split(passage)])
        ref_id = f"[Ref{idx}]"
        if total_length + len(passage) < max_context_tokens * 4:
            passages_text += f"{ref_id} {passage}\n"
            passage_refs.append((ref_id, passage))
            total_length += len(passage) + 1
        else:
            logger.warning(f"Truncated passages to fit within {max_context_tokens} tokens")
            break
    logger.info(f"Passages provided in prompt: {passages_text}")
    return prompt_template.format(passages=passages_text, query=query), passage_refs

# Function to cluster passages based on semantic similarity
def cluster_passages(passages, embeddings, max_clusters=3):
    if len(passages) <= 1:
        return [(passages, embeddings)] if passages else []
    
    # Use KMeans to cluster embeddings
    num_clusters = min(max_clusters, len(passages))
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    labels = kmeans.fit_predict(embeddings)
    
    # Group passages by cluster
    clustered_passages = [[] for _ in range(num_clusters)]
    clustered_embeddings = [[] for _ in range(num_clusters)]
    for idx, label in enumerate(labels):
        clustered_passages[label].append(passages[idx])
        clustered_embeddings[label].append(embeddings[idx])
    
    # Return clusters as list of (passages, embeddings) tuples
    return [(clustered_passages[i], clustered_embeddings[i]) for i in range(num_clusters) if clustered_passages[i]]

# Function to compute semantic similarity between two texts
def compute_similarity(text1, text2):
    embeddings = embedder.encode([text1, text2], show_progress_bar=False)
    similarity = np.dot(embeddings[0], embeddings[1]) / (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
    return similarity

# Function to format the response as HTML with citations at the end
def format_response(bot_message, passage_refs, merged_chunks):
    logger.info(f"Bot message: {bot_message}")
    
    # Convert the bot message to HTML, preserving paragraphs and bullet points
    lines = bot_message.strip().split('\n')
    formatted_lines = []
    in_list = False
    used_refs = set()
    
    # Parse the answer to find referenced passages
    for line in lines:
        line = line.strip()
        # Skip the "References:" line at the end
        if line.startswith("References:"):
            continue
        # Look for [RefX] patterns in the line
        refs_in_line = re.findall(r'\[Ref\d+\]', line)
        used_refs.update(refs_in_line)
        # Remove [RefX] from the line for display
        line = re.sub(r'\[Ref\d+\]', '', line).strip()
        if not line:
            if in_list:
                formatted_lines.append('</ul>')
                in_list = False
            continue
        if line.startswith('- '):
            if not in_list:
                formatted_lines.append('<ul>')
                in_list = True
            formatted_lines.append(f'<li>{line[2:]}</li>')
        else:
            if in_list:
                formatted_lines.append('</ul>')
                in_list = False
            formatted_lines.append(f'<p>{line}</p>')
    
    if in_list:
        formatted_lines.append('</ul>')
    
    # Generate citations for used references
    citations = []
    if used_refs:
        for ref_id, passage in passage_refs:
            if ref_id in used_refs:
                filepath = passage.split('\n')[0].replace("Document: ", "")
                for chunk in merged_chunks:
                    if chunk['filepath'] == filepath:
                        citation = f"Source: {os.path.basename(chunk['filepath'])}"
                        if chunk['pages']:
                            pages_str = ', '.join(map(str, sorted(set(chunk['pages']))))
                            citation += f", Pages/Slides: {pages_str}"
                        if citation not in citations:
                            citations.append(citation)
    else:
        # Fallback: Use semantic similarity to determine relevant passages
        logger.warning("No references explicitly used in the answer. Using semantic similarity to find relevant passages.")
        answer_text = bot_message.lower()
        for ref_id, passage in passage_refs:
            passage_text = passage.lower()
            # Check for keyword overlap
            answer_words = set(answer_text.split())
            passage_words = set(passage_text.split())
            common_words = answer_words.intersection(passage_words)
            # Compute semantic similarity
            similarity = compute_similarity(answer_text, passage_text)
            # Include the passage if there is significant overlap or high similarity
            if len(common_words) > 3 or similarity > 0.7:
                used_refs.add(ref_id)
                filepath = passage.split('\n')[0].replace("Document: ", "")
                for chunk in merged_chunks:
                    if chunk['filepath'] == filepath:
                        citation = f"Source: {os.path.basename(chunk['filepath'])}"
                        if chunk['pages']:
                            pages_str = ', '.join(map(str, sorted(set(chunk['pages']))))
                            citation += f", Pages/Slides: {pages_str}"
                        if citation not in citations:
                            citations.append(citation)
        if not citations:
            logger.warning("No relevant passages found via semantic similarity. No citations will be included.")

    # Add citations at the end
    if citations:
        citations_html = '<div class="references"><strong>References:</strong><br>' + '<br>'.join(citations) + '</div>'
        formatted_lines.append(citations_html)
    
    return ''.join(formatted_lines)

def prepare_passages(user_message, faiss_index):
    query_embedding = embedder.encode([user_message], show_progress_bar=False)
    query_embedding = np.array(query_embedding).astype('float32')
    if len(query_embedding.shape) == 1:
        query_embedding = query_embedding.reshape(1, -1)
    logger.info(f"Query embedding generated - shape: {query_embedding.shape}, norm: {np.linalg.norm(query_embedding)}")
    passages = []
    passage_embeddings = []
    merged_chunks = []
    clusters = []
    if faiss_index is not None and len(documents) > 0:
        logger.info('='*10 + " Performing FAISS search" + "="*10)
        distances, indices = faiss_index.search(query_embedding, k=num_retrieved_indices)
        logger.info(f"FAISS search results:")
        logger.info(f"- Retrieved indices: {indices[0].tolist()}")
        logger.info(f"- Distances: {distances[0].tolist()}")
        logger.info(distances[0])
        logger.info('='*30)
        valid_indices = [idx for idx in indices[0] if 0 <= idx < len(chunk_metadata)]
        if not valid_indices:
            logger.warning("No valid indices retrieved from FAISS search")
            passages = []
            passage_embeddings = []
        else:
            valid_indices.sort(key=lambda idx: (chunk_metadata[idx]['filepath'], chunk_metadata[idx]['chunk_index']))
            
            current_chunk = None
            for idx in valid_indices:
                chunk_info = chunk_metadata[idx]
                if (current_chunk is None or
                        current_chunk['filepath'] != chunk_info['filepath'] or
                        current_chunk['chunk_index'] + 1 != chunk_info['chunk_index']):
                    if current_chunk is not None:
                        merged_chunks.append(current_chunk)
                    current_chunk = {
                        'filepath': chunk_info['filepath'],
                        'chunk_index': chunk_info['chunk_index'],
                        'text': chunk_info['original_text'],
                        'pages': [chunk_info['page']] if chunk_info['page'] is not None else []
                    }
                else:
                    current_chunk['text'] += " " + chunk_info['original_text']
                    current_chunk['chunk_index'] = chunk_info['chunk_index']
                    if chunk_info['page'] is not None:
                        current_chunk['pages'].append(chunk_info['page'])
            if current_chunk is not None:
                merged_chunks.append(current_chunk)
            
            # Create passages and embeddings for clustering
            for chunk in merged_chunks:
                sep_sign = '--'*15
                passage = f"Document: {chunk['filepath']}\nContent: {chunk['text']}\n{sep_sign}\n"
                passages.append(passage)
                passage_embeddings.append(embedder.encode([passage], show_progress_bar=False)[0])
            
            # Cluster passages based on semantic similarity
            passage_embeddings = np.array(passage_embeddings).astype('float32')
            clusters = cluster_passages(passages, passage_embeddings)
            
            # Rebuild passages based on clusters
            passages = []
            for clustered_passages, _ in clusters:
                cluster_text = "\n\n".join(clustered_passages)
                passages.append(cluster_text)
            
            logger.info(f"Clustered passages: {passages}")
            return passages, merged_chunks

  from .autonotebook import tqdm as notebook_tqdm
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: multi-qa-mpnet-base-dot-v1


In [2]:
faiss_index = initialize_index()
prompt_config = load_prompt_config('prompt.yaml')
SYSTEM_PROMPT = prompt_config.get('system_prompt', '')
prompt_template = prompt_config.get('prompt_template', '')

INFO:__main__:Starting document indexing...
INFO:__main__:process for LLD Dacomitinib - CHN - Chinese (Simplified).PDF
INFO:__main__:process for ./documents/LLD Dacomitinib - CHN - Chinese (Simplified).PDF
INFO:__main__:Attempting to read PDF: ./documents/LLD Dacomitinib - CHN - Chinese (Simplified).PDF
INFO:__main__:PDF file loaded.
INFO:__main__:Chunk 0:  617 chars | Start: 第1页，共12页
Version No. : 20240229
核准日期：2019 年05 月15 日
修改日期：2020 年10 月16 日；2021 年 08 月09 日；2024 年 02 月29 日
达可替尼片说明书
INFO:__main__:----------------------------------------
INFO:__main__:Chunk 0:  901 chars | Start: 第2页，共12页
Version No. : 20240229
本品的推荐剂量为每日一次口服 45 mg，直至出现疾病进展或不可接受的毒性。本品可与食物同服，
也可不与食物同服（见【药代动力学】）。
每天在大致相同的时间服用本品。如果患者呕
INFO:__main__:----------------------------------------
INFO:__main__:Chunk 0: 1000 chars | Start: 第3页，共12页
Version No. : 20240229
不建议 对轻度或中度肾功能损害（ 依据 Cockcroft-Gault 公式预计肌酐清 除率 [CLcr] 在 30~89
mL/min）的患者调整剂量。尚未确定重度肾功能损害
INFO:__main__:----------------------------------------
INFO:__main__:C

In [3]:
user_message = '达可替尼片的分子式和英文名是什么'
# user_message = '根据达可替尼片的分子式,在当前domain,有哪些类似的药品'
# user_message = '根据达可替尼片的特性和副作用，我们目前有哪些相似的药品'
# user_message = '在当前正在实验的新药中，我们目前有哪些相似的药品和达可替尼片的特性相近，并且副作用要小'

passages, merged_chunks = prepare_passages(user_message,faiss_index)

logger.info("Constructing RAG prompt...")
prompt, passage_refs = create_prompt(prompt_template, passages, user_message, max_context_tokens=3000)
logger.info(f"Final prompt length: {len(prompt)} characters")

INFO:__main__:Query embedding generated - shape: (1, 768), norm: 5.977413654327393
INFO:__main__:FAISS search results:
INFO:__main__:- Retrieved indices: [7, 16, 0, 14, 1]
INFO:__main__:- Distances: [30.50945281982422, 32.47759246826172, 33.56614685058594, 33.96271896362305, 34.227542877197266]
INFO:__main__:[30.509453 32.477592 33.566147 33.96272  34.227543]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  ret = a @ b
  ret = a @ b
  ret = a @ b
INFO:__main__:Clustered passages: ['Document: ./documents/LLD Dacomitinib - CHN - Chinese (Simplified).PDF\nContent: 风险。\n数据\n动物数据\n器官形成期间，妊娠大鼠每日口服 5 mg/kg/天（约为建议人用剂量时暴露量（基于曲线下面积 [AUC] \n的 1.2 倍）达可替尼后，导致着床后流产、母体毒性以及胎儿体重下降的发生率增加。\n小鼠模型中的 EGFR 被破坏或消耗，表明 EGFR 在生殖和发育过程（包括胚泡植入、胎盘发育和胚\n胎-胎儿/出生后存活和发育）中至关重要。小鼠胚胎-胎儿或母体 E

In [4]:
print(prompt)

Base your answer strictly on the information provided in the following internal project documents. Structure your answer as follows:
1. Start with a brief summary paragraph that provides an overview of the answer.
2. Follow with detailed information in a natural format. Use bullet points only when listing multiple items (e.g., information about multiple projects). Otherwise, use natural paragraphs.
3. For every piece of information you use from the documents, you MUST include the corresponding Reference ID (e.g., [Ref1], [Ref2]) inline with the text. If you do not use any document, explicitly state that no documents were used.
4. At the end of your answer, list the Reference IDs you used (e.g., References: [Ref1], [Ref2]).

Documents:
[Ref1] Document: ./documents/LLD Dacomitinib - CHN - Chinese (Simplified).PDF
Content: 风险。
数据
动物数据
器官形成期间，妊娠大鼠每日口服 5 mg/kg/天（约为建议人用剂量时暴露量（基于曲线下面积 [AUC] 
的 1.2 倍）达可替尼后，导致着床后流产、母体毒性以及胎儿体重下降的发生率增加。
小鼠模型中的 EGFR 被破坏或消耗，表明 EGFR 在生殖和发育过程（包括胚泡植入、胎盘发育和胚
胎-胎儿/出生后存活

In [48]:
messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT
        }
    ]

messages.append({"role": "user", "content": prompt})

In [49]:
response = azure_client.chat.completions.create(
                
                messages=messages,
                max_tokens=1500,
                temperature=0.4
            )
bot_message = response.choices[0].message.content.strip()
logger.info("Received LLM response")
logger.info(f"Response length: {len(bot_message)} characters")

formatted_message = format_response(bot_message, passage_refs, merged_chunks)

        # Add bot response to conversation history
conversation_history.append({"role": "assistant", "content": bot_message})

INFO:openai._base_client:Retrying request to /chat/completions in 0.486361 seconds
INFO:openai._base_client:Retrying request to /chat/completions in 0.832427 seconds


APITimeoutError: Request timed out.

In [8]:
print(formatted_message)

NameError: name 'formatted_message' is not defined