In [None]:
%run ../setup.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
from collections import Counter
import random
import spacy
import os

# -------------------------------
# Set random seeds for reproducibility
# -------------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# -------------------------------
# Enable cuDNN benchmark for potential speed-up (only effective if using GPU)
# -------------------------------
if torch.cuda.is_available():
    cudnn.benchmark = True

# -------------------------------
# Load spaCy English model (disable parser and NER for speed)
# -------------------------------
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

# -------------------------------
# Step 1: Load and Inspect the MIMIC Discharge Dataset
# -------------------------------
discharge_csv = "discharge.csv"
try:
    df_mimic = pd.read_csv(discharge_csv)
    print("Discharge dataset loaded successfully.")
except FileNotFoundError:
    print(f"Error: '{discharge_csv}' not found. Please ensure the file is in the working directory.")
    exit()

# Print out columns for inspection
print("Columns in discharge dataset:", df_mimic.columns)

# -------------------------------
# Decrease the number of samples for faster execution.
# Change max_samples to the desired number, e.g., 10000.
# -------------------------------
max_samples = 10000
if len(df_mimic) > max_samples:
    df_mimic = df_mimic.sample(n=max_samples, random_state=42).reset_index(drop=True)
    print(f"Dataset reduced to {max_samples} samples.")

# -------------------------------
# Step 2: Preprocess the Data
# -------------------------------
# We use the "text" column as input and "note_type" as label.
df_mimic = df_mimic.dropna(subset=["text", "note_type"])
df_mimic['note_type'] = df_mimic['note_type'].astype(str)

texts = df_mimic["text"].tolist()
labels = df_mimic["note_type"].tolist()

# -------------------------------
# Step 3: Encode Labels
# -------------------------------
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(labels)
num_classes = len(label_encoder.classes_)
print(f"Number of target classes (note types): {num_classes}")

# -------------------------------
# Step 4: Tokenization and Vocabulary Construction
# -------------------------------
def batch_tokenize(texts, batch_size=1000):
    tokenized_texts = []
    for doc in nlp.pipe(texts, batch_size=batch_size):
        # Remove punctuation and space tokens.
        tokens = [token.text for token in doc if not token.is_punct and not token.is_space]
        tokenized_texts.append(tokens)
    return tokenized_texts

tokenized_texts = batch_tokenize(texts, batch_size=1000)
print("Tokenization complete on discharge texts.")

# Build vocabulary from tokenized texts.
all_tokens = [token for tokens in tokenized_texts for token in tokens]
vocab_counter = Counter(all_tokens)
min_word_freq = 2
vocab = {token for token, count in vocab_counter.items() if count >= min_word_freq}

# Reserve indices: 0 for padding, 1 for unknown tokens.
word_to_index = {"<PAD>": 0, "<UNK>": 1}
for word in sorted(vocab):
    word_to_index[word] = len(word_to_index)
vocab_size = len(word_to_index)
print(f"Vocabulary size for MIMIC discharge: {vocab_size}")

# Convert tokenized texts to sequences of indices.
def text_to_sequence(tokens):
    return [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]

sequences = [text_to_sequence(tokens) for tokens in tokenized_texts]

# -------------------------------
# Step 5: Pad Sequences
# -------------------------------
max_len = 256  # Fixed maximum sequence length.

def pad_sequence_fn(seq, max_len):
    return seq + [0] * (max_len - len(seq)) if len(seq) < max_len else seq[:max_len]

padded_sequences = [pad_sequence_fn(seq, max_len) for seq in sequences]
X = np.array(padded_sequences)
y = np.array(labels_encoded)

# -------------------------------
# Step 6: Train/Validation Data Split
# -------------------------------
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")

# -------------------------------
# Step 7: Create PyTorch Dataset and DataLoader for Discharge Data
# -------------------------------
class MIMICDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.long), torch.tensor(self.y[idx], dtype=torch.long)

# Increase batch size and number of workers for faster loading.
batch_size = 128  # Increase batch size if memory allows.
num_workers = 8   # Adjust based on available CPU cores.

train_dataset = MIMICDataset(X_train, y_train)
val_dataset = MIMICDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)

# -------------------------------
# Step 8: Load Pre-trained GloVe Embeddings and Build Embedding Matrix
# -------------------------------
def load_glove_embeddings(filepath, embedding_dim):
    embeddings_index = {}
    with open(filepath, encoding="utf8") as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            if vector.shape[0] == embedding_dim:
                embeddings_index[word] = vector
    return embeddings_index

embedding_dim = 100
glove_path = "glove.6B.100d.txt"
if not os.path.exists(glove_path):
    raise FileNotFoundError(f"{glove_path} not found. Please download it and place it in the working directory.")

glove_embeddings = load_glove_embeddings(glove_path, embedding_dim)
print(f"Loaded {len(glove_embeddings)} word vectors from GloVe.")

embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float32)
for word, idx in word_to_index.items():
    if word in glove_embeddings:
        embedding_matrix[idx] = glove_embeddings[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

# -------------------------------
# Step 9: Define the LSTM-Only Model (No Attention)
# -------------------------------
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, output_dim, dropout=0.3,
                 pretrained_embeddings=None, freeze_embeddings=False):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.tensor(pretrained_embeddings))
            self.embedding.weight.requires_grad = not freeze_embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, 
                            batch_first=True, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, (h_n, _) = self.lstm(embedded)
        # For bidirectional LSTM, concatenate the last hidden state from both directions.
        forward_h = h_n[-2, :, :]
        backward_h = h_n[-1, :, :]
        hidden = torch.cat((forward_h, backward_h), dim=1)
        hidden = self.dropout(hidden)
        logits = self.fc(hidden)
        return logits

hidden_dim = 128
num_layers = 2
# Use the number of target classes for fine-tuning (derived from note_type)
new_output_dim = num_classes  
dropout = 0.3

# -------------------------------
# Step 10: Initialize Model and Load Pretrained Weights from Medal
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Instantiate the fine-tuning model.
model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, num_layers, new_output_dim, dropout,
                       pretrained_embeddings=embedding_matrix, freeze_embeddings=False)
