### 

# Fine Tuning a Language Learning Model 

### 1. Introduction

This project aims to fine tune a Language Learning Model (LLM) with data scraped from the AskDocs subreddit so as to make it behave like a doctor responding to medical queries. To achieve this, first, posts and comments from the AskDocs are to be pulled. The pulled data are then cleaned and standardized for the purpose of fine tuning. Then, a few potential LLM candidates are selected and benchmarked using Ollama, and then the appropriate model is selected for fine tuning. Once, fine tuned, the model is then tested and benchmarked again and the metrics are compared with the original model to evaluate its performance. 


### 2. Project Initialization

For the purpose of our project, we will be making use of a number of libraries in python. The libraries to be used have been listed below:


In [2]:
import numpy as np
import pandas as pd
import praw
import html
import ast
import os
import re
import ftfy
from markdown_it import MarkdownIt
from bs4 import BeautifulSoup
from textblob import TextBlob


In [7]:
user_agent = "Doc"
posts_data = []


With the help of reddit praw library, the posts and comments from the AskDocs subreddit will be scraped

In [8]:
rd = praw.Reddit(
    client_id = "Wyzdr8dkNuFVp0_EwSMuoA",
    client_secret = "f0uBJYRmNyb_ENBm4ACOolUlZV44Zg",
    user_agent = user_agent
)
subreddit = rd.subreddit('AskDocs')

The provided Python function "scrape_data" automates the collection of data from a subreddit on Reddit. Users can specify the sorting type of posts ("hot" or "top") and limit the number of posts to process. For each post, it extracts various details such as the post's ID, title, text content, author, score, number of comments, creation time, and any flair associated with the post. The function also collects data from the top comments, filtering out those made by moderators or automated systems, and captures details like the comment's ID, body, author, score, and more. Data is periodically saved to CSV files every 100 posts to prevent data loss and finally saves all collected data into a CSV file named according to the sort type. Error handling is incorporated to save progress in case of exceptions.

In [53]:
def scrape_data(sort_type='hot', limit=1000, top_comments=10):
    try:
        if sort_type not in ['hot', 'top']:
            raise ValueError("sort_type must be 'hot' or 'top'")

        posts_sort = subreddit.top(limit=limit) if sort_type == 'top' else subreddit.hot(limit=limit)

        for index, post in enumerate(posts_sort):
            if index % 10 == 0:
                print(f"Accessed {index} posts")

            post_data = {
                'id': post.id,
                'title': post.title,
                'selftext': post.selftext,
                'author': post.author.name if post.author else None,
                'score': post.score,
                'num_comments': post.num_comments,
                'created_utc': post.created_utc,
                'flair': post.link_flair_text
            }

            comments_data = []

            post.comments.replace_more(limit=0)
            comments = sorted([comment for comment in post.comments.list() if comment.author and not comment.author.name.startswith(('AutoModerator', 'AskDocs-ModTeam'))], key=lambda x: x.score, reverse=True)[:top_comments]

            for comment in comments:
                comment_data = {
                    'id': comment.id,
                    'body': comment.body,
                    'author': comment.author.name if comment.author else None,
                    'score': comment.score,
                    'parent_id': comment.parent_id,
                    'created_utc': comment.created_utc,
                    'is_submitter': comment.is_submitter,
                    'author_flair': comment.author_flair_text if comment.author_flair_text else None
                }
                comments_data.append(comment_data)

            post_data['comments'] = comments_data
            posts_data.append(post_data)

            if index % 100 == 0 and index != 0:
                initial_df = pd.DataFrame(posts_data)
                initial_df.to_csv(f'reddit_data_{index}.csv', index=False)
                print(f"Data saved at {index} posts")

    except Exception as e:
        print(f"An error occurred: {e}")
        initial_df = pd.DataFrame(posts_data)
        initial_df.to_csv('reddit_data_error_save.csv', index=False)
        raise

    posts_df = pd.DataFrame(posts_data)
    posts_df.to_csv(f'ScrapeAskDoc_{sort_type}data.csv', index=False)
    print("Successfully scraped the data")
    return posts_df

Scraping the top posts

In [54]:
top_posts = scrape_data("top")


