## Medical NLP Summarization

In [3]:
from datasets import load_dataset

df = load_dataset("har1/MTS_Dialogue-Clinical_Note")

In [4]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("har1/HealthScribe-Clinical_Note_Generator")
model = AutoModelForSeq2SeqLM.from_pretrained("har1/HealthScribe-Clinical_Note_Generator")

In [5]:
# 1️⃣ Prepare the conversation as one string
input_text = (
    "Physician: How are you feeling today?\n"
    "Patient: I’m doing better, but I still have some discomfort now and then.\n"
    "Physician: I understand you were in a car accident last September. Can you walk me through what happened?\n"
    "Patient: Yes, it was on September 1st, around 12:30 in the afternoon. I was driving from Cheadle Hulme to Manchester when I had to stop in traffic. Out of nowhere, another car hit me from behind, which pushed my car into the one in front.\n"
    "Physician: That sounds like a strong impact. Were you wearing your seatbelt?\n"
    "Patient: Yes, I always do.\n"
    "Physician: What did you feel immediately after the accident?\n"
    "Patient: At first, I was just shocked. But then I realized I had hit my head on the steering wheel, and I could feel pain in my neck and back almost right away.\n"
    "Physician: Did you seek medical attention at that time?\n"
    "Patient: Yes, I went to Moss Bank Accident and Emergency. They checked me over and said it was a whiplash injury, but they didn’t do any X-rays. They just gave me some advice and sent me home.\n"
    "Physician: How did things progress after that?\n"
    "Patient: The first four weeks were rough. My neck and back pain were really bad—I had trouble sleeping and had to take painkillers regularly. It started improving after that, but I had to go through ten sessions of physiotherapy to help with the stiffness and discomfort.\n"
    "Physician: That makes sense. Are you still experiencing pain now?\n"
    "Patient: It’s not constant, but I do get occasional backaches. It’s nothing like before, though.\n"
    "Physician: That’s good to hear. Have you noticed any other effects, like anxiety while driving or difficulty concentrating?\n"
    "Patient: No, nothing like that. I don’t feel nervous driving, and I haven’t had any emotional issues from the accident.\n"
    "Physician: And how has this impacted your daily life? Work, hobbies, anything like that?\n"
    "Patient: I had to take a week off work, but after that, I was back to my usual routine. It hasn’t really stopped me from doing anything.\n"
    "Physician: That’s encouraging. Let’s go ahead and do a physical examination to check your mobility and any lingering pain.\n"
    "[Physical Examination Conducted]\n"
    "Physician: Everything looks good. Your neck and back have a full range of movement, and there’s no tenderness or signs of lasting damage. Your muscles and spine seem to be in good condition.\n"
    "Patient: That’s a relief!\n"
    "Physician: Yes, your recovery so far has been quite positive. Given your progress, I’d expect you to make a full recovery within six months of the accident. There are no signs of long-term damage or degeneration.\n"
    "Patient: That’s great to hear. So, I don’t need to worry about this affecting me in the future?\n"
    "Physician: That’s right. I don’t foresee any long-term impact on your work or daily life. If anything changes or you experience worsening symptoms, you can always come back for a follow-up. But at this point, you’re on track for a full recovery.\n"
    "Patient: Thank you, doctor. I appreciate it.\n"
    "Physician: You’re very welcome, Ms. Jones. Take care, and don’t hesitate to reach out if you need anything."
)


# 2️⃣ Tokenize
inputs = tokenizer(
    input_text,
    return_tensors="pt",
    truncation=True,      # just in case it's long
    max_length=512
)

# 3️⃣ Generate output
output_ids = model.generate(
    **inputs,
    max_new_tokens=200,   # adjust as needed
    num_beams=4,          # beam search for better quality
    early_stopping=True
)

# 4️⃣ Decode
generated_note = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(generated_note)

Symptoms: occasional backaches
Diagnosis: whiplash injury
History of Patient: Involved in a motor vehicle accident on September 1st, 2001, while driving a Ford Taurus from Cheadle Hulme to Manchester, hit head on steering wheel, experienced pain in neck and back immediately after accident, sought medical care at Moss Bank Accident and Emergency, received advice and sent home
Plan of Action: N/A



In [6]:
import json

# 1️⃣ Prepare the conversation as one string
input_text = (
    "Physician: How are you feeling today?\n"
    "Patient: I’m doing better, but I still have some discomfort now and then.\n"
    "Physician: I understand you were in a car accident last September. Can you walk me through what happened?\n"
    "Patient: Yes, it was on September 1st, around 12:30 in the afternoon. I was driving from Cheadle Hulme to Manchester when I had to stop in traffic. Out of nowhere, another car hit me from behind, which pushed my car into the one in front.\n"
    "Physician: That sounds like a strong impact. Were you wearing your seatbelt?\n"
    "Patient: Yes, I always do.\n"
    "Physician: What did you feel immediately after the accident?\n"
    "Patient: At first, I was just shocked. But then I realized I had hit my head on the steering wheel, and I could feel pain in my neck and back almost right away.\n"
    "Physician: Did you seek medical attention at that time?\n"
    "Patient: Yes, I went to Moss Bank Accident and Emergency. They checked me over and said it was a whiplash injury, but they didn’t do any X-rays. They just gave me some advice and sent me home.\n"
    "Physician: How did things progress after that?\n"
    "Patient: The first four weeks were rough. My neck and back pain were really bad—I had trouble sleeping and had to take painkillers regularly. It started improving after that, but I had to go through ten sessions of physiotherapy to help with the stiffness and discomfort.\n"
    "Physician: That makes sense. Are you still experiencing pain now?\n"
    "Patient: It’s not constant, but I do get occasional backaches. It’s nothing like before, though.\n"
    "Physician: That’s good to hear. Have you noticed any other effects, like anxiety while driving or difficulty concentrating?\n"
    "Patient: No, nothing like that. I don’t feel nervous driving, and I haven’t had any emotional issues from the accident.\n"
    "Physician: And how has this impacted your daily life? Work, hobbies, anything like that?\n"
    "Patient: I had to take a week off work, but after that, I was back to my usual routine. It hasn’t really stopped me from doing anything.\n"
    "Physician: That’s encouraging. Let’s go ahead and do a physical examination to check your mobility and any lingering pain.\n"
    "[Physical Examination Conducted]\n"
    "Physician: Everything looks good. Your neck and back have a full range of movement, and there’s no tenderness or signs of lasting damage. Your muscles and spine seem to be in good condition.\n"
    "Patient: That’s a relief!\n"
    "Physician: Yes, your recovery so far has been quite positive. Given your progress, I’d expect you to make a full recovery within six months of the accident. There are no signs of long-term damage or degeneration.\n"
    "Patient: That’s great to hear. So, I don’t need to worry about this affecting me in the future?\n"
    "Physician: That’s right. I don’t foresee any long-term impact on your work or daily life. If anything changes or you experience worsening symptoms, you can always come back for a follow-up. But at this point, you’re on track for a full recovery.\n"
    "Patient: Thank you, doctor. I appreciate it.\n"
    "Physician: You’re very welcome, Ms. Jones. Take care, and don’t hesitate to reach out if you need anything."
)