model.to(device)
print(model)

# Load pretrained Medal LSTM weights (if available).
pretrained_path = "trained_models/models/medal_trained_model_LSTM.pth"
if os.path.exists(pretrained_path):
    print("Loading pretrained weights from Medal model...")
    pretrained_state = torch.load(pretrained_path, map_location=device)
    model_dict = model.state_dict()
    # Load only matching layers (ignores final fc layer if dimensions differ).
    pretrained_state = {k: v for k, v in pretrained_state.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_state)
    model.load_state_dict(model_dict)
    print("Pretrained weights loaded successfully (partial transfer if dimensions differ).")
else:
    print("Pretrained model not found; fine-tuning will start from scratch.")

# -------------------------------
# Step 11: Define Loss, Optimizer, and Training Loop for Fine-Tuning
# -------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss, epoch_correct = 0, 0
    for inputs, labels in loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
        epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

def evaluate_epoch(model, loader, criterion, device):
    model.eval()
    epoch_loss, epoch_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item() * inputs.size(0)
            epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate_epoch(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}% | Val Loss={val_loss:.4f}, Val Acc={val_acc*100:.2f}%")

# -------------------------------
# Step 12: Define Inference Function for the Fine-Tuned MIMIC Model
# -------------------------------
def predict_mimic(model, text, word_to_index, max_len, device, label_encoder):
    # Tokenize the entire text using spaCy.
    tokens = [token.text for token in nlp(text) if not token.is_punct and not token.is_space]
    seq = [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]
    seq = pad_sequence_fn(seq, max_len)
    input_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_tensor)
    pred_class = logits.argmax(dim=1).item()
    pred_label = label_encoder.inverse_transform([pred_class])[0]
    return pred_label

# Example Inference on a sample discharge text:
sample_discharge_text = (
    "Patient admitted with severe chest pain. "
    "Discharge summary indicates a myocardial infarction and appropriate treatment was provided. "
    "Follow-up is recommended."
)
predicted_label = predict_mimic(model, sample_discharge_text, word_to_index, max_len, device, label_encoder)
print(f"\nFor the sample discharge text, predicted note type is: {predicted_label}")

# -------------------------------
# Step 13: Save the Fine-Tuned Model
# -------------------------------
fine_tuned_model_path = "fine_tuned_model_MIMIC.pth"
torch.save(model.state_dict(), fine_tuned_model_path)
print(f"Fine-tuned model saved to {fine_tuned_model_path}")