Accessed 0 posts
Accessed 10 posts
Accessed 20 posts
Accessed 30 posts
Accessed 40 posts
Accessed 50 posts
Accessed 60 posts
Accessed 70 posts
Accessed 80 posts
Accessed 90 posts
Accessed 100 posts
Data saved at 100 posts
Accessed 110 posts
Accessed 120 posts
Accessed 130 posts
Accessed 140 posts
Accessed 150 posts
Accessed 160 posts
Accessed 170 posts
Accessed 180 posts
Accessed 190 posts
Accessed 200 posts
Data saved at 200 posts
Accessed 210 posts
Accessed 220 posts
Accessed 230 posts
Accessed 240 posts
Accessed 250 posts
Accessed 260 posts
Accessed 270 posts
Accessed 280 posts
Accessed 290 posts
Accessed 300 posts
Data saved at 300 posts
Accessed 310 posts
Accessed 320 posts
Accessed 330 posts
Accessed 340 posts
Accessed 350 posts
Accessed 360 posts
Accessed 370 posts
Accessed 380 posts
Accessed 390 posts
Accessed 400 posts
Data saved at 400 posts
Accessed 410 posts
Accessed 420 posts
Accessed 430 posts
Accessed 440 posts
Accessed 450 posts
Accessed 460 posts
Accessed 470 posts
Acc

Scraping the hot posts

In [57]:
hot_posts = scrape_data("hot")

Accessed 0 posts
Accessed 10 posts
Accessed 20 posts
Accessed 30 posts
Accessed 40 posts
Accessed 50 posts
Accessed 60 posts
Accessed 70 posts
Accessed 80 posts
Accessed 90 posts
Accessed 100 posts
Data saved at 100 posts
Accessed 110 posts
Accessed 120 posts
Accessed 130 posts
Accessed 140 posts
Accessed 150 posts
Accessed 160 posts
Accessed 170 posts
Accessed 180 posts
Accessed 190 posts
Accessed 200 posts
Data saved at 200 posts
Accessed 210 posts
Accessed 220 posts
Accessed 230 posts
Accessed 240 posts
Accessed 250 posts
Accessed 260 posts
Accessed 270 posts
Accessed 280 posts
Accessed 290 posts
Accessed 300 posts
Data saved at 300 posts
Accessed 310 posts
Accessed 320 posts
Accessed 330 posts
Accessed 340 posts
Accessed 350 posts
Accessed 360 posts
Accessed 370 posts
Accessed 380 posts
Accessed 390 posts
Accessed 400 posts
Data saved at 400 posts
Accessed 410 posts
Accessed 420 posts
Accessed 430 posts
Accessed 440 posts
Accessed 450 posts
Accessed 460 posts
Accessed 470 posts
Acc

**Rejecting Topdata**:

Originally, the idea was to use a merged version of the top data and the hot data. 

The top data csv file was rejected after a careful consideration. It was concluded that most of the top posts in the subreddit are off topic, for example, patients saying goodbye to the community after a long battle with a terminal disease, appreciation posts, meta posts about the subreddit and the community etc. 

Since hot data is likely to have more urgent latest medical enquries, it is used instead of top data.

In [3]:
data1 = pd.read_csv('ScrapeAskDoc_topdata.csv')
# data1_copy = data1.copy()
data2 = pd.read_csv('ScrapeAskDoc_topdata.csv')
# data2_copy = data2.copy()
data3 = pd.read_csv('AskDoc_topdata.csv')
# data3_copy = data3.copy()

#merged_data = pd.concat([data1, data2, data3], ignore_index = True)
merged_data = pd.concat([data1, data2])
merged_data.reset_index(drop = True, inplace = True)


#final_data = merged_data.copy()


In [4]:
final_data = pd.read_csv('ScrapeAskDoc_hotdata.csv')
final_data.head()