# 2️⃣ Add instruction for JSON output
prompt = (
    input_text +
    "\n\nPlease extract the structured medical summary in the following JSON format only:\n"
    "{\n"
    '  "Patient_Name": "Janet Jones",\n'
    '  "Symptoms": ["Neck pain", "Back pain", "Head impact"],\n'
    '  "Diagnosis": "Whiplash injury",\n'
    '  "Treatment": ["10 physiotherapy sessions", "Painkillers"],\n'
    '  "Current_Status": "Occasional backache",\n'
    '  "Prognosis": "Full recovery expected within six months"\n'
    "}"
)

# 3️⃣ Tokenize
inputs = tokenizer(
    prompt,
    return_tensors="pt",
    truncation=True,
    max_length=512
)

# 4️⃣ Generate JSON output
output_ids = model.generate(
    **inputs,
    max_new_tokens=300,
    num_beams=4,
    early_stopping=True
)

# 5️⃣ Decode and parse JSON
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

# Try parsing to dict
try:
    generated_json = json.loads(generated_text)
    print(json.dumps(generated_json, indent=2))
except json.JSONDecodeError:
    print("Raw model output (not valid JSON):")
    print(generated_text)


Raw model output (not valid JSON):
Symptoms: occasional backaches
Diagnosis: whiplash injury
History of Patient: Involved in a motor vehicle accident on September 1st, 2001, while driving a Ford Taurus from Cheadle Hulme to Manchester, hit head on steering wheel, experienced pain in neck and back immediately after accident, sought medical care at Moss Bank Accident and Emergency, received advice and sent home
Plan of Action: N/A



In [2]:
import pandas as pd

# If you haven’t loaded it yet
df = pd.read_csv("Combined Data.csv")   # or pd.read_excel("your_file.xlsx")

# Unique entries in the 'status' column
print(df['status'].unique())


['Anxiety' 'Normal' 'Depression' 'Suicidal' 'Stress' 'Bipolar'
 'Personality disorder']


In [3]:
import pandas as pd

# drop rows where status is in the unwanted list
df = df[~df['status'].isin(['Suicidal', 'Bipolar', 'Personality disorder'])]

# replace Depression and Stress with Anxiety
df['status'] = df['status'].replace({
    'Depression': 'Anxiety',
    'Stress': 'Anxiety'
})

# (optional) check unique values again
print(df['status'].unique())
for i in range(5):
    print(df["statement"][i])

['Anxiety' 'Normal']
oh my gosh
trouble sleeping, confused mind, restless heart. All out of tune
All wrong, back off dear, forward doubt. Stay in a restless and restless place
I've shifted my focus to something else but I'm still worried
I'm restless and restless, it's been a month now, boy. What do you mean?


In [16]:
# Save to CSV (without the pandas index column)
balanced_df.to_csv("cleaned_dataset2.csv", index=False)


In [17]:
import pandas as pd

df = pd.read_csv("cleaned_dataset2.csv")
df = df[['statement','status']]  # keep only the two needed columns


In [18]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
balanced_df['label'] = le.fit_transform(balanced_df['status'])
print(le.classes_)  # e.g. ['Anxiety' 'Normal']


['Anxiety' 'Normal']


In [14]:
df['status'].value_counts()

status
Anxiety    21832
Normal     16343
Name: count, dtype: int64

In [15]:
import pandas as pd

# Assuming your df has 'status' column with classes and original data
# Downsample each class to 8000 rows randomly

anxiety_df = df[df["status"] == "Anxiety"].sample(n=8000, random_state=42)
normal_df = df[df["status"] == "Normal"].sample(n=8000, random_state=42)

# Concatenate balanced data into a new DataFrame
balanced_df = pd.concat([anxiety_df, normal_df]).reset_index(drop=True)

print(balanced_df["status"].value_counts())


status
Anxiety    8000
Normal     8000
Name: count, dtype: int64


In [8]:
!pip install tf-keras

Collecting keras>=3.10.0 (from tensorflow<2.21,>=2.20->tf-keras)
  Using cached keras-3.11.3-py3-none-any.whl.metadata (5.9 kB)