Discharge dataset loaded successfully.
Columns in discharge dataset: Index(['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'text'],
      dtype='object')
Dataset reduced to 10000 samples.
Number of target classes (note types): 1
Tokenization complete on discharge texts.
Vocabulary size for MIMIC discharge: 82901
Training samples: 8000, Validation samples: 2000
Loaded 400000 word vectors from GloVe.
Using device: cuda
LSTMClassifier(
  (embedding): Embedding(82901, 100, padding_idx=0)
  (lstm): LSTM(100, 128, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)
Loading pretrained weights from Medal model...
Pretrained weights loaded successfully (partial transfer if dimensions differ).
Epoch 1/10: Train Loss=0.0000, Train Acc=100.00% | Val Loss=0.0000, Val Acc=100.00%
Epoch 2/10: Train Loss=0.0000, Train Acc=100.00% | Val Loss

In [2]:
print(df_mimic.columns)

Index(['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'text'],
      dtype='object')


In [8]:
import pandas as pd

# Define file paths
diagnosis_csv = "diagnosis.csv"
edstays_csv = "edstays.csv"
discharge_csv = "discharge.csv"

# Load Diagnosis.csv and print its columns
try:
    df_diagnosis = pd.read_csv(diagnosis_csv)
    print("Diagnosis.csv loaded successfully.")
    print("Columns in Diagnosis.csv:")
    print(df_diagnosis.columns)
except Exception as e:
    print(f"Error loading {diagnosis_csv}: {e}")

print("-" * 50)

# Load edstays.csv and print its columns
try:
    df_edstays = pd.read_csv(edstays_csv)
    print("edstays.csv loaded successfully.")
    print("Columns in edstays.csv:")
    print(df_edstays.columns)
except Exception as e:
    print(f"Error loading {edstays_csv}: {e}")

print("-" * 50)

# Load Dichange.csv and print its columns
try:
    df_discharge = pd.read_csv(discharge_csv)
    print("Discharge.csv loaded successfully.")
    print("Columns in Discharge.csv:")
    print(df_discharge.columns)
except Exception as e:
    print(f"Error loading {discharge_csv}: {e}")

Diagnosis.csv loaded successfully.
Columns in Diagnosis.csv:
Index(['subject_id', 'stay_id', 'seq_num', 'icd_code', 'icd_version',
       'icd_title'],
      dtype='object')
--------------------------------------------------
edstays.csv loaded successfully.
Columns in edstays.csv:
Index(['subject_id', 'hadm_id', 'stay_id', 'intime', 'outtime', 'gender',
       'race', 'arrival_transport', 'disposition'],
      dtype='object')
--------------------------------------------------
Discharge.csv loaded successfully.
Columns in Discharge.csv:
Index(['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'text'],
      dtype='object')


In [11]:
import pandas as pd
import time

# -------------------------------
# File paths (adjust if necessary)
# -------------------------------
diagnosis_csv = "diagnosis.csv"
edstays_csv = "edstays.csv"
discharge_csv = "discharge.csv"

# -------------------------------
# Step 1: Load the Datasets, Selecting Only Needed Columns
# -------------------------------
# Select only columns needed for merging
diag_usecols = ['subject_id', 'stay_id', 'icd_code', 'icd_version', 'icd_title']
edstays_usecols = ['subject_id', 'stay_id', 'hadm_id', 'intime', 'outtime', 'gender', 'race', 'arrival_transport', 'disposition']
discharge_usecols = ['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq', 'charttime', 'storetime', 'text']

df_diag = pd.read_csv(diagnosis_csv, usecols=diag_usecols)
df_edstays = pd.read_csv(edstays_csv, usecols=edstays_usecols)
df_discharge = pd.read_csv(discharge_csv, usecols=discharge_usecols)

print("diagnosis.csv loaded. Columns:")
print(df_diag.columns)
print("-" * 50)
print("edstays.csv loaded. Columns:")
print(df_edstays.columns)
print("-" * 50)
print("discharge.csv loaded. Columns:")
print(df_discharge.columns)
print("-" * 50)

# -------------------------------
# Step 2: Merge Diagnosis and edstays Using Indexes and join
# -------------------------------

# Set the index on the common keys ('subject_id' and 'stay_id') for both dataframes.
start_time = time.time()
df_diag.set_index(['subject_id', 'stay_id'], inplace=True)
df_edstays.set_index(['subject_id', 'stay_id'], inplace=True)

# Use join on the indexes – joining on the key is often faster.
df_diag_ed = df_diag.join(df_edstays, how="inner").reset_index()
print("Merged Diagnosis and edstays. Time taken: {:.2f} seconds".format(time.time() - start_time))
print("Columns after first merge:")
print(df_diag_ed.columns)
print("-" * 50)

# -------------------------------
# Step 3: Merge the Result with Discharge Data Using Indexes
# -------------------------------

# Both the merged dataframe and discharge data share 'subject_id' and 'hadm_id'.
start_time = time.time()
df_diag_ed.set_index(['subject_id', 'hadm_id'], inplace=True)
df_discharge.set_index(['subject_id', 'hadm_id'], inplace=True)

# Join on the indexes.
df_merged = df_diag_ed.join(df_discharge, how="inner").reset_index()
print("Merged with Discharge data. Time taken: {:.2f} seconds".format(time.time() - start_time))
print("Columns in final merged dataset:")
print(df_merged.columns)
print("-" * 50)

# -------------------------------
# Step 4: Save the Final Merged Dataset
# -------------------------------
output_csv = "MIMIC_diagnosis_prediction_dataset.csv"
df_merged.to_csv(output_csv, index=False)
print(f"Merged dataset saved to {output_csv}.")

diagnosis.csv loaded. Columns:
Index(['subject_id', 'stay_id', 'icd_code', 'icd_version', 'icd_title'], dtype='object')
--------------------------------------------------
edstays.csv loaded. Columns:
Index(['subject_id', 'hadm_id', 'stay_id', 'intime', 'outtime', 'gender',
       'race', 'arrival_transport', 'disposition'],
      dtype='object')
--------------------------------------------------
discharge.csv loaded. Columns:
Index(['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'text'],
      dtype='object')
--------------------------------------------------
Merged Diagnosis and edstays. Time taken: 13.21 seconds
Columns after first merge:
Index(['subject_id', 'stay_id', 'icd_code', 'icd_version', 'icd_title',
       'hadm_id', 'intime', 'outtime', 'gender', 'race', 'arrival_transport',
       'disposition'],
      dtype='object')
--------------------------------------------------
Merged with Discharge data. Time taken: 10.40 seconds
Colu

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
from collections import Counter
import random
import spacy
import os
import time

# -------------------------------
# Set random seeds for reproducibility
# -------------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# -------------------------------
# Enable cuDNN benchmark for potential speed-up (effective if using GPU)
# -------------------------------
if torch.cuda.is_available():
    cudnn.benchmark = True

# -------------------------------
# Load spaCy English model (disable parser and NER for speed)
# -------------------------------
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

# -------------------------------
# Step 1: Load and Inspect the Merged Dataset for Diagnosis Prediction
# -------------------------------
merged_csv = "MIMIC_diagnosis_prediction_dataset.csv"
try:
    df_mimic = pd.read_csv(merged_csv)
    print("Merged dataset loaded successfully.")
except FileNotFoundError:
    print(f"Error: '{merged_csv}' not found. Please ensure the file is in the working directory.")
    exit()

print("Columns in merged dataset:")
print(df_mimic.columns)
print("-" * 50)

# -------------------------------
# Optionally reduce the number of samples for faster execution.
# -------------------------------
max_samples = 10000  # Adjust as needed
if len(df_mimic) > max_samples:
    df_mimic = df_mimic.sample(n=max_samples, random_state=42).reset_index(drop=True)
    print(f"Dataset reduced to {max_samples} samples.")

# -------------------------------
# Step 2: Preprocess the Data
# -------------------------------
# For diagnosis prediction, we use the note text as input and diagnosis code (icd_code) as the label.
df_mimic = df_mimic.dropna(subset=["text", "icd_code"])
df_mimic['icd_code'] = df_mimic['icd_code'].astype(str)

texts = df_mimic["text"].tolist()
labels = df_mimic["icd_code"].tolist()

# -------------------------------
# Step 3: Encode Diagnosis Labels
# -------------------------------
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(labels)
num_classes = len(label_encoder.classes_)
print(f"Number of target diagnosis classes: {num_classes}")

# -------------------------------
# Step 4: Tokenization and Vocabulary Construction
# -------------------------------
def batch_tokenize(texts, batch_size=1000):
    tokenized_texts = []
    for doc in nlp.pipe(texts, batch_size=batch_size):
        tokens = [token.text for token in doc if not token.is_punct and not token.is_space]
        tokenized_texts.append(tokens)
    return tokenized_texts

start_time = time.time()
tokenized_texts = batch_tokenize(texts, batch_size=1000)
print("Tokenization complete on note texts. Time taken: {:.2f} seconds".format(time.time()-start_time))

all_tokens = [token for tokens in tokenized_texts for token in tokens]
vocab_counter = Counter(all_tokens)
min_word_freq = 2
vocab = {token for token, count in vocab_counter.items() if count >= min_word_freq}

# Reserve special tokens
word_to_index = {"<PAD>": 0, "<UNK>": 1}
for word in sorted(vocab):
    word_to_index[word] = len(word_to_index)
vocab_size = len(word_to_index)
print(f"Vocabulary size for note texts: {vocab_size}")

def text_to_sequence(tokens):
    return [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]

sequences = [text_to_sequence(tokens) for tokens in tokenized_texts]

# -------------------------------
# Step 5: Pad Sequences
# -------------------------------
max_len = 256  # Fixed maximum sequence length
def pad_sequence_fn(seq, max_len):
    return seq + [0]*(max_len - len(seq)) if len(seq) < max_len else seq[:max_len]

padded_sequences = [pad_sequence_fn(seq, max_len) for seq in sequences]
X = np.array(padded_sequences)
y = np.array(labels_encoded)

# -------------------------------
# Filter Out Underrepresented Classes (with <2 samples)
# -------------------------------
class_counts = pd.Series(y).value_counts()
print("Original class counts:")
print(class_counts)
classes_to_remove = class_counts[class_counts < 2].index.tolist()
if classes_to_remove:
    print("The following classes have less than 2 samples and will be removed:", classes_to_remove)
    mask = ~np.isin(y, classes_to_remove)
    X = X[mask]
    y = y[mask]
    print("Updated class counts:")
    print(pd.Series(y).value_counts())
else:
    print("All classes have at least 2 samples.")

# -------------------------------
# Step 6: Train/Validation Data Split
# -------------------------------
test_ratio = 0.2
test_count = int(len(X)*test_ratio)
if test_count < num_classes:
    print(f"Warning: test_count ({test_count}) is less than number of classes ({num_classes}).")
    print("Using non-stratified split.")
    stratify_param = None
else:
    stratify_param = y

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=test_ratio, random_state=42, stratify=stratify_param
)
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")

# -------------------------------
# Step 7: Create PyTorch Dataset and DataLoader
# -------------------------------
class MIMICDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.long), torch.tensor(self.y[idx], dtype=torch.long)

batch_size = 128
num_workers = 8

train_dataset = MIMICDataset(X_train, y_train)
val_dataset = MIMICDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)

# -------------------------------
# Step 8: Load Pre-trained GloVe Embeddings and Build Embedding Matrix
# -------------------------------
def load_glove_embeddings(filepath, embedding_dim):
    embeddings_index = {}
    with open(filepath, encoding="utf8") as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            if vector.shape[0] == embedding_dim:
                embeddings_index[word] = vector
    return embeddings_index

embedding_dim = 100
glove_path = "glove.6B.100d.txt"
if not os.path.exists(glove_path):
    raise FileNotFoundError(f"{glove_path} not found. Please download it and place it in the working directory.")

glove_embeddings = load_glove_embeddings(glove_path, embedding_dim)
print(f"Loaded {len(glove_embeddings)} word vectors from GloVe.")

embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float32)
for word, idx in word_to_index.items():
    if word in glove_embeddings:
        embedding_matrix[idx] = glove_embeddings[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

# -------------------------------
# Step 9: Define the LSTM-Only Model (No Attention)
# -------------------------------
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, output_dim, dropout=0.3,
                 pretrained_embeddings=None, freeze_embeddings=False):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.tensor(pretrained_embeddings))
            self.embedding.weight.requires_grad = not freeze_embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True,
                            dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, (h_n, _) = self.lstm(embedded)
        forward_h = h_n[-2, :, :]
        backward_h = h_n[-1, :, :]
        hidden = torch.cat((forward_h, backward_h), dim=1)
        hidden = self.dropout(hidden)
        logits = self.fc(hidden)
        return logits

