In [1]:
import os
import xml.etree.ElementTree as ET
import json


# Function to load data from XML files within subfolders
def load_medquad_data(data_path):
    data = []
    for subfolder in os.listdir(data_path):
        subfolder_path = os.path.join(data_path, subfolder)
        if os.path.isdir(subfolder_path):
            for file in os.listdir(subfolder_path):
                if file.endswith(".xml"):
                    file_path = os.path.join(subfolder_path, file)
                    tree = ET.parse(file_path)
                    root = tree.getroot()
                    
                    
                    focus_annotations = root.find('FocusAnnotations/UMLS/CUIs/CUI').text if root.find('FocusAnnotations/UMLS/CUIs/CUI') is not None else None
                    semantic_types = [st.text for st in root.findall('FocusAnnotations/UMLS/SemanticTypes/SemanticType')] if root.find('FocusAnnotations/UMLS/SemanticTypes/SemanticType') is not None else []
                    semantic_group = root.find('FocusAnnotations/UMLS/SemanticGroup').text if root.find('FocusAnnotations/UMLS/SemanticGroup') is not None else None
                    
                    synonyms = [syn.text for syn in root.findall('FocusAnnotations/UMLS/Synonyms/Synonym')] if root.find('FocusAnnotations/UMLS/Synonyms/Synonym') is not None else []
                    
                    for qa_pair in root.findall('QAPairs/QAPair'):
                        pid = qa_pair.get('pid')
                        question = qa_pair.find('Question').text
                        question_type = qa_pair.find('Question').get('qtype')
                        answer = qa_pair.find('Answer').text
                        
                        data.append({
                            
                            'semantic_group': semantic_group,
                            'synonyms': synonyms,
                            'question': question,
                            'question_type': question_type,
                            'answer': answer
                        })
    return data

# Path to the dataset
data_path = "C:\\Users\\Moshe\\Desktop\\ML_Chatbot\\MedQuAD"

# Load the dataset
dataset = load_medquad_data(data_path)

# Extract questions and answers
questions = [item['question'] for item in dataset]
answers = [item['answer'] for item in dataset]

# Print a sample of the data to verify
for sample in dataset[:1]:
    print(json.dumps(sample, indent=4))


{
    "semantic_group": "Disorders",
    "synonyms": [],
    "question": "What is (are) A guide to clinical trials for cancer ?",
    "question_type": "information",
    "answer": null
}


In [2]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# Vectorize the questions
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(questions)

def get_retrieval_response(query):
    query_vec = vectorizer.transform([query])
    similarities = cosine_similarity(query_vec, X).flatten()
    best_match_index = similarities.argmax()
    return questions[best_match_index], answers[best_match_index]


In [3]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments

# Load pre-trained model and tokenizer
model_name = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Prepare dataset for training
def save_training_data(questions, answers, file_path):
    with open(file_path, 'w',encoding='utf-8') as f:
        for q, a in zip(questions, answers):
            f.write(f"<|startoftext|>{q}<|sep|>{a}<|endoftext|>\n")

train_file_path = "medquad_train.txt"
save_training_data(questions, answers, train_file_path)

# Load dataset
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path=train_file_path,
    block_size=64
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=16,
    save_steps=10_000,
    save_total_limit=2,
)

# Train the model
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

trainer.train()




  0%|          | 0/6189 [00:00<?, ?it/s]