Using cached keras-3.11.3-py3-none-any.whl (1.4 MB)
Installing collected packages: keras
Successfully installed keras-3.11.3


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-intel 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.11.3 which is incompatible.
tensorflow-intel 2.15.0 requires ml-dtypes~=0.2.0, but you have ml-dtypes 0.5.3 which is incompatible.
tensorflow-intel 2.15.0 requires numpy<2.0.0,>=1.23.5, but you have numpy 2.2.6 which is incompatible.
tensorflow-intel 2.15.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 6.32.1 which is incompatible.
tensorflow-intel 2.15.0 requires tensorboard<2.16,>=2.15, but you have tensorboard 2.20.0 which is incompatible.
tensorflow-intel 2.15.0 requires wrapt<1.15,>=1.11.0, but you have wrapt 1.17.3 which is incompatible.


In [12]:
!pip install --upgrade transformers



In [None]:
import json
from transformers import pipeline
def extract_medical_info_generic(conversation_text, tokenizer, model):
    """
    Generic extraction function that works with your existing model.
    Requires your model to be instruction-tuned or fine-tuned for extraction.
    """
    
    # Create a detailed prompt for extraction
    extraction_prompt = f"""<|system|>
You are a medical information extraction assistant. Extract information from conversations and return ONLY valid JSON.

<|user|>
Extract the following from this medical conversation:
1. Patient_Name (string)
2. Symptoms (list of strings)
3. Diagnosis (string)
4. Treatment (list of strings)
5. Current_Status (string)
6. Prognosis (string)

Conversation:
{conversation_text}

Return ONLY a JSON object with these exact keys.

<|assistant|>
"""

    # Tokenize
    inputs = tokenizer(
        extraction_prompt,
        return_tensors="pt",
        truncation=True,
        max_length=1024  # Increased for longer conversations
    )

    # Generate with better parameters for structured output
    output_ids = model.generate(
        **inputs,
        max_new_tokens=300,
        num_beams=5,
        temperature=0.3,  # Lower temperature for more consistent output
        do_sample=False,
        early_stopping=True
    )

    # Decode
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    # Extract JSON from response
    try:
        import re
        json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', generated_text, re.DOTALL)
        if json_match:
            medical_info = json.loads(json_match.group())
            return medical_info
        else:
            # Fallback: try to parse the whole response
            return json.loads(generated_text)
    except json.JSONDecodeError:
        print("Could not parse JSON. Raw output:")
        print(generated_text)
        return parse_text_to_json(generated_text)

def parse_text_to_json(text):
    """
    Fallback parser when model doesn't return proper JSON.
    Attempts to extract structured information from free text.
    """
    import re
    
    medical_info = {
        "Patient_Name": None,
        "Symptoms": [],
        "Diagnosis": None,
        "Treatment": [],
        "Current_Status": None,
        "Prognosis": None
    }
    
    # Try to find each field in the text
    lines = text.split('\n')
    current_field = None
    
    for line in lines:
        line = line.strip()
        
        if 'patient_name' in line.lower() or 'patient name' in line.lower():
            match = re.search(r'["\']([^"\']+)["\']|:\s*(.+?)(?:,|\n|$)', line, re.IGNORECASE)
            if match:
                medical_info['Patient_Name'] = match.group(1) or match.group(2)
        
        elif 'symptoms' in line.lower():
            # Extract list items
            items = re.findall(r'["\']([^"\']+)["\']', line)
            if items:
                medical_info['Symptoms'] = items
        
        elif 'diagnosis' in line.lower():
            match = re.search(r'["\']([^"\']+)["\']|:\s*(.+?)(?:,|\n|$)', line, re.IGNORECASE)
            if match:
                medical_info['Diagnosis'] = match.group(1) or match.group(2)
        
        elif 'treatment' in line.lower():
            items = re.findall(r'["\']([^"\']+)["\']', line)
            if items:
                medical_info['Treatment'] = items
        
        elif 'current_status' in line.lower() or 'current status' in line.lower():
            match = re.search(r'["\']([^"\']+)["\']|:\s*(.+?)(?:,|\n|$)', line, re.IGNORECASE)
            if match:
                medical_info['Current_Status'] = match.group(1) or match.group(2)
        
        elif 'prognosis' in line.lower():
            match = re.search(r'["\']([^"\']+)["\']|:\s*(.+?)(?:,|\n|$)', line, re.IGNORECASE)
            if match:
                medical_info['Prognosis'] = match.group(1) or match.group(2)
    
    return medical_info


# ============================================================================
# MAIN USAGE EXAMPLE