hidden_dim = 128
num_layers = 2
new_output_dim = num_classes  # Diagnosis classes
dropout = 0.3

# -------------------------------
# Step 10: Initialize Model and Load Pretrained Weights from Medal (if available)
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, num_layers, new_output_dim, dropout,
                       pretrained_embeddings=embedding_matrix, freeze_embeddings=False)
model.to(device)
print(model)

pretrained_path = "trained_model_LSTM.pth"
if os.path.exists(pretrained_path):
    print("Loading pretrained weights from Medal model...")
    pretrained_state = torch.load(pretrained_path, map_location=device)
    model_dict = model.state_dict()
    pretrained_state = {k: v for k, v in pretrained_state.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_state)
    model.load_state_dict(model_dict)
    print("Pretrained weights loaded successfully (partial transfer if dimensions differ).")
else:
    print("Pretrained model not found; fine-tuning will start from scratch.")

# -------------------------------
# Step 11: Define Loss, Optimizer, and Training Loop for Diagnosis Prediction
# -------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss, epoch_correct = 0, 0
    for inputs, labels in loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
        epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss/len(loader.dataset), epoch_correct/len(loader.dataset)

def evaluate_epoch(model, loader, criterion, device):
    model.eval()
    epoch_loss, epoch_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item() * inputs.size(0)
            epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss/len(loader.dataset), epoch_correct/len(loader.dataset)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate_epoch(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}% | Val Loss={val_loss:.4f}, Val Acc={val_acc*100:.2f}%")

# -------------------------------
# Step 12: Define Inference Function for Diagnosis Prediction
# -------------------------------
def predict_diagnosis(model, text, word_to_index, max_len, device, label_encoder):
    tokens = [token.text for token in nlp(text) if not token.is_punct and not token.is_space]
    seq = [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]
    seq = pad_sequence_fn(seq, max_len)
    input_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_tensor)
    pred_class = logits.argmax(dim=1).item()
    pred_label = label_encoder.inverse_transform([pred_class])[0]
    return pred_label

# Example Inference on a sample note:
sample_text = (
    "Patient presented with fever, cough, and shortness of breath. "
    "Chest X-ray revealed infiltrates consistent with pneumonia; "
    "appropriate treatment was initiated."
)
predicted_diagnosis = predict_diagnosis(model, sample_text, word_to_index, max_len, device, label_encoder)
print(f"\nFor the sample note, predicted diagnosis (icd_code) is: {predicted_diagnosis}")

# -------------------------------
# Step 13: Save the Fine-Tuned Diagnosis Prediction Model
# -------------------------------
fine_tuned_model_path = "trained_models/models/lstm_finetuned_for_diagnosis_pred.pth"
torch.save(model.state_dict(), fine_tuned_model_path)
print(f"Fine-tuned diagnosis prediction model saved to {fine_tuned_model_path}")

