# Medical Conversation Analysis

This notebook analyzes [medical conversations](https://www.kaggle.com/datasets/artemminiailo/medicalconversations2disease) to predict diseases and highlight relevant words using Spacy and deep learning.

In [None]:
# Import required libraries
import pandas as pd
import spacy
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
from IPython.display import HTML, display

# Download spacy model
!python -m spacy download en_core_web_sm

In [None]:
# Load and preprocess the data
df = pd.read_csv('data/medical_conversations.csv')
print(f"Dataset shape: {df.shape}")
print("\nSample conversations:")
display(df.head())

In [None]:
# Load spaCy model
nlp = spacy.load('en_core_web_sm')

def remove_bot_responses(text):
    # Split the conversation into turns
    turns = text.split('</s>')
    # Keep only user turns (they start with "User:") and remove the "User:" prefix
    user_turns = [turn.strip()[5:].strip() for turn in turns if turn.strip().startswith('User:')]
    # Join the user turns back together
    return ' </s> '.join(user_turns)

# Preprocess text
def preprocess_text(text):
    # First remove bot responses
    text = remove_bot_responses(text)
    # Remove HTML-like tags (including </s> tags)
    import re
    text = re.sub(r'<[^>]+>', ' ', text)

    doc = nlp(text)
    # Remove stopwords and punctuation, convert to lowercase
    tokens = [token.text.lower() for token in doc if not token.is_stop and not token.is_punct]
    return ' '.join(tokens)

# Apply preprocessing
df['processed_text'] = df['conversations'].apply(preprocess_text)

# Encode labels
label_encoder = LabelEncoder()
df['disease_encoded'] = label_encoder.fit_transform(df['disease'])

In [None]:
# Create vocabulary and word embeddings
class Vocabulary:
    def __init__(self, texts):
        self.word2idx = {'<PAD>': 0}
        self.idx2word = {0: '<PAD>'}
        self.build_vocab(texts)

    def build_vocab(self, texts):
        idx = 1
        for text in texts:
            for word in text.split():
                if word not in self.word2idx:
                    self.word2idx[word] = idx
                    self.idx2word[idx] = word
                    idx += 1

    def __len__(self):
        return len(self.word2idx)

# Create vocabulary
vocab = Vocabulary(df['processed_text'])
print(f"Vocabulary size: {len(vocab)}")

In [None]:
# Create dataset class
class MedicalDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_length=100):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx].split()
        # Convert words to indices and pad
        indices = [self.vocab.word2idx.get(word, 0) for word in text[:self.max_length]]
        indices = indices + [0] * (self.max_length - len(indices))
        return torch.tensor(indices), torch.tensor(self.labels[idx])

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    df['processed_text'], df['disease_encoded'], test_size=0.2, random_state=42
)

# Create datasets
train_dataset = MedicalDataset(X_train.values, y_train.values, vocab)
test_dataset = MedicalDataset(X_test.values, y_test.values, vocab)

In [None]:
# Define the model
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, text):
        embedded = self.embedding(text)
        lstm_out, _ = self.lstm(embedded)

        # Calculate attention weights
        attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
        self.last_attention_weights = attention_weights  # Store for later visualization

        # Apply attention
        attended = torch.sum(attention_weights * lstm_out, dim=1)
        return self.fc(attended)

# Initialize model
EMBEDDING_DIM = 100
HIDDEN_DIM = 64
OUTPUT_DIM = len(label_encoder.classes_)

model = TextClassifier(len(vocab), EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
# Training loop
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_texts, batch_labels in train_loader:
        optimizer.zero_grad()
        predictions = model(batch_texts)
        loss = criterion(predictions, batch_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_texts, batch_labels in test_loader:
            predictions = model(batch_texts)
            _, predicted = torch.max(predictions, 1)
            total += batch_labels.size(0)
            correct += (predicted == batch_labels).sum().item()

    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Average Loss: {total_loss/len(train_loader):.4f}')
    print(f'Accuracy: {100 * correct / total:.2f}%\n')

In [None]:
# Function to visualize word relevance
def visualize_word_relevance(text, attention_weights, vocab):
    words = text.split()
    weights = attention_weights.squeeze().numpy()[:len(words)]

    # Normalize weights to [0, 1]
    weights = (weights - weights.min()) / (weights.max() - weights.min())

    # Create HTML with colored words
    html = []
    for word, weight in zip(words, weights):
        color_intensity = int(255 * (1 - weight))
        html.append(f'<span style="background-color: rgba(255, {color_intensity}, {color_intensity}, 0.5)">{word}</span>')

    return ' '.join(html)

# Display examples with highlighted words
model.eval()
print("Examples with word relevance highlighting:")

sample_indices = np.random.choice(len(test_dataset), 10, replace=False)
for idx in sample_indices:
    text_indices, label = test_dataset[idx]
    text = X_test.iloc[idx]

    # Get model prediction and attention weights
    with torch.no_grad():
        prediction = model(text_indices.unsqueeze(0))
        attention_weights = model.last_attention_weights

    pred_label = label_encoder.inverse_transform([prediction.argmax().item()])[0]
    true_label = label_encoder.inverse_transform([label.item()])[0]

    print(f"\nTrue label: {true_label}")
    print(f"Predicted label: {pred_label}")
    display(HTML(visualize_word_relevance(text, attention_weights, vocab)))


# Challenge
- Can you come up with an explanation why the attention is at those words?
- Can you write some code to calculate the top 5 relevant words across all results?