Unnamed: 0,id,title,selftext,author,score,num_comments,created_utc,flair,comments
0,1cfva9b,Weekly Discussion/General Questions Thread - A...,**This is a weekly general discussion and gene...,AutoModerator,1,71,1714385000.0,,"[{'id': 'l1tov1a', 'body': 'say you’re there t..."
1,1ck71u1,My nonverbal son was given 1000 mg ketamine pr...,My (21 m) son who is 6ft and weighs approximat...,Odd-Magician-3397,176,28,1714848000.0,Physician Responded,"[{'id': 'l2l0syr', 'body': ""Was this all at on..."
2,1ckgzxa,Why do I get sick when I don't drink milk?,"Caucasian male 22, height 5'11, 145lb/65.7kg \...",Creative-Yak-8287,42,13,1714876000.0,Physician Responded,"[{'id': 'l2n7h1x', 'body': 'There is unlikely ..."
3,1ckg0mk,"Elderly 90+ dad CHF, organ failure - is this c...",EDIT: you all are so sweet and honest and than...,seaw33dthrowaway,36,35,1714873000.0,Physician Responded,"[{'id': 'l2mpqx8', 'body': 'Have you spoken at..."
4,1cjxy7q,My friend [36M] is sending 30k to a girl [21F]...,Exam: MRI BRAIN/PITUITARY W/WO\nDate of Exam: ...,ProbablyShouldStop,278,46,1714821000.0,Physician Responded,"[{'id': 'l2j83zd', 'body': 'She sent him a 3.5..."


## Data Preprocessing
In this phase for fine-tuning a language model with Reddit data, several essential steps were undertaken. Null data entries were addressed to ensure dataset integrity. Duplicates were removed to avoid model bias and potential overfitting. Comments from moderators, which could skew the data's natural language patterns, were excluded. Given the unpredictable nature of 'hot' posts, additional cleaning was performed including the removal of special characters, HTML entities, and encoded characters. All of this is done to make sure the training data is as pure as it can get.



In [35]:
#Length of data before cleaning
print(f"Initial number of rows for the scraped data: {len(final_data)}")

Initial number of rows for the scraped data: 500


In [36]:
mdi = MarkdownIt()

Since reddit is a social forum, the language used there is not necessarily formal. In most cases, a lot of posts and comments have a high usage of emojis and special characters. They need to be dealt with before fine tuning our model. A function called clean_text has been implemented to do just that which removes a number of special characters, emojis etc. 

The posts and coments also contain hyperlinks and html characters for formatting purposes. These are removed using beauitfulsoup library and html. The text is then normalized by converting everything to lowercase and the URLs and blank spaces are removed. The final data is then exported as a backup.

In [27]:
def clean_text(text, author):
    # Skip processing for bot authors or deleted/removed content
    if author in ('AskDocs-ModTeam', 'AutoModerator') or text in ('[removed]', '[deleted]'):
        return None 

    emojis = re.compile(
        "["
        u"\U0001F600-\U0001F64F"  # emoticons
        u"\U0001F300-\U0001F5FF"  # symbols & pictographs
        u"\U0001F680-\U0001F6FF"  # transport & map symbols
        u"\U0001F700-\U0001F77F"  # alchemical symbols
        u"\U0001F780-\U0001F7FF"  # Geometric Shapes Extended
        u"\U0001F800-\U0001F8FF"  # Supplemental Arrows-C
        u"\U0001F900-\U0001F9FF"  # Supplemental Symbols and Pictographs
        u"\U0001FA00-\U0001FA6F"  # Chess Symbols
        u"\U0001FA70-\U0001FAFF"  # Symbols and Pictographs Extended-A
        u"\U00002702-\U000027B0"  # Dingbats
        u"\U000024C2-\U0001F251" 
        "]+", flags=re.UNICODE
    )
    
    # Continue processing if not by a bot or deleted/removed
    text = mdi.render(text)
    text = BeautifulSoup(text, 'html.parser').get_text()
    text = html.unescape(text)
    text = text.lower()  # Normalize case
    text = re.sub(r'http\S+', '', text)  # Remove URLs
    text = re.sub(r'\s+', ' ', text).strip()  # Remove extra spaces
    #text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    text = emojis.sub(r'', text)
    
    return text

final_data['title'] = final_data.apply(lambda row: clean_text(row['title'], row['author']), axis=1)
final_data['selftext'] = final_data.apply(lambda row: clean_text(row['selftext'], row['author']), axis=1)


In [24]:
final_data['comments']
final_data.to_csv('check.csv')

In [19]:
##old top max score top comment code
def clean_comments(comments, author):
    comments_list = ast.literal_eval(comments)

    # Find the top comment based on the maximum score
    top_comment = max(comments_list, key=lambda x: x['score']) if comments_list else None

    # Filter out comments from AutoModerator, AskDocs-ModTeam, and removed/deleted comments
    if top_comment and top_comment['author'] not in ('AskDocs-ModTeam', 'AutoModerator') and top_comment['body'] not in ('[removed]', '[deleted]'):
        cleaned_text = clean_text(top_comment['body'], top_comment['author'])
        if cleaned_text is not None:
            return [cleaned_text]

    return ''

final_data['comments'] = final_data.apply(lambda row: clean_comments(row['comments'], row['author']), axis=1)

In [37]:
def clean_text(text, author):
    if author in ('AskDocs-ModTeam', 'AutoModerator') or text in ('[removed]', '[deleted]'):
        return None

    emojis = re.compile(
        "["
        u"\U0001F600-\U0001F64F"  # emoticons
        u"\U0001F300-\U0001F5FF"  # symbols & pictographs
        u"\U0001F680-\U0001F6FF"  # transport & map symbols
        u"\U0001F700-\U0001F77F"  # alchemical symbols
        u"\U0001F780-\U0001F7FF"  # Geometric Shapes Extended
        u"\U0001F800-\U0001F8FF"  # Supplemental Arrows-C
        u"\U0001F900-\U0001F9FF"  # Supplemental Symbols and Pictographs
        u"\U0001FA00-\U0001FA6F"  # Chess Symbols
        u"\U0001FA70-\U0001FAFF"  # Symbols and Pictographs Extended-A
        u"\U00002702-\U000027B0"  # Dingbats
        u"\U000024C2-\U0001F251"
        "]+", flags=re.UNICODE
    )

    # Normalize the case, remove URLs and extra whitespace, and strip emojis
    text = html.unescape(text)
    text = text.lower()
    text = re.sub(r'http\S+', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    text = emojis.sub(r'', text)
    
    return text


In [38]:
def clean_comments(comments, author):
    comments_list = ast.literal_eval(comments)
    cleaned_comments = []

    # Process each comment
    for comment in comments_list:
        if comment['author'] not in ('AskDocs-ModTeam', 'AutoModerator') and comment['body'] not in ('[removed]', '[deleted]'):
            cleaned_text = clean_text(comment['body'], comment['author'])
            if cleaned_text is not None:
                cleaned_comments.append(cleaned_text)

    return cleaned_comments

# Applying the function to the DataFrame
final_data['comments'] = final_data.apply(lambda row: clean_comments(row['comments'], row['author']), axis=1)

In [39]:
new_data = final_data.copy()
new_data = new_data.dropna()
#new_data = new_data.drop_duplicates()
# final_data = final_data[final_data['comments'].str.len() > 5]
print(f"Number of rows after dropping null values: {len(new_data)}")

Number of rows after dropping null values: 121


In [42]:
new_data.to_csv('LatestCleaned.csv', encoding = 'utf-8', index = False)
new_data

Unnamed: 0,id,title,selftext,author,score,num_comments,created_utc,flair,comments
1,1ck71u1,My nonverbal son was given 1000 mg ketamine pr...,My (21 m) son who is 6ft and weighs approximat...,Odd-Magician-3397,176,28,1.714848e+09,Physician Responded,"[was this all at once, or as two separate inje..."
2,1ckgzxa,Why do I get sick when I don't drink milk?,"Caucasian male 22, height 5'11, 145lb/65.7kg \...",Creative-Yak-8287,42,13,1.714876e+09,Physician Responded,[there is unlikely to be anything physically w...
3,1ckg0mk,"Elderly 90+ dad CHF, organ failure - is this c...",EDIT: you all are so sweet and honest and than...,seaw33dthrowaway,36,35,1.714873e+09,Physician Responded,[have you spoken at all with the doctors about...
4,1cjxy7q,My friend [36M] is sending 30k to a girl [21F]...,Exam: MRI BRAIN/PITUITARY W/WO\nDate of Exam: ...,ProbablyShouldStop,278,46,1.714821e+09,Physician Responded,[she sent him a 3.5yr old mri? i have doubts a...
6,1ckb874,"12 month old unresponsive, stumped hospital. A...",12 months old\n\nFemale\n\nNo medications\n\nL...,Babyd2hardrn,26,10,1.714859e+09,Physician Responded,[all of her bloods are relatively normal (even...
...,...,...,...,...,...,...,...,...,...
471,1cjztlt,Is it bad that I drink to much water?,I work 12 hour shifts at my job. When I’m at w...,PossiblyAburd,1,2,1.714828e+09,Physician Responded,[sounds pretty normal. i also work 12 hour shi...
493,1cisrya,Slurred speech continued in 4 year old,4M. 52 pounds. \n\nI posted the other day abou...,lolly1997,905,312,1.714689e+09,Physician Responded,[did they do a lumbar puncture (a “spinal tap”...
496,1cjy1o5,Penis atrophy/shrinkage question,\n\n20 year old male\n\nSo since my prostate p...,Jankis2000,1,5,1.714822e+09,Physician Responded,[fascinating! there is this strange phenomenon...
498,1cjxnku,What are the chances my baby has birth defects,28F 5’4” 180lbs 4ish weeks pregnant\n\nI had a...,because-throw,1,2,1.714820e+09,Physician Responded,"[if you're only about 4 weeks pregnant now, th..."


In [1]:
# new_data = new_data.drop_duplicates(subset = 'id')
# # new_data.head(26)
# for x in new_data['comments']:
#     print (x)

In [45]:
new_data.to_csv('CleanLatest.csv', encoding = 'utf-8', index = False)

In [44]:
len(new_data)

121

In [34]:
# new_data = pd.read_csv('CleanData.csv')
# new_data.head(3)
for x in new_data['comments']:
    print(x[0])
    break

sorry this is happening you dont say your moms age but its unusual to develop schizophrenia later in life there are many other things that can cause auditory and visual hallucinations that arent schizophrenia some of which are more likely than schizophrenia depending on her age although having a sibling with it does increase her risk alcohol use disorder can predispose someone to certain types of brain dysfunction and dementia that could look like this too its also worth noting that a ct scan cannot rule in or out schizophrenia its hard to speculate further based on the information you have


In [35]:
def standardize_text(text):
    # Tokenize text
    tokens = word_tokenize(text)
    
    # Initialize the WordNet Lemmatizer
    lemmatizer = WordNetLemmatizer()

    # Lemmatize tokens and non-alphabetic tokens
    lemmatized_tokens = [
        lemmatizer.lemmatize(token) for token in tokens if token.isalpha()
    ]
    
    # Re-join lemmatized tokens into a single string
    return ' '.join(lemmatized_tokens)


In [36]:
stand_data = new_data.copy()

In [37]:
stand_data['title'] = stand_data['title'].apply(standardize_text)

In [38]:
stand_data['selftext'] = stand_data['selftext'].apply(standardize_text)

In [39]:
stand_data['comments'] = stand_data['comments'].apply(lambda x: [standardize_text(comment) for comment in x])

In [40]:
for x in new_data['comments']:
    print(x[0])
    break

sorry this is happening you dont say your moms age but its unusual to develop schizophrenia later in life there are many other things that can cause auditory and visual hallucinations that arent schizophrenia some of which are more likely than schizophrenia depending on her age although having a sibling with it does increase her risk alcohol use disorder can predispose someone to certain types of brain dysfunction and dementia that could look like this too its also worth noting that a ct scan cannot rule in or out schizophrenia its hard to speculate further based on the information you have


In [42]:
# final_data.drop(columns = ['author', 'num_comments', 'flair', 'id'])
# stand_data.drop(columns = ['author', 'num_comments', 'flair', 'id'], inplace = True)
# stand_data.to_csv('LatestStandardizedData.csv')
stand_data = pd.read_csv('LatestStandardizedData.csv')
# stand_data['comment'] = stand_data['comments'].apply(lambda x: x[0] if x else '')
# stand_data.head(2)

stand_data.head()

Unnamed: 0.1,Unnamed: 0,title,selftext,score,created_utc,comments
0,1,mother hospitalized with acute psychosis is it...,i wa informed that my mother and best friend w...,60,1713978000.0,['sorry this is happening you dont say your mo...
1,3,what are the actual chance of this happening,my son wa stillborn last month when i wa week ...,425,1713915000.0,['sorry for your loss my deepest sympathy true...
2,4,is it weird or unsafe to clean your butthole i...,age sex male height weight race hispanic durat...,228,1713921000.0,['short answer this is kinda strange but youre...
3,5,seeking adviceavenues to explore leg paralysis...,month ago i woke up with extreme leg pain to t...,9,1713980000.0,['usual disclaimer no one can provide specific...
4,7,ultrasound report,male lump on right testicle complain about inc...,5,1713994000.0,['hello my friend im wary of firing off anythi...


In [7]:
import logging
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
import pandas as pd
import ast

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def extract_disease_names_hf(text, max_length=512):
    # Load tokenizer and model from Hugging Face Hub
    tokenizer = AutoTokenizer.from_pretrained("Clinical-AI-Apollo/Medical-NER", max_length=max_length, truncation=True)
    model = AutoModelForTokenClassification.from_pretrained("Clinical-AI-Apollo/Medical-NER")
    
    # Create a pipeline for named entity recognition
    pipe = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy='simple')
    
    # Split the text into smaller segments
    segments = [text[i:i+max_length] for i in range(0, len(text), max_length)]
    
    diseases = []
    for segment in segments:
        # Process each segment through the pipeline
        ner_results = pipe(segment)
        # Extract entities labeled as diseases (depending on the model's labeling scheme)
        segment_diseases = [result['word'] for result in ner_results if 'disease' in result['entity_group'].lower()]
        diseases.extend(segment_diseases)
    
    return diseases

class MedicalSpecialtyPipeline:
    def __init__(self, model_name="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", max_length=512):
        logging.info("Initializing MedicalSpecialtyPipeline with model: %s", model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, max_length=max_length, truncation=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.medical_specialties = [
                    "Cardiology",
                    "Dermatology",
                    "Emergency Medicine",
                    "Endocrinology",
                    "Gastroenterology",
                    "General Surgery",
                    "Geriatrics",
                    "Gynecology",
                    "Hematology",
                    "Infectious Disease",
                    "Internal Medicine",
                    "Nephrology",
                    "Neurology",
                    "Obstetrics",
                    "Oncology",
                    "Ophthalmology",
                    "Orthopedics",
                    "Otolaryngology (ENT)",
                    "Pediatrics",
                    "Psychiatry",
                    "Pulmonology",
                    "Rheumatology",
                    "Urology",
                    "Others"
        ]
        self.advice_classifier = pipeline("zero-shot-classification", model=model_name)
        self.toxicity_classifier = pipeline("text-classification", model="unitary/toxic-bert")
        self.max_length = max_length

    def get_medical_specialty(self, text):
        logging.info("Getting medical specialty for text: %s", text[:50])
        
        # Extract disease names from the text
        diseases = extract_disease_names_hf(text, max_length=self.max_length)
        logging.info("Extracted diseases: %s", diseases)
        
        # Formulate the input for zero-shot classification
        hypothesis_template = "This medical case involves {}."
        
        # Classify medical specialty using zero-shot classification
        logging.info("Running zero-shot classification with hypothesis template")
        result = self.advice_classifier(text, self.medical_specialties, hypothesis_template=hypothesis_template)
        
        # Extract top result
        top_specialty = result['labels'][0]
        logging.info("Top medical specialty: %s", top_specialty)
        
        return top_specialty

    def is_doctor_advice(self, comment):
        logging.info("Checking if comment contains doctor advice: %s", comment[:50])
        # Adjust the hypothesis template to be more explicit and contextual
        hypothesis_template = "The statement '{}', is a piece of medical advice."
        candidate_labels = ["true", "false"]  # Using true/false to align with the hypothesis
        logging.info("Running zero-shot classifier on comment")
        # Adjust the call to pass the hypothesis template
        result = self.advice_classifier(hypothesis_template.format(comment), candidate_labels)
        logging.info("Classifier result: %s", result['labels'][0])
        return result['labels'][0] == 'true'

    def find_doctor_advice_comment(self, comments):
        logging.info("Finding doctor advice comment in comments")
        # Ensure comments are iterated correctly
        if isinstance(comments, str):
            comments = [comments]  # Single string to list
        elif isinstance(comments, list):
            pass  # Already in list form, do nothing
        else:
            logging.error("Unsupported comment format: %s", type(comments))
            return None

        for comment in comments:
            logging.info("Checking comment: %s", comment[:50])
            if self.is_doctor_advice(comment):
                logging.info("Found doctor advice comment: %s", comment[:50])
                return comment
        logging.info("No doctor advice comment found")
        return None
    
    def is_toxic(self, text):
        result = self.toxicity_classifier(text[:self.max_length])[0]
        return result['label'] == 'toxic' and result['score'] >= 0.7

    def process_dataframe(self, df):
        logging.info("Processing dataframe with %d rows", len(df))
        df = df.drop_duplicates(subset='id', keep='first')
        df = df[df['num_comments']< 15]
        df = df[['title','selftext','comments', 'flair']]
        df = df.reset_index()
        df.drop(['index'], inplace=True, axis=1)
        df['comments'] = df['comments'].apply(ast.literal_eval)
        df['medical_specialty'] = df['selftext'].apply(self.get_medical_specialty)
        df['doctor_advice_comment'] = df['comments'].apply(self.find_doctor_advice_comment)
        
        # Toxicity analysis
        df['title_non_toxic'] = ~df['title'].apply(self.is_toxic)
        df['selftext_non_toxic'] = ~df['selftext'].apply(self.is_toxic)
        df['comments_non_toxic'] = df['comments'].apply(lambda comments: all(~self.is_toxic(comment) for comment in comments))
        
        # Drop rows with toxic content
        df = df[(df['title_non_toxic'] == True) & (df['selftext_non_toxic'] == True) & (df['comments_non_toxic'] == True)]
        
        return df

import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

class MedicalSummaryPipeline:
    def __init__(self, summarization_model_name="Falconsai/medical_summarization", max_length=512):
        self.tokenizer = AutoTokenizer.from_pretrained(summarization_model_name, max_length=max_length, truncation=True)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name)
        self.summarizer = pipeline("summarization", model=self.model, tokenizer=self.tokenizer, max_length=max_length)
        self.max_length = max_length

    def summarize_text(self, text, max_length=50, min_length=10, do_sample=False):
        # Split the text into smaller segments
        segments = [text[i:i+self.max_length] for i in range(0, len(text), self.max_length)]
        
        summaries = []
        for segment in segments:
            # Summarize each segment
            summary = self.summarizer(segment, max_length=max_length, min_length=min_length, do_sample=do_sample)[0]['summary_text']
            summaries.append(summary)
        
        # Combine the summaries
        combined_summary = ' '.join(summaries)
        
        return combined_summary

    def process_dataframe(self, df):
        df['selftext_summary'] = df['selftext'].apply(self.summarize_text)
        df['question'] = df['title'] + '. ' + df['selftext_summary']
        return df
    

class PromptInstructionDataset:
    def __init__(self, data_file):
        self.data_file = data_file
        self.df = pd.read_csv(data_file)

    def create_dataset(self):
        prompt_instruction_pairs = []

        for index, row in self.df.iterrows():
            question = row['question']
            doctor_advice_comment = row['doctor_advice_comment']
            medical_specialty = row['medical_specialty']

            prompt = f"Question: {question}\n\nBased on the above information, provide a general medical advice comment and suggest the most appropriate medical specialty."
            instruction = f"Specialty Suggestion: {medical_specialty}\nMedical Advice: {doctor_advice_comment}"

            prompt_instruction_pairs.append({"prompt": prompt, "instruction": instruction})

        prompt_instruction_df = pd.DataFrame(prompt_instruction_pairs)
        return prompt_instruction_df

    def save_dataset(self, output_file):
        prompt_instruction_df = self.create_dataset()
        prompt_instruction_df.to_csv(output_file, index=False)


# Example usage and testing
df = pd.read_csv('data.csv')
df = df[:5]
medical_pipeline = MedicalSpecialtyPipeline(max_length=512)
df = medical_pipeline.process_dataframe(df)
df.to_csv('pro.csv')

summary_pipeline = MedicalSummaryPipeline(max_length=512)
df = summary_pipeline.process_dataframe(df)
print(df['question'].head())
df = df[['question','doctor_advice_comment','medical_specialty']]

df.to_csv('sum.csv')

# prompt_dataset = PromptInstructionDataset('sum.csv')
# prompt_dataset.save_dataset('prompt_instruction_dataset.csv')

2024-05-06 22:38:30,025 - INFO - Initializing MedicalSpecialtyPipeline with model: MoritzLaurer/deberta-v3-large-zeroshot-v2.0
2024-05-06 22:38:37,156 - INFO - Processing dataframe with 5 rows
2024-05-06 22:38:37,167 - INFO - Getting medical specialty for text: Caucasian male 22, height 5'11, 145lb/65.7kg 

If 
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
2024-05-06 22:38:41,500 - INFO - Extracted diseases: []
2024-05-06 22:38:41,502 - INFO - Running zero-shot classification with hypothesis template
2024-05-06 22:39:25,879 - INFO - Top medical specialty: Endocrinology
2024-05-06 22:39:25,882 - INFO - Getting medical specialty for text: 12 months old

Female

No medications

Labs

Alkal
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
2024-05-06 22:39:29,503 - INFO - Extracted diseases: []
2024-05-06 2

0    Why do I get sick when I don't drink milk?. ca...
1    12 month old unresponsive, stumped hospital. A...
Name: question, dtype: object


## Data analysis


## Benchmark

3 potential models selected, Phi3, llama 3 and mistral. Benchmarking below:


In [9]:
import requests

def generate_medical_assistant_response(user_query, model, api_key, endpoint):
    prompt = f"""
User Query: {user_query}
[You are a medical professional bot that behaves like a doctor. You only answer to medical-related queries. If you encounter off-topics, you should respond with, 'As an AI medical assistant, my expertise lies in providing assistance with medical queries. Unfortunately, I'm unable to address non-medical topics. If you have any medical-related questions or concerns, feel free to ask, and I'll do my best to assist you']

Medical Specialty Prediction:
Based on the provided query, the most relevant medical specialty is likely [Insert predicted medical specialty]. This specialty deals with [Briefly describe the focus of the predicted medical specialty].

Expert Response:
[Provide a detailed, expert-level response to the user's query based on the predicted medical specialty, including the following:
- Possible diagnosis or differential diagnoses
- Brief explanation of the suspected condition(s)
- Potential causes or risk factors
- Recommended diagnostic tests or procedures
- Treatment options and management strategies
- Prognosis and potential complications
- Guidance on when to seek further medical attention]

Lifestyle Recommendations:
[Insert specific advice on lifestyle modifications, such as diet, exercise, sleep, or stress management, that may help manage the condition or alleviate symptoms]

Next Steps:
[Insert recommendations for follow-up with a healthcare provider, including the type of specialist to consult and the timing of the follow-up]

Disclaimer:
This information is provided for educational purposes only and is not intended to replace professional medical advice, diagnosis, or treatment. Always seek the guidance of a qualified healthcare provider with any questions regarding a medical condition or treatment.
"""

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }

    data = {
        "model": model,
        "prompt": prompt,
        "max_tokens": 500,
        "temperature": 0.7,
        "stream": False
    }

    response = requests.post(endpoint, headers=headers, json=data)

    if response.status_code == 200:
        result = response.json()
        return result
    else:
        return f"Error: {response.status_code}"

# Example usage
models = ["mistral", "llama2"]
api_key = "YOUR_API_KEY"
endpoint = "http://localhost:11434/api/generate"

user_query = "I have been experiencing persistent headaches, sensitivity to light, and occasional nausea for the past few weeks. What could be causing these symptoms, and what should I do?"

for model in models:
    result = generate_medical_assistant_response(user_query, model, api_key, endpoint)
    print(f"Results from {model}:", result)


Results from mistral: {'model': 'mistral', 'created_at': '2024-05-06T23:32:08.5412682Z', 'response': " As an assistant, I'm here to help answer your medical-related questions to the best of my ability. Based on your symptoms of persistent headaches, sensitivity to light, and occasional nausea, it is possible that you are experiencing Migraines or Tension-Type Headaches.\n\nMigraines are a common, complex neurological disorder characterized by recurring moderate to severe headaches typically accompanied by sensitivities to light, sound, or other environmental stimuli. They can also be associated with visual disturbances, nausea, and vomiting. The exact cause of migraines is not fully understood, but they are believed to involve abnormal brain activity, genetics, and environmental factors.\n\nTension-type headaches are the most common type of primary headache and are characterized by bilateral pain that is mild to moderate in intensity and is often described as a feeling of pressure or t

In [14]:
# # Example usage
user_query = "I have been experiencing persistent headaches, sensitivity to light, and occasional nausea for the past few weeks. What could be causing these symptoms, and what should I do?"

prompt = medical_assistant_prompt(user_query)

# Make a request to the Mistral API
api_key = "YOUR_API_KEY"
endpoint = "http://localhost:11434/api/generate"

headers = {
    "Content-Type": "application/json"
   
}

data = {
    "model": "mistral",
    "prompt": prompt,
    "max_tokens": 500,
    "temperature": 0.7,
    "stream":False
}

response = requests.post(endpoint, headers=headers, json=data)

if response.status_code == 200:
    result = response.json()
    # generated_text = result["choices"][0]["text"]
    print(result)
else:
    print(f"Error: {response.status_code}")


In [70]:
result.keys()

dict_keys(['model', 'created_at', 'response', 'done', 'context', 'total_duration', 'load_duration', 'prompt_eval_count', 'prompt_eval_duration', 'eval_count', 'eval_duration'])

In [77]:
print(f"Model: {result['response']} \n ")

Model:  Predicted Medical Specialty: Neurology

Expert Response:
Based on your symptoms of persistent headaches, sensitivity to light, and occasional nausea for several weeks, it is important that you consult with a neurologist. These symptoms could be indicative of various conditions, including migraines, tension headaches, or more serious disorders such as a brain tumor or a condition called photophobia (extreme sensitivity to light).

Migraines are characterized by recurring headaches, usually unilateral, often accompanied by nausea, vomiting, and sensitivity to light and sound. Tension headaches are another common type of headache that may cause constant pain or pressure in the forehead, temples, or back of the head.

The exact causes of these conditions vary. Migraines may be triggered by stress, certain foods, hormonal changes, or other factors. Tension headaches can result from muscle tension due to stress, poor posture, or other reasons. Brain tumors and photophobia are less co

In [8]:
user_query2 = "Hey can you help me do my homework please?"
prompt2 = medical_assistant_prompt(user_query2)
response_2 = requests.post(endpoint, headers = headers, json = {"model":"mistral", "prompt":prompt2,"max_tokens":80, "temperature": 0.7, "stream": False})

if response_2.status_code == 200:
    result = response_2.json()
    # generated_text = result["choices"][0]["text"]
    # print(result)
else:
    print(f"Error: {response_2.status_code}")