Merged dataset loaded successfully.
Columns in merged dataset:
Index(['subject_id', 'hadm_id', 'stay_id', 'icd_code', 'icd_version',
       'icd_title', 'intime', 'outtime', 'gender', 'race', 'arrival_transport',
       'disposition', 'note_id', 'note_type', 'note_seq', 'charttime',
       'storetime', 'text'],
      dtype='object')
--------------------------------------------------
Dataset reduced to 10000 samples.
Number of target diagnosis classes: 1763
Tokenization complete on note texts. Time taken: 1728.91 seconds
Vocabulary size for note texts: 85054
Original class counts:
188     333
59      158
1312    130
1321    128
515     126
       ... 
1051      1
1035      1
1019      1
1003      1
1759      1
Length: 1763, dtype: int64
The following classes have less than 2 samples and will be removed: [1183, 1071, 1381, 1087, 975, 1047, 1175, 1119, 1031, 1429, 1421, 1023, 1309, 1317, 1135, 1293, 1063, 1007, 1405, 1143, 1277, 1167, 1103, 1389, 1151, 981, 1245, 709, 1599, 717, 725, 1591

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
from collections import Counter
import random
import spacy
import os
import time

# --------------------------------------
# 1. Set random seeds for reproducibility
# --------------------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# --------------------------------------
# 2. Enable cuDNN benchmark for speed-up (if using GPU)
# --------------------------------------
if torch.cuda.is_available():
    cudnn.benchmark = True

# --------------------------------------
# 3. Load spaCy English model (with parser and NER disabled for speed)
# --------------------------------------
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

# --------------------------------------
# 4. Load the merged dataset (created from Diagnosis.csv, edstays.csv, and Discharge.csv)
# --------------------------------------
merged_csv = "MIMIC_diagnosis_prediction_dataset.csv"
try:
    df_mimic = pd.read_csv(merged_csv)
    print("Merged dataset loaded successfully.")
except FileNotFoundError:
    print(f"Error: '{merged_csv}' not found. Please ensure the file is in the working directory.")
    exit()

print("Columns in merged dataset:")
print(df_mimic.columns)
print("-" * 50)

# --------------------------------------
# 5. Optionally reduce the number of samples for faster experimentation.
# --------------------------------------
max_samples = 10000  # change this value as needed
if len(df_mimic) > max_samples:
    df_mimic = df_mimic.sample(n=max_samples, random_state=42).reset_index(drop=True)
    print(f"Dataset reduced to {max_samples} samples.")

# --------------------------------------
# 6. Group the ICD codes by using only the first three characters.
# --------------------------------------
# Ensure that the icd_code column is a string.
df_mimic['icd_code'] = df_mimic['icd_code'].astype(str)

# Create a new column for the grouped ICD codes.
df_mimic['icd_group'] = df_mimic['icd_code'].str[:3]
print("Number of unique ICD groups:", df_mimic['icd_group'].nunique())

# --------------------------------------
# 7. Preprocess the data: use note text as input and the icd_group as target label.
# --------------------------------------
df_mimic = df_mimic.dropna(subset=["text", "icd_group"])
texts = df_mimic["text"].tolist()
labels = df_mimic["icd_group"].tolist()

# --------------------------------------
# 8. Encode the grouped ICD codes using LabelEncoder.
# --------------------------------------
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(labels)
num_classes = len(label_encoder.classes_)
print(f"Number of target diagnosis classes after grouping: {num_classes}")

# --------------------------------------
# 9. Tokenization and vocabulary construction.
# --------------------------------------
def batch_tokenize(texts, batch_size=1000):
    tokenized_texts = []
    for doc in nlp.pipe(texts, batch_size=batch_size):
        # Remove punctuation and whitespace tokens.
        tokens = [token.text for token in doc if not token.is_punct and not token.is_space]
        tokenized_texts.append(tokens)
    return tokenized_texts

start_time = time.time()
tokenized_texts = batch_tokenize(texts, batch_size=1000)
print("Tokenization complete on note texts. Time taken: {:.2f} seconds".format(time.time()-start_time))

# Build a vocabulary using a minimum word frequency threshold.
all_tokens = [token for tokens in tokenized_texts for token in tokens]
vocab_counter = Counter(all_tokens)
min_word_freq = 2
vocab = {token for token, count in vocab_counter.items() if count >= min_word_freq}

# Reserve special tokens: 0 for <PAD> and 1 for <UNK>.
word_to_index = {"<PAD>": 0, "<UNK>": 1}
for word in sorted(vocab):
    word_to_index[word] = len(word_to_index)
vocab_size = len(word_to_index)
print(f"Vocabulary size for note texts: {vocab_size}")

def text_to_sequence(tokens):
    return [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]

sequences = [text_to_sequence(tokens) for tokens in tokenized_texts]

# --------------------------------------
# 10. Pad sequences to a fixed length.
# --------------------------------------
max_len = 256  # Fixed maximum sequence length.
def pad_sequence_fn(seq, max_len):
    return seq + [0] * (max_len - len(seq)) if len(seq) < max_len else seq[:max_len]

padded_sequences = [pad_sequence_fn(seq, max_len) for seq in sequences]
X = np.array(padded_sequences)
y = np.array(labels_encoded)

# --------------------------------------
# 11. Filter out classes with fewer than 2 samples to support stratified splitting.
# --------------------------------------
class_counts = pd.Series(y).value_counts()
print("Original class counts:")
print(class_counts)
classes_to_remove = class_counts[class_counts < 2].index.tolist()
if classes_to_remove:
    print("The following classes (encoded as integers) have less than 2 samples and will be removed:", classes_to_remove)
    mask = ~np.isin(y, classes_to_remove)
    X = X[mask]
    y = y[mask]
    print("Updated class counts:")
    print(pd.Series(y).value_counts())
else:
    print("All classes have at least 2 samples.")

# --------------------------------------
# 12. Split the data into training and validation sets.
# --------------------------------------
test_ratio = 0.2
test_count = int(len(X)*test_ratio)
if test_count < num_classes:
    print(f"Warning: test_count ({test_count}) is less than the number of classes ({num_classes}).")
    print("Using non-stratified split.")
    stratify_param = None
else:
    stratify_param = y

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=test_ratio, random_state=42, stratify=stratify_param
)
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")