if __name__ == "__main__":
    # Sample conversation
    input_text = """
    Physician: How are you feeling today?
    Patient: I'm doing better, but I still have some discomfort now and then.
    Physician: I understand you were in a car accident last September. Can you walk me through what happened?
    Patient: Yes, it was on September 1st, around 12:30 in the afternoon. I was driving from Cheadle Hulme to Manchester when I had to stop in traffic. Out of nowhere, another car hit me from behind, which pushed my car into the one in front.
    Physician: That sounds like a strong impact. Were you wearing your seatbelt?
    Patient: Yes, I always do.
    Physician: What did you feel immediately after the accident?
    Patient: At first, I was just shocked. But then I realized I had hit my head on the steering wheel, and I could feel pain in my neck and back almost right away.
    Physician: Did you seek medical attention at that time?
    Patient: Yes, I went to Moss Bank Accident and Emergency. They checked me over and said it was a whiplash injury, but they didn't do any X-rays. They just gave me some advice and sent me home.
    Physician: How did things progress after that?
    Patient: The first four weeks were rough. My neck and back pain were really bad—I had trouble sleeping and had to take painkillers regularly. It started improving after that, but I had to go through ten sessions of physiotherapy to help with the stiffness and discomfort.
    Physician: That makes sense. Are you still experiencing pain now?
    Patient: It's not constant, but I do get occasional backaches. It's nothing like before, though.
    Physician: That's good to hear. Have you noticed any other effects, like anxiety while driving or difficulty concentrating?
    Patient: No, nothing like that. I don't feel nervous driving, and I haven't had any emotional issues from the accident.
    Physician: And how has this impacted your daily life? Work, hobbies, anything like that?
    Patient: I had to take a week off work, but after that, I was back to my usual routine. It hasn't really stopped me from doing anything.
    Physician: That's encouraging. Let's go ahead and do a physical examination to check your mobility and any lingering pain.
    [Physical Examination Conducted]
    Physician: Everything looks good. Your neck and back have a full range of movement, and there's no tenderness or signs of lasting damage. Your muscles and spine seem to be in good condition.
    Patient: That's a relief!
    Physician: Yes, your recovery so far has been quite positive. Given your progress, I'd expect you to make a full recovery within six months of the accident. There are no signs of long-term damage or degeneration.
    Patient: That's great to hear. So, I don't need to worry about this affecting me in the future?
    Physician: That's right. I don't foresee any long-term impact on your work or daily life. If anything changes or you experience worsening symptoms, you can always come back for a follow-up. But at this point, you're on track for a full recovery.
    Patient: Thank you, doctor. I appreciate it.
    Physician: You're very welcome, Ms. Jones. Take care, and don't hesitate to reach out if you need anything.
    """
    
    print("=" * 80)
    print("OPTION 1: Using FLAN-T5 (Recommended)")
    print("=" * 80)
    result = extract_medical_info_with_llm(input_text)
    print(json.dumps(result, indent=2))
    
    print("\n" + "=" * 80)
    print("OPTION 2: Using Your Existing Model")
    print("=" * 80)
    print("Uncomment and use with your tokenizer and model:")
    print("# result = extract_medical_info_generic(input_text, tokenizer, model)")
    print("# print(json.dumps(result, indent=2))")
    
    print("\n" + "=" * 80)
    print("OPTION 3: Using OpenAI API")
    print("=" * 80)
    print("Uncomment and provide API key:")
    print("# result = extract_with_openai(input_text, api_key='your-api-key')")
    print("# print(json.dumps(result, indent=2))")

## Sentiment & Intent Analysis 

In [None]:
from datasets import Dataset
from transformers import (
    DistilBertTokenizerFast,
    TFAutoModel,
    DataCollatorWithPadding,
    create_optimizer
)
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import joblib
import json



df = pd.read_csv("cleaned_dataset2.csv")

# If you only have 'status' column, you'll need to create intent labels
# For now, I'll show you how to handle both sentiment and intent

# Option A: If you have both columns in CSV
if 'intent' in df.columns:
    df = df[["statement", "status", "intent"]].dropna()
else:
    # Option B: Create intent labels based on patterns (you'll need to customize this)
    def assign_intent(row):
        statement = str(row['statement']).lower()
        status = row['status']
        
        # Rule-based intent assignment (customize based on your data)
        if status == 'anxiety':
            if any(word in statement for word in ['help', 'what should', 'worried', 'scared']):
                return 'seeking_reassurance'
            elif any(word in statement for word in ['can\'t', 'unable', 'difficult']):
                return 'expressing_difficulty'
            else:
                return 'expressing_concern'
        else:  # neutral
            if any(word in statement for word in ['good', 'fine', 'okay', 'well']):
                return 'expressing_confidence'
            elif any(word in statement for word in ['think', 'believe', 'feel']):
                return 'sharing_opinion'
            else:
                return 'making_statement'
    
    df['intent'] = df.apply(assign_intent, axis=1)
    df = df[["statement", "status", "intent"]].dropna()

# Encode both labels
sentiment_le = LabelEncoder()
intent_le = LabelEncoder()

df["sentiment_labels"] = sentiment_le.fit_transform(df["status"])
df["intent_labels"] = intent_le.fit_transform(df["intent"])

print(f"Sentiment classes: {sentiment_le.classes_}")
print(f"Intent classes: {intent_le.classes_}")

# Split data
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

# Convert to Hugging Face Dataset
train_ds = Dataset.from_pandas(train_df)
test_ds = Dataset.from_pandas(test_df)

# ============================================================================
# STEP 2: Tokenizer
# ============================================================================

model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)

def preprocess_data(batch):
    return tokenizer(batch["statement"], truncation=True, padding=True, max_length=128)

tokenized_train = train_ds.map(preprocess_data, batched=True)
tokenized_test = test_ds.map(preprocess_data, batched=True)

# ============================================================================
# STEP 3: Data Collator
# ============================================================================

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")

# ============================================================================
# STEP 4: Convert to tf.data.Dataset with Multiple Labels
# ============================================================================