{'loss': 2.1424, 'grad_norm': 3.041133165359497, 'learning_rate': 4.5960575214089515e-05, 'epoch': 0.08}
{'loss': 1.8472, 'grad_norm': 2.7776498794555664, 'learning_rate': 4.192115042817903e-05, 'epoch': 0.16}
{'loss': 1.7807, 'grad_norm': 2.1909759044647217, 'learning_rate': 3.788172564226855e-05, 'epoch': 0.24}
{'loss': 1.6989, 'grad_norm': 1.922439455986023, 'learning_rate': 3.3842300856358054e-05, 'epoch': 0.32}
{'loss': 1.668, 'grad_norm': 1.940169334411621, 'learning_rate': 2.980287607044757e-05, 'epoch': 0.4}
{'loss': 1.6517, 'grad_norm': 1.9824458360671997, 'learning_rate': 2.5763451284537084e-05, 'epoch': 0.48}
{'loss': 1.6293, 'grad_norm': 1.8078033924102783, 'learning_rate': 2.1724026498626597e-05, 'epoch': 0.57}
{'loss': 1.6076, 'grad_norm': 2.0845260620117188, 'learning_rate': 1.768460171271611e-05, 'epoch': 0.65}
{'loss': 1.6032, 'grad_norm': 1.8674557209014893, 'learning_rate': 1.3645176926805623e-05, 'epoch': 0.73}
{'loss': 1.5741, 'grad_norm': 1.8647161722183228, 'lear

TrainOutput(global_step=6189, training_loss=1.6941662096672903, metrics={'train_runtime': 52892.9103, 'train_samples_per_second': 1.872, 'train_steps_per_second': 0.117, 'total_flos': 3234240110592000.0, 'train_loss': 1.6941662096672903, 'epoch': 1.0})

In [5]:
def get_retrieval_response(query):
    query_vec = vectorizer.transform([query])
    similarities = cosine_similarity(query_vec, X).flatten()
    best_match_index = similarities.argmax()
    
    question = questions[best_match_index] if best_match_index < len(questions) else None
    answer = answers[best_match_index] if best_match_index < len(answers) else None
    
    if question is None or answer is None:
        print(f"Warning: Retrieved question or answer is None. Index: {best_match_index}")
        return "I couldn't find a specific answer to that question.", "I'm sorry, but I don't have enough information to provide a reliable answer."
    
    return question, answer

def get_retrieval_confidence(query):
    # Compute similarity/confidence score for retrieval-based response
    query_vec = vectorizer.transform([query])
    similarities = cosine_similarity(query_vec, X).flatten()
    return similarities.max()

def get_response(query):
    retrieval_confidence = get_retrieval_confidence(query)
    print(f"Retrieval confidence: {retrieval_confidence}")
    if retrieval_confidence > 0.5:
        question, response = get_retrieval_response(query)
        print(f"Retrieved question: {question}")
        print(f"Retrieved response: {response}")
    else:
        response = generate_response(query)
        print(f"Generated response: {response}")
    return response

def generate_response(query):
    try:
        inputs = tokenizer.encode(f"<|startoftext|>{query}<|sep|>", return_tensors='pt')
        outputs = model.generate(inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2)
        decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = decoded_output.split("<|sep|>")[-1].strip()
        return response if response else "I'm sorry, I couldn't generate a response."
    except Exception as e:
        print(f"Error generating response: {e}")
        return "I'm sorry, there was an error generating a response."

# Example usage
query = "What is Abetalipoproteinemia?"
response = get_response(query)
print(response)


Retrieval confidence: 0.9844728203897749
Retrieved question: What is (are) Abetalipoproteinemia ?
Retrieved response: Abetalipoproteinemia is a condition characterized by the inability to fully absorb dietary fats, cholesterol and fat-soluble vitamins. Signs and symptoms appear in the first few months of life and can include failure to thrive; diarrhea; acanthocytosis; and stool abnormalities. Other features develop later in childhood and often impair the function of the nervous system, potentially causing slower intellectual development; poor muscle coordination; progressive ataxia; and an eye disorder called retinitis pigmentosa. Most of the symptoms are due to defects in the absorption and transport of vitamin E. Abetalipoproteinemia is caused by mutations in the MTTP gene and is inherited in an autosomal recessive manner. Early diagnosis, high-dose vitamin E therapy, and medium-chain fatty acid supplements may slow the progression of the nervous system abnormalities. Long-term outl

In [7]:
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

def preprocess_query(query):
    tokens = nltk.word_tokenize(query.lower())
    lemmatized = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
    return ' '.join(lemmatized)

def get_response(query):
    preprocessed_query = preprocess_query(query)
    retrieval_confidence = get_retrieval_confidence(preprocessed_query)
    print(f"Retrieval confidence: {retrieval_confidence}")

    if retrieval_confidence > 0.8:
        question, response = get_retrieval_response(preprocessed_query)
        print(f"Retrieved question: {question}")
        print(f"Retrieved response: {response}")
        return response
    elif retrieval_confidence > 0.5:
        question, retrieved_response = get_retrieval_response(preprocessed_query)
        generated_response = generate_response(query, context=retrieved_response)
        print(f"Combined response: {generated_response}")
        return generated_response
    else:
        clarification = ask_for_clarification(query)
        if clarification:
            return get_response(query + " " + clarification)
        else:
            generated_response = generate_response(query)
            print(f"Generated response: {generated_response}")
            return generated_response

def generate_response(query, context=None):
    if context:
        prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
    else:
        prompt = f"Question: {query}\nAnswer:"
    
    inputs = tokenizer.encode(prompt, return_tensors='pt')
    outputs = model.generate(inputs, max_length=200, num_return_sequences=1, no_repeat_ngram_size=2)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split("Answer:")[-1].strip()

def ask_for_clarification(query):
    print(f"I'm not sure I fully understand. Can you provide more context about '{query}'?")
    return input("Your clarification (or press Enter to skip): ")

query = "symptoms of hypertension?"
response = get_response(query)
print(response)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Moshe\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Moshe\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Moshe\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Retrieval confidence: 0.4915027989825902
I'm not sure I fully understand. Can you provide more context about 'symptoms of hypertension?'?


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Retrieval confidence: 0.590298533520142
Combined response: Hypertensive symptoms are the most common symptom of a hypertrophic cardiomyopathy. The symptoms include
  
-  -  blood pressure - high blood sugar  (low blood glucose)  and high levels of blood cholesterol  These symptoms can be caused by
 - - a heart disease or a stroke  or by a disease
 or stroke
or by an heart attack or an injury
 and a blood clot
These symptoms may be related to a condition called hypertensive hypertension.  The most commonly reported symptoms for hypertrophosphate deficiency are
A heart condition that causes high or low blood flow to the heart. This condition is called ahystolic heart failure. It causes blood to flow in the arteries that carry blood from the lungs to your heart
 a high level of
Hypertensive symptoms are the most common symptom of a hypertrophic cardiomyopathy. The symptoms include
  
-  -  blood pressure - high blood sugar  (low blood glucose)  and high levels of blood cholesterol  These 

In [None]:
import joblib

# Save the vectorizer and the retrieval data
joblib.dump(vectorizer, 'vectorizer.pkl')
joblib.dump((questions, answers), 'qa_data.pkl')


: 

In [11]:
torch.save(model, 'generative_model.pth')