# --------------------------------------
# 13. Create PyTorch Dataset and DataLoaders.
# --------------------------------------
class MIMICDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.long), torch.tensor(self.y[idx], dtype=torch.long)

batch_size = 128
num_workers = 8

train_dataset = MIMICDataset(X_train, y_train)
val_dataset = MIMICDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        num_workers=num_workers, pin_memory=True)

# --------------------------------------
# 14. Load pre-trained GloVe embeddings and build the embedding matrix.
# --------------------------------------
def load_glove_embeddings(filepath, embedding_dim):
    embeddings_index = {}
    with open(filepath, encoding="utf8") as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            if vector.shape[0] == embedding_dim:
                embeddings_index[word] = vector
    return embeddings_index

embedding_dim = 100
glove_path = "glove.6B.100d.txt"
if not os.path.exists(glove_path):
    raise FileNotFoundError(f"{glove_path} not found. Please download it and place it in the working directory.")

glove_embeddings = load_glove_embeddings(glove_path, embedding_dim)
print(f"Loaded {len(glove_embeddings)} word vectors from GloVe.")

embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float32)
for word, idx in word_to_index.items():
    if word in glove_embeddings:
        embedding_matrix[idx] = glove_embeddings[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

# --------------------------------------
# 15. Define the LSTM-based classification model.
# --------------------------------------
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, output_dim, dropout=0.3,
                 pretrained_embeddings=None, freeze_embeddings=False):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.tensor(pretrained_embeddings))
            self.embedding.weight.requires_grad = not freeze_embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True,
                            dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, (h_n, _) = self.lstm(embedded)
        forward_h = h_n[-2, :, :]
        backward_h = h_n[-1, :, :]
        hidden = torch.cat((forward_h, backward_h), dim=1)
        hidden = self.dropout(hidden)
        logits = self.fc(hidden)
        return logits

hidden_dim = 128
num_layers = 2
new_output_dim = num_classes  # Number of ICD groups after grouping.
dropout = 0.3

# --------------------------------------
# 16. Initialize the model and load optional pretrained weights from the Medal model.
# --------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, num_layers, new_output_dim, dropout,
                       pretrained_embeddings=embedding_matrix, freeze_embeddings=False)
model.to(device)
print(model)

pretrained_path = "trained_model_LSTM.pth"
if os.path.exists(pretrained_path):
    print("Loading pretrained weights from Medal model...")
    pretrained_state = torch.load(pretrained_path, map_location=device)
    model_dict = model.state_dict()
    # Update only matching layers; skip final FC layer if dimensions differ.
    pretrained_state = {k: v for k, v in pretrained_state.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_state)
    model.load_state_dict(model_dict)
    print("Pretrained weights loaded successfully (partial transfer if dimensions differ).")
else:
    print("Pretrained model not found; fine-tuning will start from scratch.")

# --------------------------------------
# 17. Define the loss, optimizer, and training loop.
# --------------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss, epoch_correct = 0, 0
    for inputs, labels in loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
        epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

def evaluate_epoch(model, loader, criterion, device):
    model.eval()
    epoch_loss, epoch_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item() * inputs.size(0)
            epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

num_epochs = 20
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate_epoch(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}% | Val Loss={val_loss:.4f}, Val Acc={val_acc*100:.2f}%")

# --------------------------------------
# 18. Define an inference function.
# --------------------------------------
def predict_diagnosis(model, text, word_to_index, max_len, device, label_encoder):
    tokens = [token.text for token in nlp(text) if not token.is_punct and not token.is_space]
    seq = [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]
    seq = pad_sequence_fn(seq, max_len)
    input_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_tensor)
    pred_class = logits.argmax(dim=1).item()
    pred_label = label_encoder.inverse_transform([pred_class])[0]
    return pred_label

# Example Inference on a sample note.
sample_text = (
    "Patient presented with fever, cough, and shortness of breath. "
    "Chest X-ray revealed infiltrates consistent with pneumonia; "
    "appropriate treatment was initiated."
)
predicted_diagnosis = predict_diagnosis(model, sample_text, word_to_index, max_len, device, label_encoder)
print(f"\nFor the sample note, predicted diagnosis group (first 3 characters of icd_code) is: {predicted_diagnosis}")

# --------------------------------------
# 19. Save the fine-tuned model.
# --------------------------------------
fine_tuned_model_path = "trained_models/models/lstm_finetuned_for_diagnosis_pred.pth"
torch.save(model.state_dict(), fine_tuned_model_path)
print(f"Fine-tuned diagnosis prediction model saved to {fine_tuned_model_path}")

Merged dataset loaded successfully.
Columns in merged dataset:
Index(['subject_id', 'hadm_id', 'stay_id', 'icd_code', 'icd_version',
       'icd_title', 'intime', 'outtime', 'gender', 'race', 'arrival_transport',
       'disposition', 'note_id', 'note_type', 'note_seq', 'charttime',
       'storetime', 'text'],
      dtype='object')
--------------------------------------------------
Dataset reduced to 10000 samples.
Number of unique ICD groups: 782
Number of target diagnosis classes after grouping: 782
Tokenization complete on note texts. Time taken: 1603.81 seconds
Vocabulary size for note texts: 85054
Original class counts:
252    408
112    333
612    255
44     230
258    227
      ... 
