In [9]:
import os
import random
import numpy as np
import pandas as pd
from collections import Counter
import spacy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from datasets import load_dataset

# -----------------------------------------------------------
# 1. Set random seeds for reproducibility
# -----------------------------------------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# -----------------------------------------------------------
# 2. Load the spaCy English model (disable parser & NER for speed)
# -----------------------------------------------------------
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

# -----------------------------------------------------------
# 3. Load the Hugging Face Dataset and Label Mapping
# -----------------------------------------------------------
ds_default = load_dataset("TimSchopf/medical_abstracts", "default")
ds_labels = load_dataset("TimSchopf/medical_abstracts", "labels")

df = pd.DataFrame(ds_default["train"])
print(f"Default split loaded successfully with {len(df)} samples.")

labels_df = pd.DataFrame(ds_labels["train"])
label_map = dict(zip(labels_df["condition_label"], labels_df["condition_name"]))
print("Original Label Mapping:", label_map)

# Adjust labels from 1–5 to 0–4 for compatibility with CrossEntropyLoss
label_map = {k - 1: v for k, v in label_map.items()}
print("Adjusted Label Mapping:", label_map)

# -----------------------------------------------------------
# 4. Preprocess and Clean Data
# -----------------------------------------------------------
# Use the correct column names: "medical_abstract" for text and "condition_label" for labels.
df = df.dropna(subset=["medical_abstract", "condition_label"]).reset_index(drop=True)

min_label_freq = 5
label_counts = df["condition_label"].value_counts()
valid_labels = label_counts[label_counts >= min_label_freq].index
df = df[df["condition_label"].isin(valid_labels)].reset_index(drop=True)
print(f"After filtering, {len(df)} samples remain.")

texts = df["medical_abstract"].tolist()
labels = df["condition_label"].tolist()  # originally 1–5
labels = [l - 1 for l in labels]           # convert to 0-based indexing

print("Unique labels:", np.unique(labels))

# -----------------------------------------------------------
# 5. Tokenization and Vocabulary Construction
# -----------------------------------------------------------
def batch_advanced_tokenize(texts, batch_size=1000):
    tokenized_texts = []
    for doc in nlp.pipe(texts, batch_size=batch_size):
        tokens = [token.text for token in doc if not token.is_punct and not token.is_space]
        tokenized_texts.append(tokens)
    return tokenized_texts

tokenized_texts = batch_advanced_tokenize(texts, batch_size=1000)
print("Tokenization complete.")

all_tokens = [token for tokens in tokenized_texts for token in tokens]
vocab_counter = Counter(all_tokens)
min_word_freq = 2
vocab = {token for token, count in vocab_counter.items() if count >= min_word_freq}

# Reserve indices: 0 for <PAD>, 1 for <UNK>
word_to_index = {"<PAD>": 0, "<UNK>": 1}
for word in sorted(vocab):
    word_to_index[word] = len(word_to_index)
vocab_size = len(word_to_index)
print(f"Vocabulary size: {vocab_size}")

def text_to_sequence(tokens):
    return [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]

sequences = [text_to_sequence(tokens) for tokens in tokenized_texts]

# -----------------------------------------------------------
# 6. Pad Sequences to a Fixed Maximum Length
# -----------------------------------------------------------
max_len = 256  # Adjust as needed

def pad_sequence_fn(seq, max_len):
    if len(seq) < max_len:
        return seq + [0] * (max_len - len(seq))
    else:
        return seq[:max_len]

padded_sequences = [pad_sequence_fn(seq, max_len) for seq in sequences]
X = np.array(padded_sequences)
y = np.array(labels)

# -----------------------------------------------------------
# 7. Split the Data into Training and Validation Sets
# -----------------------------------------------------------
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")

# -----------------------------------------------------------
# 8. Create PyTorch Dataset and DataLoader
# -----------------------------------------------------------
class MedicalAbstractDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.long), torch.tensor(self.y[idx], dtype=torch.long)

batch_size = 64
train_dataset = MedicalAbstractDataset(X_train, y_train)
val_dataset = MedicalAbstractDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        num_workers=4, pin_memory=True)

# -----------------------------------------------------------
# 9. Load Pre-trained GloVe Embeddings and Build the Embedding Matrix
# -----------------------------------------------------------
def load_glove_embeddings(filepath, embedding_dim):
    embeddings_index = {}
    with open(filepath, encoding="utf8") as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            if vector.shape[0] == embedding_dim:
                embeddings_index[word] = vector
    return embeddings_index

embedding_dim = 100
glove_path = "glove.6B.100d.txt"
if not os.path.exists(glove_path):
    raise FileNotFoundError(f"{glove_path} not found. Please download it and place it in your working directory.")

glove_embeddings = load_glove_embeddings(glove_path, embedding_dim)
print(f"Loaded {len(glove_embeddings)} word vectors from GloVe.")

embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float32)
for word, idx in word_to_index.items():
    if word in glove_embeddings:
        embedding_matrix[idx] = glove_embeddings[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

# -----------------------------------------------------------
# 10. Define the LSTM-based Model for Text Classification
# -----------------------------------------------------------
# Increase dropout from 0.3 to 0.4 for added regularization.
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, output_dim,
                 dropout=0.4, pretrained_embeddings=None, freeze_embeddings=False):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.tensor(pretrained_embeddings))
            self.embedding.weight.requires_grad = not freeze_embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers,
                            batch_first=True, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        embedded = self.embedding(x)
        # Obtain the last hidden states from both directions.
        _, (h_n, _) = self.lstm(embedded)
        forward_h = h_n[-2, :, :]
        backward_h = h_n[-1, :, :]
        hidden = torch.cat((forward_h, backward_h), dim=1)
        hidden = self.dropout(hidden)
        logits = self.fc(hidden)
        return logits

hidden_dim = 128
num_layers = 2
output_dim = len(label_map)  # 5 classes (0 to 4)
dropout = 0.4  # Increased dropout

model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, num_layers,
                       output_dim, dropout, pretrained_embeddings=embedding_matrix,
                       freeze_embeddings=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(model)

# -----------------------------------------------------------
# 11. Define the Loss, Optimizer, and Training/Evaluation Functions
# -----------------------------------------------------------
criterion = nn.CrossEntropyLoss()
# Add weight decay for regularization.
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Use ReduceLROnPlateau to adjust learning rate when validation loss plateaus.
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss, epoch_correct = 0, 0
    for inputs, labels in loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
        epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

def evaluate_epoch(model, loader, criterion, device):
    model.eval()
    epoch_loss, epoch_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item() * inputs.size(0)
            epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate_epoch(model, val_loader, criterion, device)
    scheduler.step(val_loss)  # Adjust learning rate based on validation loss
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc*100:.2f}% | " +
          f"Val Loss = {val_loss:.4f}, Val Acc = {val_acc*100:.2f}%")

# -----------------------------------------------------------
# 12. Inference Function: Predict the Label for a New Abstract
# -----------------------------------------------------------
def predict_abstract(model, text, word_to_index, max_len, device, label_map):
    tokens = [token.text for token in nlp(text) if not token.is_punct and not token.is_space]
    seq = [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]
    seq = pad_sequence_fn(seq, max_len)
    input_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_tensor)
    pred_class = logits.argmax(dim=1).item()
    pred_label = label_map[pred_class]
    return pred_label

# -----------------------------------------------------------
# 13. Example Inference (Single Test Case)
# -----------------------------------------------------------
sample_abstract = (
    "Recent advances in medical research show that artificial intelligence can greatly "
    "improve diagnostic accuracy for various diseases. Further clinical studies are required "
    "to validate these preliminary findings."
)
predicted_label = predict_abstract(model, sample_abstract, word_to_index, max_len, device, label_map)
print("\nFor the sample abstract:\n\"{}\"\nPredicted label is: {}".format(sample_abstract, predicted_label))

# -----------------------------------------------------------
# 14. Save the Trained Model
# -----------------------------------------------------------
model_save_path = "lstm_model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"\nTrained model saved to {model_save_path}")


Default split loaded successfully with 11550 samples.
Original Label Mapping: {1: 'neoplasms', 2: 'digestive system diseases', 3: 'nervous system diseases', 4: 'cardiovascular diseases', 5: 'general pathological conditions'}
Adjusted Label Mapping: {0: 'neoplasms', 1: 'digestive system diseases', 2: 'nervous system diseases', 3: 'cardiovascular diseases', 4: 'general pathological conditions'}
After filtering, 11550 samples remain.
Unique labels: [0 1 2 3 4]
Tokenization complete.
Vocabulary size: 33400
Training samples: 9240, Validation samples: 2310
Loaded 400000 word vectors from GloVe.
LSTMClassifier(
  (embedding): Embedding(33400, 100, padding_idx=0)
  (lstm): LSTM(100, 128, num_layers=2, batch_first=True, dropout=0.4, bidirectional=True)
  (fc): Linear(in_features=256, out_features=5, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
)
Epoch 1/10: Train Loss = 1.3824, Train Acc = 42.41% | Val Loss = 1.2126, Val Acc = 51.86%
Epoch 2/10: Train Loss = 1.1692, Train Acc = 53.38% 

In [10]:
# -----------------------------------------------------------
# 15. Additional Test Cases for Inference
# -----------------------------------------------------------
test_cases = [
    "The patient was diagnosed with a tumor in the lung.",
    "Severe abdominal pain and persistent nausea.",
    "The patient exhibits tremors and loss of motor control.",
    "High blood pressure and chest pain were observed.",
    "Generalized weakness and fever."
]

print("\nAdditional Test Cases:")
for i, test_abstract in enumerate(test_cases, start=1):
    pred_label = predict_abstract(model, test_abstract, word_to_index, max_len, device, label_map)
    print(f"Test Case {i}:")
    print(f"Abstract: {test_abstract}")
    print(f"Predicted label: {pred_label}\n")


Additional Test Cases:
Test Case 1:
Abstract: The patient was diagnosed with a tumor in the lung.
Predicted label: neoplasms

Test Case 2:
Abstract: Severe abdominal pain and persistent nausea.
Predicted label: digestive system diseases

Test Case 3:
Abstract: The patient exhibits tremors and loss of motor control.
Predicted label: general pathological conditions

Test Case 4:
Abstract: High blood pressure and chest pain were observed.
Predicted label: cardiovascular diseases

Test Case 5:
Abstract: Generalized weakness and fever.
Predicted label: general pathological conditions