tf_train_dataset = tokenized_train.to_tf_dataset(
    columns=["input_ids", "attention_mask"],
    label_cols=["sentiment_labels", "intent_labels"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)

tf_validation_dataset = tokenized_test.to_tf_dataset(
    columns=["input_ids", "attention_mask"],
    label_cols=["sentiment_labels", "intent_labels"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=8,
)

# ============================================================================
# STEP 5: Build Multi-Task Model
# ============================================================================

class MultiTaskModel(keras.Model):
    def __init__(self, num_sentiments, num_intents, model_name="distilbert-base-uncased"):
        super().__init__()
        
        # Load pre-trained DistilBERT
        from transformers import TFAutoModel
        self.distilbert = TFAutoModel.from_pretrained(model_name, from_pt=True)
        
        # Sentiment classification head
        self.sentiment_dropout = keras.layers.Dropout(0.1)
        self.sentiment_classifier = keras.layers.Dense(
            num_sentiments, 
            activation=None, 
            name='sentiment_output'
        )
        
        # Intent classification head
        self.intent_dropout = keras.layers.Dropout(0.1)
        self.intent_classifier = keras.layers.Dense(
            num_intents, 
            activation=None, 
            name='intent_output'
        )
    
    def call(self, inputs, training=False):
        # Get DistilBERT outputs
        outputs = self.distilbert(inputs, training=training)
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        
        # Use [CLS] token representation
        cls_output = sequence_output[:, 0, :]  # [batch_size, hidden_size]
        
        # Sentiment branch
        sentiment_output = self.sentiment_dropout(cls_output, training=training)
        sentiment_logits = self.sentiment_classifier(sentiment_output)
        
        # Intent branch
        intent_output = self.intent_dropout(cls_output, training=training)
        intent_logits = self.intent_classifier(intent_output)
        
        return {
            'sentiment_output': sentiment_logits,
            'intent_output': intent_logits
        }

# Instantiate model
num_sentiments = len(sentiment_le.classes_)
num_intents = len(intent_le.classes_)

model = MultiTaskModel(num_sentiments, num_intents)

# ============================================================================
# STEP 6: Compile with Multiple Losses
# ============================================================================

batch_size = 8
num_epochs = 3
batches_per_epoch = len(tf_train_dataset)
total_train_steps = batches_per_epoch * num_epochs

optimizer, schedule = create_optimizer(
    init_lr=5e-5, 
    num_warmup_steps=0, 
    num_train_steps=total_train_steps
)

# Define losses and metrics for both tasks
losses = {
    'sentiment_output': keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    'intent_output': keras.losses.SparseCategoricalCrossentropy(from_logits=True)
}

metrics = {
    'sentiment_output': [keras.metrics.SparseCategoricalAccuracy(name='sentiment_accuracy')],
    'intent_output': [keras.metrics.SparseCategoricalAccuracy(name='intent_accuracy')]
}

# You can weight the losses if one task is more important
loss_weights = {
    'sentiment_output': 1.0,
    'intent_output': 1.0
}

model.compile(
    optimizer=optimizer,
    loss=losses,
    metrics=metrics,
    loss_weights=loss_weights
)

# ============================================================================
# STEP 7: Train Model
# ============================================================================

# Build the model by calling it once
dummy_input = {
    'input_ids': tf.constant([[0] * 128]),
    'attention_mask': tf.constant([[1] * 128])
}
_ = model(dummy_input)

history = model.fit(
    tf_train_dataset,
    validation_data=tf_validation_dataset,
    epochs=num_epochs
)

# ============================================================================
# STEP 8: Save Model and Encoders
# ============================================================================

model.save_weights("./multi_task_model/model_weights")
tokenizer.save_pretrained("./multi_task_model")
joblib.dump(sentiment_le, "./multi_task_model/sentiment_encoder.pkl")
joblib.dump(intent_le, "./multi_task_model/intent_encoder.pkl")

# Save model config
config = {
    'num_sentiments': num_sentiments,
    'num_intents': num_intents,
    'model_name': model_name
}
with open("./multi_task_model/config.json", 'w') as f:
    json.dump(config, f)

print("\n✅ Model training complete!")
print(f"Sentiment classes: {list(sentiment_le.classes_)}")
print(f"Intent classes: {list(intent_le.classes_)}")

# ============================================================================
# STEP 9: Inference Function to Get JSON Output
# ============================================================================

def predict_with_json(text, model, tokenizer, sentiment_le, intent_le):
    """
    Predict sentiment and intent, return as JSON
    """
    # Tokenize
    inputs = tokenizer(
        text,
        truncation=True,
        padding=True,
        max_length=128,
        return_tensors="tf"
    )
    
    # Predict
    outputs = model(inputs, training=False)
    
    # Get predictions
    sentiment_logits = outputs['sentiment_output']
    intent_logits = outputs['intent_output']
    
    sentiment_pred = tf.argmax(sentiment_logits, axis=1).numpy()[0]
    intent_pred = tf.argmax(intent_logits, axis=1).numpy()[0]
    
    # Decode labels
    sentiment_label = sentiment_le.inverse_transform([sentiment_pred])[0]
    intent_label = intent_le.inverse_transform([intent_pred])[0]
    
    # Format intent label nicely
    intent_formatted = intent_label.replace('_', ' ').title()
    sentiment_formatted = sentiment_label.capitalize()
    
    # Return JSON
    result = {
        "Sentiment": sentiment_formatted,
        "Intent": intent_formatted
    }
    
    return json.dumps(result, indent=2)

# ============================================================================
# STEP 10: Test the Model
# ============================================================================

# Test examples
test_texts = [
    "I'm really worried about the upcoming exam",
    "Everything is going well today",
    "I can't handle this stress anymore"
]

print("\n" + "="*60)
print("TESTING MODEL WITH JSON OUTPUT")
print("="*60)

for text in test_texts:
    print(f"\nInput: {text}")
    result = predict_with_json(text, model, tokenizer, sentiment_le, intent_le)
    print(f"Output: {result}")

## SOAP NOTES GENERATION

In [None]:
"""
SOAP Note Generation System
Implements multiple approaches for converting medical transcripts to SOAP notes
"""

import json
import re
from typing import Dict, List, Any
from dataclasses import dataclass, asdict
import pandas as pd
from sklearn.model_selection import train_test_split

# ============================================================================
# APPROACH 1: Rule-Based + NER (Named Entity Recognition)
# ============================================================================

class RuleBasedSOAPGenerator:
    """
    Rule-based system using pattern matching and medical keywords
    Best for: Quick prototyping, interpretable results
    """
    
    def __init__(self):
        # Define keyword patterns for each SOAP section
        self.subjective_keywords = {
            'complaint': ['hurt', 'pain', 'ache', 'discomfort', 'feeling', 'symptom'],
            'history': ['accident', 'injury', 'started', 'began', 'weeks ago', 'months ago', 
                       'experienced', 'had', 'was', 'happened']
        }
        
        self.objective_keywords = {
            'physical_exam': ['examination', 'range of motion', 'tenderness', 'swelling', 
                            'temperature', 'blood pressure', 'heart rate', 'respiratory rate'],
            'observations': ['appears', 'looks', 'gait', 'posture', 'condition', 'normal', 'abnormal']
        }
        
        self.assessment_keywords = {
            'diagnosis': ['diagnosis', 'diagnosed with', 'condition', 'injury', 'disease', 
                         'whiplash', 'strain', 'sprain', 'fracture'],
            'severity': ['mild', 'moderate', 'severe', 'improving', 'worsening', 'stable']
        }
        
        self.plan_keywords = {
            'treatment': ['treatment', 'medication', 'therapy', 'physiotherapy', 'surgery', 
                         'continue', 'prescribe', 'recommend', 'analgesic', 'painkiller'],
            'follow_up': ['follow-up', 'return', 'come back', 'appointment', 'check-up', 
                         'monitor', 'if', 'worsens', 'persists']
        }
    
    def extract_sentences_by_keywords(self, text: str, keywords: List[str]) -> List[str]:
        """Extract sentences containing specific keywords"""
        sentences = re.split(r'[.!?]', text)
        matching_sentences = []
        
        for sentence in sentences:
            sentence = sentence.strip()
            if any(keyword.lower() in sentence.lower() for keyword in keywords):
                matching_sentences.append(sentence)
        
        return matching_sentences
    
    def extract_subjective(self, text: str) -> Dict[str, str]:
        """Extract Subjective section"""
        complaint_sentences = self.extract_sentences_by_keywords(
            text, self.subjective_keywords['complaint']
        )
        history_sentences = self.extract_sentences_by_keywords(
            text, self.subjective_keywords['history']
        )
        
        return {
            "Chief_Complaint": ' '.join(complaint_sentences[:2]) if complaint_sentences else "Not documented",
            "History_of_Present_Illness": ' '.join(history_sentences[:3]) if history_sentences else "Not documented"
        }
    
    def extract_objective(self, text: str) -> Dict[str, str]:
        """Extract Objective section"""
        exam_sentences = self.extract_sentences_by_keywords(
            text, self.objective_keywords['physical_exam']
        )
        observation_sentences = self.extract_sentences_by_keywords(
            text, self.objective_keywords['observations']
        )
        
        return {
            "Physical_Exam": ' '.join(exam_sentences) if exam_sentences else "Not documented",
            "Observations": ' '.join(observation_sentences) if observation_sentences else "Not documented"
        }
    
    def extract_assessment(self, text: str) -> Dict[str, str]:
        """Extract Assessment section"""
        diagnosis_sentences = self.extract_sentences_by_keywords(
            text, self.assessment_keywords['diagnosis']
        )
        severity_sentences = self.extract_sentences_by_keywords(
            text, self.assessment_keywords['severity']
        )
        
        return {
            "Diagnosis": ' '.join(diagnosis_sentences) if diagnosis_sentences else "Not documented",
            "Severity": ' '.join(severity_sentences) if severity_sentences else "Not documented"
        }
    
    def extract_plan(self, text: str) -> Dict[str, str]:
        """Extract Plan section"""
        treatment_sentences = self.extract_sentences_by_keywords(
            text, self.plan_keywords['treatment']
        )
        followup_sentences = self.extract_sentences_by_keywords(
            text, self.plan_keywords['follow_up']
        )
        
        return {
            "Treatment": ' '.join(treatment_sentences) if treatment_sentences else "Not documented",
            "Follow_Up": ' '.join(followup_sentences) if followup_sentences else "Not documented"
        }
    
    def generate_soap_note(self, transcript: str) -> Dict[str, Any]:
        """Generate complete SOAP note from transcript"""
        return {
            "Subjective": self.extract_subjective(transcript),
            "Objective": self.extract_objective(transcript),
            "Assessment": self.extract_assessment(transcript),
            "Plan": self.extract_plan(transcript)
        }


# ============================================================================
# APPROACH 2: Transformer-Based Fine-Tuning (T5/BART)
# ============================================================================

class TransformerSOAPGenerator:
    """
    Uses pre-trained transformer models fine-tuned on SOAP note generation
    Best for: High accuracy, handles complex medical language
    """
    
    def __init__(self, model_name="t5-base"):
        from transformers import T5Tokenizer, TFT5ForConditionalGeneration
        
        self.model_name = model_name
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = TFT5ForConditionalGeneration.from_pretrained(model_name, from_pt=True)
    
    def prepare_training_data(self, csv_path: str) -> tuple:
        """
        Prepare training data from CSV with columns:
        - transcript: full conversation text
        - subjective: JSON string
        - objective: JSON string
        - assessment: JSON string
        - plan: JSON string
        """
        df = pd.read_csv(csv_path)
        
        # Create training examples for each SOAP section
        training_data = []
        
        for _, row in df.iterrows():
            transcript = row['transcript']
            
            # Create training pairs for each section
            sections = {
                'Subjective': row['subjective'],
                'Objective': row['objective'],
                'Assessment': row['assessment'],
                'Plan': row['plan']
            }
            
            for section_name, section_content in sections.items():
                input_text = f"extract {section_name.lower()} from medical transcript: {transcript}"
                output_text = section_content
                training_data.append({
                    'input': input_text,
                    'output': output_text
                })
        
        train_data, val_data = train_test_split(training_data, test_size=0.2, random_state=42)
        return train_data, val_data
    
    def train(self, train_data: List[Dict], val_data: List[Dict], epochs: int = 3):
        """
        Fine-tune T5 model on SOAP note generation
        """
        import tensorflow as tf
        from transformers import create_optimizer
        
        # Tokenize training data
        def tokenize_data(examples):
            model_inputs = self.tokenizer(
                [ex['input'] for ex in examples],
                max_length=512,
                truncation=True,
                padding='max_length'
            )
            
            labels = self.tokenizer(
                [ex['output'] for ex in examples],
                max_length=256,
                truncation=True,
                padding='max_length'
            )
            
            model_inputs['labels'] = labels['input_ids']
            return model_inputs
        
        # Prepare datasets
        train_encodings = tokenize_data(train_data)
        val_encodings = tokenize_data(val_data)
        
        # Create TF datasets
        train_dataset = tf.data.Dataset.from_tensor_slices((
            {
                'input_ids': train_encodings['input_ids'],
                'attention_mask': train_encodings['attention_mask']
            },
            train_encodings['labels']
        )).batch(8)
        
        val_dataset = tf.data.Dataset.from_tensor_slices((
            {
                'input_ids': val_encodings['input_ids'],
                'attention_mask': val_encodings['attention_mask']
            },
            val_encodings['labels']
        )).batch(8)
        
        # Setup optimizer
        num_train_steps = len(train_data) // 8 * epochs
        optimizer, schedule = create_optimizer(
            init_lr=5e-5,
            num_warmup_steps=0,
            num_train_steps=num_train_steps
        )
        
        # Compile and train
        self.model.compile(optimizer=optimizer)
        self.model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=epochs
        )
    
    def generate_section(self, transcript: str, section: str) -> str:
        """Generate a specific SOAP section"""
        input_text = f"extract {section.lower()} from medical transcript: {transcript}"
        
        input_ids = self.tokenizer(
            input_text,
            return_tensors="tf",
            max_length=512,
            truncation=True
        ).input_ids
        
        outputs = self.model.generate(
            input_ids,
            max_length=256,
            num_beams=4,
            early_stopping=True
        )
        
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    def generate_soap_note(self, transcript: str) -> Dict[str, Any]:
        """Generate complete SOAP note"""
        sections = ['Subjective', 'Objective', 'Assessment', 'Plan']
        soap_note = {}
        
        for section in sections:
            section_text = self.generate_section(transcript, section)
            # Parse JSON if the model outputs JSON format
            try:
                soap_note[section] = json.loads(section_text)
            except json.JSONDecodeError:
                # If not JSON, structure it manually
                soap_note[section] = {"content": section_text}
        
        return soap_note


# ============================================================================
# APPROACH 3: Sequence Labeling with BERT (Token Classification)
# ============================================================================

class BERTSOAPLabeler:
    """
    Uses BERT for sequence labeling to tag each sentence with SOAP category
    Best for: Precise sentence-level categorization
    """
    
    def __init__(self):
        from transformers import TFBertForTokenClassification, BertTokenizerFast
        
        # Label mapping for SOAP sections + subsections
        self.label_map = {
            'O': 0,  # Outside any SOAP section
            'S-COMPLAINT': 1,  # Subjective - Chief Complaint
            'S-HISTORY': 2,    # Subjective - History
            'O-EXAM': 3,       # Objective - Physical Exam
            'O-OBS': 4,        # Objective - Observations
            'A-DIAG': 5,       # Assessment - Diagnosis
            'A-SEV': 6,        # Assessment - Severity
            'P-TREAT': 7,      # Plan - Treatment
            'P-FOLLOW': 8      # Plan - Follow-up
        }
        
        self.id_to_label = {v: k for k, v in self.label_map.items()}
        
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.model = TFBertForTokenClassification.from_pretrained(
            'bert-base-uncased',
            num_labels=len(self.label_map),
            from_pt=True
        )
    
    def prepare_training_data(self, annotated_transcripts: List[Dict]) -> tuple:
        """
        Prepare data where each sentence is labeled with SOAP category
        Format: [{'text': 'sentence', 'label': 'S-COMPLAINT'}, ...]
        """
        tokenized_inputs = []
        labels = []
        
        for transcript in annotated_transcripts:
            sentences = transcript['sentences']
            sentence_labels = transcript['labels']
            
            # Tokenize sentences
            encoding = self.tokenizer(
                sentences,
                is_split_into_words=True,
                padding='max_length',
                truncation=True,
                max_length=128,
                return_tensors='tf'
            )
            
            # Convert labels to IDs
            label_ids = [self.label_map[label] for label in sentence_labels]
            
            # Align labels with tokens
            word_ids = encoding.word_ids()
            aligned_labels = []
            
            for word_id in word_ids:
                if word_id is None:
                    aligned_labels.append(-100)  # Ignore padding
                else:
                    aligned_labels.append(label_ids[word_id])
            
            tokenized_inputs.append(encoding)
            labels.append(aligned_labels)
        
        return tokenized_inputs, labels
    
    def train(self, train_data, val_data, epochs=3):
        """Train the BERT token classifier"""
        import tensorflow as tf
        from transformers import create_optimizer
        
        # Prepare TF datasets
        train_dataset = tf.data.Dataset.from_tensor_slices((
            dict(train_data[0]),
            train_data[1]
        )).batch(8)
        
        val_dataset = tf.data.Dataset.from_tensor_slices((
            dict(val_data[0]),
            val_data[1]
        )).batch(8)
        
        # Setup training
        num_train_steps = len(train_data[0]) // 8 * epochs
        optimizer, _ = create_optimizer(
            init_lr=5e-5,
            num_warmup_steps=0,
            num_train_steps=num_train_steps
        )
        
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
        
        self.model.fit(train_dataset, validation_data=val_dataset, epochs=epochs)
    
    def predict_soap_labels(self, transcript: str) -> List[tuple]:
        """Predict SOAP labels for each sentence"""
        sentences = re.split(r'[.!?]', transcript)
        sentences = [s.strip() for s in sentences if s.strip()]
        
        # Tokenize
        inputs = self.tokenizer(
            sentences,
            is_split_into_words=True,
            padding=True,
            truncation=True,
            return_tensors='tf'
        )
        
        # Predict
        outputs = self.model(inputs)
        predictions = tf.argmax(outputs.logits, axis=-1)
        
        # Map predictions to labels
        labeled_sentences = []
        for i, sentence in enumerate(sentences):
            label_id = predictions[i, 1].numpy()  # Skip [CLS] token
            label = self.id_to_label[label_id]
            labeled_sentences.append((sentence, label))
        
        return labeled_sentences
    
    def generate_soap_note(self, transcript: str) -> Dict[str, Any]:
        """Generate SOAP note from labeled sentences"""
        labeled_sentences = self.predict_soap_labels(transcript)
        
        soap_note = {
            "Subjective": {"Chief_Complaint": "", "History_of_Present_Illness": ""},
            "Objective": {"Physical_Exam": "", "Observations": ""},
            "Assessment": {"Diagnosis": "", "Severity": ""},
            "Plan": {"Treatment": "", "Follow_Up": ""}
        }
        
        # Group sentences by label
        for sentence, label in labeled_sentences:
            if label == 'S-COMPLAINT':
                soap_note["Subjective"]["Chief_Complaint"] += sentence + " "
            elif label == 'S-HISTORY':
                soap_note["Subjective"]["History_of_Present_Illness"] += sentence + " "
            elif label == 'O-EXAM':
                soap_note["Objective"]["Physical_Exam"] += sentence + " "
            elif label == 'O-OBS':
                soap_note["Objective"]["Observations"] += sentence + " "
            elif label == 'A-DIAG':
                soap_note["Assessment"]["Diagnosis"] += sentence + " "
            elif label == 'A-SEV':
                soap_note["Assessment"]["Severity"] += sentence + " "
            elif label == 'P-TREAT':
                soap_note["Plan"]["Treatment"] += sentence + " "
            elif label == 'P-FOLLOW':
                soap_note["Plan"]["Follow_Up"] += sentence + " "
        
        # Clean up extra spaces
        for section in soap_note:
            for subsection in soap_note[section]:
                soap_note[section][subsection] = soap_note[section][subsection].strip()
        
        return soap_note


# ============================================================================
# APPROACH 4: Hybrid System (Rule-Based + Transformer)
# ============================================================================

class HybridSOAPGenerator:
    """
    Combines rule-based extraction with transformer refinement
    Best for: Balanced accuracy and interpretability
    """
    
    def __init__(self):
        self.rule_based = RuleBasedSOAPGenerator()
        self.transformer = None  # Initialize when needed
    
    def generate_soap_note(self, transcript: str, use_transformer: bool = False) -> Dict[str, Any]:
        """
        Generate SOAP note using hybrid approach
        1. Extract with rules
        2. Refine with transformer if available
        """
        # First pass: Rule-based extraction
        soap_note = self.rule_based.generate_soap_note(transcript)
        
        # Second pass: Transformer refinement (if enabled)
        if use_transformer and self.transformer:
            for section in soap_note:
                section_text = json.dumps(soap_note[section])
                refined = self.transformer.generate_section(transcript, section)
                try:
                    soap_note[section] = json.loads(refined)
                except:
                    pass  # Keep rule-based result if parsing fails
        
        return soap_note


# ============================================================================
# MAIN USAGE AND DEMONSTRATION
# ============================================================================

def main():
    # Sample transcript
    transcript = """
    Doctor: How are you feeling today?
    Patient: I had a car accident. My neck and back hurt a lot for four weeks.
    Doctor: Did you receive treatment?
    Patient: Yes, I had ten physiotherapy sessions, and now I only have occasional back pain.
    Doctor: Let me examine you. Your neck and back have full range of motion with no tenderness.
    Doctor: Based on the examination, you have a whiplash injury and lower back strain. It's mild and improving.
    Doctor: Continue physiotherapy as needed and use analgesics for pain relief.
    Doctor: Return if pain worsens or persists beyond six months.
    """
    
    print("="*80)
    print("SOAP NOTE GENERATION DEMONSTRATION")
    print("="*80)
    
    # Approach 1: Rule-Based
    print("\n🔹 APPROACH 1: Rule-Based System")
    print("-" * 80)
    rule_generator = RuleBasedSOAPGenerator()
    soap_note_1 = rule_generator.generate_soap_note(transcript)
    print(json.dumps(soap_note_1, indent=2))
    
    # Approach 2: Transformer (demonstration)
    print("\n🔹 APPROACH 2: Transformer-Based (T5)")
    print("-" * 80)
    print("Note: Requires fine-tuning on medical SOAP note dataset")
    print("Example usage:")
    print("""
    transformer_gen = TransformerSOAPGenerator()
    train_data, val_data = transformer_gen.prepare_training_data('soap_dataset.csv')
    transformer_gen.train(train_data, val_data)
    soap_note_2 = transformer_gen.generate_soap_note(transcript)
    """)
    
    # Approach 3: BERT Sequence Labeling (demonstration)
    print("\n🔹 APPROACH 3: BERT Token Classification")
    print("-" * 80)
    print("Note: Requires annotated data with sentence-level SOAP labels")
    print("Example usage:")
    print("""
    bert_labeler = BERTSOAPLabeler()
    train_data, val_data = bert_labeler.prepare_training_data(annotated_data)
    bert_labeler.train(train_data, val_data)
    soap_note_3 = bert_labeler.generate_soap_note(transcript)
    """)
    
    # Approach 4: Hybrid
    print("\n🔹 APPROACH 4: Hybrid System")
    print("-" * 80)
    hybrid_gen = HybridSOAPGenerator()
    soap_note_4 = hybrid_gen.generate_soap_note(transcript, use_transformer=False)
    print(json.dumps(soap_note_4, indent=2))


if __name__ == "__main__":
    main()