373      1
349      1
341      1
333      1
0        1
Length: 782, dtype: int64
The following classes (encoded as integers) have less than 2 samples and will be removed: [385, 302, 326, 294, 1, 33, 760, 305, 574, 65, 121, 313, 273, 414, 249, 209, 438, 454, 366, 470, 478, 145, 105, 598, 358, 743, 35

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
from collections import Counter
import random
import spacy
import os
import time
import pickle

# --------------------------------------
# 1. Set random seeds for reproducibility
# --------------------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# --------------------------------------
# 2. Enable cuDNN benchmark for faster training (if using GPU)
# --------------------------------------
if torch.cuda.is_available():
    cudnn.benchmark = True

# --------------------------------------
# 3. Load spaCy English model (disable parser/NER for speed)
# --------------------------------------
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

# --------------------------------------
# 4. Load the merged dataset for Diagnosis Prediction
# --------------------------------------
merged_csv = "MIMIC_diagnosis_prediction_dataset.csv"
try:
    df_mimic = pd.read_csv(merged_csv)
    print("Merged dataset loaded successfully.")
except FileNotFoundError:
    print(f"Error: '{merged_csv}' not found. Please ensure the file is in the working directory.")
    exit()
print("Columns in merged dataset:")
print(df_mimic.columns)
print("-" * 50)

# --------------------------------------
# 5. Optionally reduce the number of samples for faster experimentation.
# --------------------------------------
max_samples = 3000000  # Adjust this value as needed
if len(df_mimic) > max_samples:
    df_mimic = df_mimic.sample(n=max_samples, random_state=42).reset_index(drop=True)
    print(f"Dataset reduced to {max_samples} samples.")

# --------------------------------------
# 6. Group ICD codes into broader classes.
# --------------------------------------
# Ensure icd_code is a string, then group by taking the first 2 characters.
df_mimic['icd_code'] = df_mimic['icd_code'].astype(str)
group_length = 2
df_mimic['icd_group'] = df_mimic['icd_code'].str[:group_length]
print("Initial number of unique ICD groups (first {} chars): {}".format(
    group_length, df_mimic['icd_group'].nunique()))

# Map groups with very low frequency (threshold) to "Other".
threshold = 20  # minimum number of samples per group to keep
group_counts = df_mimic['icd_group'].value_counts()
df_mimic['icd_group'] = df_mimic['icd_group'].apply(lambda x: x if group_counts[x] >= threshold else "Other")
print("Number of unique ICD groups after mapping infrequent ones to 'Other':",
      df_mimic['icd_group'].nunique())

# --------------------------------------
# 7. Preprocess data: use note text as input and "icd_group" as target.
# --------------------------------------
df_mimic = df_mimic.dropna(subset=["text", "icd_group"])
texts = df_mimic["text"].tolist()
labels = df_mimic["icd_group"].tolist()

# --------------------------------------
# 8. Encode the grouped ICD codes.
# --------------------------------------
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(labels)
num_classes = len(label_encoder.classes_)
print(f"Number of target diagnosis classes after grouping and mapping: {num_classes}")

# --------------------------------------
# 9. Tokenization with caching.
# --------------------------------------
def batch_tokenize(texts, batch_size=1000):
    tokenized_texts = []
    for doc in nlp.pipe(texts, batch_size=batch_size):
        tokens = [token.text for token in doc if not token.is_punct and not token.is_space]
        tokenized_texts.append(tokens)
    return tokenized_texts

tokenized_cache_file = "trained_models/tokenizers/tokenized_texts.pkl"
# Check if cache exists and if its length matches the current texts.
if os.path.exists(tokenized_cache_file):
    with open(tokenized_cache_file, "rb") as f:
        cached_tokenized_texts = pickle.load(f)
    if len(cached_tokenized_texts) != len(texts):
        print("Cache length does not match current data length. Re-tokenizing.")
        tokenized_texts = batch_tokenize(texts, batch_size=1000)
        with open(tokenized_cache_file, "wb") as f:
            pickle.dump(tokenized_texts, f)
        print("Tokenized texts cached for future use.")
    else:
        tokenized_texts = cached_tokenized_texts
        print("Loaded cached tokenized texts.")
else:
    start_time = time.time()
    tokenized_texts = batch_tokenize(texts, batch_size=1000)
    print("Tokenization complete on note texts. Time taken: {:.2f} seconds".format(time.time()-start_time))
    with open(tokenized_cache_file, "wb") as f:
        pickle.dump(tokenized_texts, f)
    print("Tokenized texts cached for future use.")

# --------------------------------------
# 10. Build vocabulary and convert texts to sequences.
# --------------------------------------
all_tokens = [token for tokens in tokenized_texts for token in tokens]
vocab_counter = Counter(all_tokens)
min_word_freq = 2
vocab = {token for token, count in vocab_counter.items() if count >= min_word_freq}

# Reserve special tokens.
word_to_index = {"<PAD>": 0, "<UNK>": 1}
for word in sorted(vocab):
    word_to_index[word] = len(word_to_index)
vocab_size = len(word_to_index)
print(f"Vocabulary size for note texts: {vocab_size}")

def text_to_sequence(tokens):
    return [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]

sequences = [text_to_sequence(tokens) for tokens in tokenized_texts]

# --------------------------------------
# 11. Pad sequences.
# --------------------------------------
max_len = 256  # fixed length
def pad_sequence_fn(seq, max_len):
    return seq + [0] * (max_len - len(seq)) if len(seq) < max_len else seq[:max_len]

padded_sequences = [pad_sequence_fn(seq, max_len) for seq in sequences]
X = np.array(padded_sequences)
y = np.array(labels_encoded)

# --------------------------------------
# 12. (Optional) Filter out classes with fewer than 2 samples.
# --------------------------------------
class_counts = pd.Series(y).value_counts()
print("Original class counts:")
print(class_counts)
classes_to_remove = class_counts[class_counts < 2].index.tolist()
if classes_to_remove:
    print("The following classes (encoded as integers) have less than 2 samples and will be removed:", classes_to_remove)
    mask = ~np.isin(y, classes_to_remove)
    X = X[mask]
    y = y[mask]
    print("Updated class counts:")
    print(pd.Series(y).value_counts())
else:
    print("All classes have at least 2 samples.")

# --------------------------------------
# 13. Train/Validation split.
# --------------------------------------
test_ratio = 0.2
test_count = int(len(X) * test_ratio)
if test_count < num_classes:
    print(f"Warning: test_count ({test_count}) is less than number of classes ({num_classes}). Using non-stratified split.")
    stratify_param = None
else:
    stratify_param = y

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=test_ratio, random_state=42, stratify=stratify_param)
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")

# --------------------------------------
# 14. Create PyTorch Datasets and DataLoaders.
# --------------------------------------
class MIMICDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.long), torch.tensor(self.y[idx], dtype=torch.long)

batch_size = 128
num_workers = 8

train_dataset = MIMICDataset(X_train, y_train)
val_dataset = MIMICDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)

# --------------------------------------
# 15. Load pre-trained GloVe embeddings and build the embedding matrix.
# --------------------------------------
def load_glove_embeddings(filepath, embedding_dim):
    embeddings_index = {}
    with open(filepath, encoding="utf8") as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            if vector.shape[0] == embedding_dim:
                embeddings_index[word] = vector
    return embeddings_index

embedding_dim = 100
glove_path = "glove.6B.100d.txt"
if not os.path.exists(glove_path):
    raise FileNotFoundError(f"{glove_path} not found. Please download it and place it in the working directory.")

glove_embeddings = load_glove_embeddings(glove_path, embedding_dim)
print(f"Loaded {len(glove_embeddings)} word vectors from GloVe.")

embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float32)
for word, idx in word_to_index.items():
    if word in glove_embeddings:
        embedding_matrix[idx] = glove_embeddings[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

# --------------------------------------
# 16. Define an Attention module.
# --------------------------------------
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, 1)
    def forward(self, lstm_outputs):
        # lstm_outputs: (batch, seq_len, hidden_dim*2)
        weights = self.attn(lstm_outputs)  # (batch, seq_len, 1)
        weights = torch.softmax(weights, dim=1)  # (batch, seq_len, 1)
        context = torch.sum(weights * lstm_outputs, dim=1)  # (batch, hidden_dim*2)
        return context

# --------------------------------------
# 17. Define the LSTM model with attention.
# --------------------------------------
class LSTMClassifierWithAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, output_dim, dropout=0.3,
                 pretrained_embeddings=None, freeze_embeddings=False):
        super(LSTMClassifierWithAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(torch.tensor(pretrained_embeddings))
            self.embedding.weight.requires_grad = not freeze_embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True,
                            dropout=dropout, bidirectional=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        embedded = self.embedding(x)  # (batch, seq_len, emb_dim)
        lstm_out, _ = self.lstm(embedded)  # (batch, seq_len, hidden_dim*2)
        context = self.attention(lstm_out)  # (batch, hidden_dim*2)
        context = self.dropout(context)
        logits = self.fc(context)
        return logits

hidden_dim = 128
num_layers = 2
new_output_dim = num_classes  # Number of ICD groups (after grouping)
dropout = 0.3

# --------------------------------------
# 18. Initialize the model and load optional pretrained weights.
# --------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = LSTMClassifierWithAttention(vocab_size, embedding_dim, hidden_dim, num_layers, new_output_dim, dropout,
                                    pretrained_embeddings=embedding_matrix, freeze_embeddings=False)
model.to(device)
print(model)

pretrained_path = "trained_model_LSTM_Attention.pth"
if os.path.exists(pretrained_path):
    print("Loading pretrained weights from Medal model...")
    pretrained_state = torch.load(pretrained_path, map_location=device)
    model_dict = model.state_dict()
    pretrained_state = {k: v for k, v in pretrained_state.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_state)
    model.load_state_dict(model_dict)
    print("Pretrained weights loaded successfully (partial transfer if dimensions differ).")
else:
    print("Pretrained model not found; fine-tuning will start from scratch.")

# --------------------------------------
# 19. Define loss, optimizer, and a learning rate scheduler.
# --------------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5, verbose=True)

# --------------------------------------
# 20. Define the training and evaluation loops.
# --------------------------------------
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss, epoch_correct = 0, 0
    for inputs, labels in loader:
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
        epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

def evaluate_epoch(model, loader, criterion, device):
    model.eval()
    epoch_loss, epoch_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item() * inputs.size(0)
            epoch_correct += (outputs.argmax(dim=1) == labels).sum().item()
    return epoch_loss / len(loader.dataset), epoch_correct / len(loader.dataset)

num_epochs = 50
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate_epoch(model, val_loader, criterion, device)
    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}% | Val Loss={val_loss:.4f}, Val Acc={val_acc*100:.2f}%")

# --------------------------------------
# 21. Define an inference function.
# --------------------------------------
def predict_diagnosis(model, text, word_to_index, max_len, device, label_encoder):
    tokens = [token.text for token in nlp(text) if not token.is_punct and not token.is_space]
    seq = [word_to_index.get(token, word_to_index["<UNK>"]) for token in tokens]
    seq = pad_sequence_fn(seq, max_len)
    input_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        logits = model(input_tensor)
    pred_class = logits.argmax(dim=1).item()
    pred_label = label_encoder.inverse_transform([pred_class])[0]
    return pred_label

# --------------------------------------
# 22. Example inference on a sample note.
# --------------------------------------
sample_text = (
    "Patient presented with fever, cough, and shortness of breath. "
    "Chest X-ray revealed infiltrates consistent with pneumonia; "
    "appropriate treatment was initiated."
)
predicted_diagnosis = predict_diagnosis(model, sample_text, word_to_index, max_len, device, label_encoder)
print(f"\nFor the sample note, predicted diagnosis group is: {predicted_diagnosis}")

# --------------------------------------
# 23. Save the fine-tuned model.
# --------------------------------------
fine_tuned_model_path = "trained_models/models/lstm_finetuned_for_diagnosis_pred.pth"
torch.save(model.state_dict(), fine_tuned_model_path)
print(f"Fine-tuned diagnosis prediction model saved to {fine_tuned_model_path}")

Merged dataset loaded successfully.
Columns in merged dataset:
Index(['subject_id', 'hadm_id', 'stay_id', 'icd_code', 'icd_version',
       'icd_title', 'intime', 'outtime', 'gender', 'race', 'arrival_transport',
       'disposition', 'note_id', 'note_type', 'note_seq', 'charttime',
       'storetime', 'text'],
      dtype='object')
--------------------------------------------------
Initial number of unique ICD groups (first 2 chars): 316
Number of unique ICD groups after mapping infrequent ones to 'Other': 245
Number of target diagnosis classes after grouping and mapping: 245
Cache length does not match current data length. Re-tokenizing.
