# Fine-Tuning LSTM + Self-Attention Model on MIMIC-IV Notes for Diagnosis Prediction

In this section, we fine-tune a previously trained LSTM + Self-Attention model (originally trained on the MeDAL dataset) on the MIMIC-IV clinical notes dataset for the task of **diagnosis prediction**.

We use discharge summaries and other clinical notes from the MIMIC-IV dataset and aim to predict one or more ICD diagnosis codes for each note. This is a **multi-label classification** problem, as each clinical note can be associated with multiple diagnoses.

Here’s a summary of the pipeline we built:
- Loaded and preprocessed the MIMIC-IV dataset, extracting tokenized summaries of up to 200 words.
- Used GloVe embeddings trained on an external vocabulary derived from the dataset.
- Created lazy loading dataloaders to efficiently process and batch data during training.
- Fine-tuned the LSTM + Self-Attention model by training on this new dataset, using the pretrained model weights from MeDAL as a starting point.

This approach helps us evaluate how well a model trained on biomedical abbreviation disambiguation (MeDAL) can transfer to a real-world clinical task like diagnosis prediction from notes.

MIMIC-IV dataset can be found [here](https://physionet.org/content/mimiciii-demo/1.4/)

In [1]:
%load_ext autoreload
%autoreload 2
%run ../setup.py

from src.data.mimic import MIMIC_IV
from env import ProjectPaths
import torch
import pandas as pd
import yaml
from torch.utils.data import Dataset, DataLoader
import torch
from src.models.trainer import ModelTrainer
from src.vectorizer.glove_embeddings import GloVeEmbedding
import pyarrow.parquet as pq
import numpy as np
from tqdm import tqdm
from collections import Counter

Environment set up: sys.path updated, working dir set to project root.


[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/prashanthjaganathan/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/prashanthjaganathan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/prashanthjaganathan/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


#### Download and load the dataset

In [2]:
mimic_dataset = MIMIC_IV('mimic-iv')
data = mimic_dataset.load_dataset()
class_to_idx, icd_to_class, icd_to_classes = mimic_dataset.convert_class_to_idx()
mimic_dataset.group_multilabel_data()
train_data, val_data, test_data = mimic_dataset.split_dataset()
train_data.head()

MIMIC-IV dataset initialized with name: mimic-iv
Saved sklearn-computed class weights for 3370 classes to /home/jaganathan.p/pretaining-language-models-for-medical-text/dataset/MIMIC-IV/class_weights.json


Unnamed: 0,subject_id,hadm_id,stay_id,text,icd_code,icd_title
0,10733714,20907974,33005972,\nName: ___ Unit No: ___\n ...,"[I269, I959]","[Hypotension, unspecified, Other pulmonary emb..."
1,18171767,22031497,38870680,\nName: ___ Unit No: ___\...,[E106],[Type 1 diabetes mellitus with hyperglycemia]
2,13084154,22214693,30334417,\nName: ___ Unit No: ...,"[R079, R109]","[Chest pain, unspecified, Unspecified abdomina..."
3,12872503,24929324,38402167,\nName: ___ Unit No: ___\...,"[276, 812, 401, E888]","[HYPOKALEMIA, FX UPPER HUMERUS NEC-CL, HYPERTE..."
4,12795168,22548131,33831755,\nName: ___ Unit No: ...,[434],[CEREBRAL ART OCCLUS W/INFARCT]


#### Statistics on the dataset

In [4]:
# Flatten the list of lists into a single list of all ICD codes
all_codes = [code for codes in mimic_dataset.data['icd_code'] for code in codes]

# Count occurrences of each ICD code
code_counts = Counter(all_codes)

# Convert to DataFrame
grouped = pd.DataFrame(code_counts.items(), columns=['icd_code', 'count'])
grouped = grouped.sort_values(by='count', ascending=False)

# Print statistics
print(f"Total classes: {len(grouped)}")
print(f"Total samples: {len(mimic_dataset.data)}")
print(f"Mean samples/class: {grouped['count'].mean():.2f}")
print(f"Median samples/class: {grouped['count'].median():.2f}")
print(f"Min samples/class: {grouped['count'].min()}")
print(f"Max samples/class: {grouped['count'].max()}")

# Print top and bottom 10 classes
print("\nTop 10 most frequent classes:")
print(grouped.head(10).to_string(index=False))

print("\nBottom 10 least frequent classes:")
print(grouped.tail(10).to_string(index=False))

Total classes: 3370
Total samples: 154704
Mean samples/class: 88.93
Median samples/class: 4.00
Min samples/class: 1
Max samples/class: 11440

Top 10 most frequent classes:
icd_code  count
     780  11440
     401  10440
     250   7267
     789   6979
     786   6668
    R060   4759
    R109   4261
     486   3974
     599   3965
    R079   3745

Bottom 10 least frequent classes:
icd_code  count
    K297      1
    T348      1
    W501      1
    Y793      1
    M799      1
    B690      1
    V224      1
    G248      1
    H819      1
    G032      1


#### Load configurations

In [5]:
%load_ext autoreload
%autoreload 2

def load_config(path):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

config = load_config('config/config.yaml')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Tokenize train and valid sets using whitespace tokenizer and store in `.parquet` file

In [13]:
preprocessed_train = mimic_dataset.preprocess(splits=['train'], summary_len=200)
preprocessed_train = preprocessed_train[preprocessed_train['text'].str.strip().astype(bool)]
preprocessed_train.reset_index(drop=True, inplace=True)

mimic_dir = ProjectPaths.DATASET_DIR.value / 'MIMIC-IV' / 'preprocessed_subset'
mimic_dir.mkdir(parents=True, exist_ok=True)
preprocessed_train.to_csv(mimic_dir / 'train.csv', index=False)
preprocessed_train.head()

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=2211), Label(value='0 / 2211'))), …

Unnamed: 0,subject_id,hadm_id,stay_id,text,icd_code,icd_title
0,15642508,25781750,39789671,past medical history significant aortic valve ...,[R509],"[Fever, unspecified]"
1,10149765,26535625,36381439,name unit admission date discharge date date b...,[N631],"[Unspecified lump in the right breast, unspeci..."
2,12166185,26271007,36468972,medicine nightfloat admission note admission d...,[K838],[Other specified diseases of biliary tract]
3,19979469,22114071,36419793,hepatitis c develop obstructive jaundice sever...,"[157, 780, 070]",[UNSPECIFIED VIRAL HEPATITIS C WITHOUT HEPATIC...
4,17691221,29996653,35619578,old female significant pmh right mca stroke sp...,[N179],"[Acute kidney failure, unspecified]"


In [14]:
preprocessed_val = mimic_dataset.preprocess(splits=['valid'], summary_len=200)
preprocessed_val = preprocessed_val[preprocessed_val['text'].str.strip().astype(bool)]
preprocessed_val.reset_index(drop=True, inplace=True)


mimic_dir = ProjectPaths.DATASET_DIR.value / 'MIMIC-IV' / 'preprocessed_subset'
mimic_dir.mkdir(parents=True, exist_ok=True)
preprocessed_val.to_csv(mimic_dir / 'valid.csv', index=False)
preprocessed_val.head()


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=277), Label(value='0 / 277'))), HB…

Unnamed: 0,subject_id,hadm_id,stay_id,text,icd_code,icd_title
0,15401792,27111495,35301659,yo f history l elbow surgery nstemi sp pci 2 s...,[I269],[Other pulmonary embolism without acute cor pu...
1,12429688,22424939,31874067,right handed woman significant past medical hi...,"[787, 780]","[ALTERED MENTAL STATUS , NAUSEA WITH VOMITING,..."
2,18280004,20198748,34991331,parkinsonism recently admit neurology worsen l...,[780],[ALTERED MENTAL STATUS ]
3,14913511,20239006,32165170,ms female past medical history notable afib wa...,"[S121, W183]","[Posterior displaced Type II dens fracture, in..."
4,13878740,20989835,39134102,mr year old man nash cirrhosis cb hepatic ence...,[R418],"[Altered mental status, unspecified]"


In [15]:
train_tokens = mimic_dataset.tokenize(tokenizer_type='whitespace', splits=['train'])

if isinstance(train_tokens, pd.Series):
    df = pd.DataFrame(train_tokens)
    df.columns = ['CONTEXT']
else:
    df = pd.DataFrame(train_tokens, columns=['CONTEXT'])

# Save as a Parquet file
mimic_dir = ProjectPaths.DATASET_DIR.value / 'MIMIC-IV' / 'whitespace_tokenized_200ctx_subset'
mimic_dir.mkdir(parents=True, exist_ok=True)

file_name = "dataset/MIMIC-IV/whitespace_tokenized_200ctx_subset/train.parquet"
df.to_parquet(file_name)
print('Parquet file saved successfully!')

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=2211), Label(value='0 / 2211'))), …

Parquet file saved successfully!


In [16]:
val_tokens = mimic_dataset.tokenize(tokenizer_type='whitespace', splits=['valid'])

if isinstance(val_tokens, pd.Series):
    df = pd.DataFrame(val_tokens)
    df.columns = ['CONTEXT']
else:
    df = pd.DataFrame(val_tokens, columns=['CONTEXT'])

file_name = "dataset/MIMIC-IV/whitespace_tokenized_200ctx_subset/valid.parquet"
df.to_parquet(file_name)
print('Parquet file saved successfully!')

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=277), Label(value='0 / 277'))), HB…

Parquet file saved successfully!


In [6]:
train_tokens = pd.read_parquet("dataset/MIMIC-IV/whitespace_tokenized_200ctx_subset/train.parquet", engine="pyarrow")
# to make it as a list[list[str]]
tokenized_train_corpus = [context.tolist() for context in tqdm(train_tokens['CONTEXT'], 'Docs', len(train_tokens['CONTEXT']))] 
print(f'Number of documents in train corpus: {len(tokenized_train_corpus)}')

Docs: 100%|██████████| 123763/123763 [00:00<00:00, 153429.34it/s]

Number of documents in train corpus: 123763





#### Create vocabulary from the train set

In [7]:
from collections import Counter

min_word_freq = 2
counter = Counter()

for tokens in tokenized_train_corpus:
    counter.update(tokens)

external_vocab = {"<PAD>": 0, "<UNK>": 1}
for token, freq in counter.items():
    if freq >= min_word_freq:
        external_vocab[token] = len(external_vocab)

print(f"External vocab size: {len(external_vocab)}")

External vocab size: 52445


In [8]:
config_glove = config['embedding_models']['glove']
config_glove['external_vocab'] = external_vocab  # Pass the external vocabulary

### Model Training

First, let's create the dataloader with embeddings as features and labels.

In [9]:
class LazyEmbeddingDataset(Dataset):
    def __init__(
            self, 
            file_path, 
            embedding_model, 
            labels,  # list of list of labels
            class_to_idx, 
            return_tokens=True,
            max_seq_len=None):
        """
        Args:
            file_path (str): Path to the Parquet file containing the tokenized text.
            embedding_model: The custom embedding model (e.g., GloVeEmbedding).
            labels (List[List[str]]): Multi-label list for each document.
            class_to_idx (dict): Mapping from class label to integer index.
            max_seq_len (int, optional): Max sequence length for padding/truncating.
        """
        self.file_path = file_path
        table = pq.read_table(self.file_path)
        self.tokenized_corpus = table['CONTEXT']
        self.embedding_model = embedding_model
        self.labels = labels
        self.class_to_idx = class_to_idx
        self.max_seq_len = max_seq_len
        self.return_tokens = return_tokens
        self.num_classes = len(class_to_idx)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        tokens = self.tokenized_corpus[idx].as_py()

        # Get token indices or embeddings
        if self.return_tokens:
            embeddings = self.embedding_model.token_indices(tokens)
        else:
            embeddings = self.embedding_model.embed(tokens)

        seq_len = len(embeddings)
        embeddings = embeddings + [0] * (self.max_seq_len - seq_len) if seq_len < self.max_seq_len else embeddings[:self.max_seq_len]
        embeddings_np = np.array(embeddings, dtype=np.float32)

        # Mask to indicate valid tokens
        mask = np.ones(seq_len, dtype=np.float32)
        if self.max_seq_len is not None and seq_len < self.max_seq_len:
            mask = np.concatenate([mask, np.zeros(self.max_seq_len - seq_len, dtype=np.float32)])

        # Convert label list into multi-hot vector
        label_list = self.labels[idx]
        label_vector = torch.zeros(self.num_classes, dtype=torch.float32)
        for label in label_list:
            if label in self.class_to_idx:
                label_vector[self.class_to_idx[label]] = 1.0

        return (
            torch.tensor(embeddings_np, dtype=torch.float32), 
            torch.tensor(mask, dtype=torch.bool),
            label_vector
        )


def create_lazy_dataloader(file_path, embedding_model, labels, class_to_idx, batch_size, max_seq_len=None, shuffle=False):
    dataset = LazyEmbeddingDataset(file_path, embedding_model, labels, class_to_idx, max_seq_len=max_seq_len, return_tokens=True)
    # return dataset
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=shuffle, 
        num_workers=0, 
        pin_memory=True
        )

In [10]:
glove_embedding_model = GloVeEmbedding(**config['embedding_models']['glove'])

max_seq_len = config['datasets']['mimic-iv']['max_sequence_length']
batch_size = config['training']['hyperparameters']['batch_size']

trainloader = create_lazy_dataloader(
    'dataset/MIMIC-IV/whitespace_tokenized_200ctx_subset/train.parquet', 
    glove_embedding_model, 
    mimic_dataset.train_data['icd_code'],
    class_to_idx, 
    batch_size=batch_size,
    max_seq_len=max_seq_len,
    shuffle=True
)


valloader = create_lazy_dataloader(
    'dataset/MIMIC-IV/whitespace_tokenized_200ctx_subset/valid.parquet', 
    glove_embedding_model, 
    mimic_dataset.val_data['icd_code'],
    class_to_idx, 
    batch_size=batch_size,
    max_seq_len=max_seq_len,
    shuffle = False
    )

#### Execute training pipeline

In [13]:
model_trainer = ModelTrainer(
    config_file='config.yaml',
    pretrained_model_path=ProjectPaths.PROJECT_DIR.value / 'trained_models/models/medal_glove_200ctx_lstm_and_self_attention_model_model.pth'
    )

train_results = model_trainer.train(
    trainloader,
    valloader,
    dataset='mimic-iv',
    embedding_dim=100,
    embedding_model = glove_embedding_model
)

Using pretrained model for training


  self.pretrained_model = torch.load(


------- lstm_and_self_attention --------
{'lstm_units': 2, 'lstm_hidden_dim': 128, 'num_attention_heads': 16, 'dropout': 0.3, 'num_classes': 3370, 'embedding_dim': 100, 'create_embedding_layer': True, 'embedding_model': <src.vectorizer.glove_embeddings.GloVeEmbedding object at 0x2ac4cdc7fbb0>}
Loaded 24 parameters from pretrained model.


Training:   2%|▏         | 32/1934 [00:25<25:05,  1.26it/s]

torch.Size([64, 200, 100])
[NaN DETECTED] NaN found in input tensor x_residual. Shape: torch.Size([64, 200, 100])





AttributeError: 'NoneType' object has no attribute 'size'

## ⚠️ Current Issue: NaNs During Fine-Tuning on MIMIC-IV Notes

While fine-tuning the pretrained LSTM + Self-Attention model (originally trained on the MeDAL dataset) for diagnosis prediction using MIMIC-IV Notes, training fails due to NaNs appearing during the forward pass — specifically during computation of `x_residual` in the feedforward layer.

### What I've Tried:
- Gradient clipping to avoid exploding gradients.
- Verified input tensors are normalized and contain no NaNs.
- Applied class weights in the loss function (`BCEWithLogitsLoss`) for class imbalance.
- Checked model output and target sizes — both match as expected.
- Inspected weights — noticed values shrinking (~0.01), indicating possible vanishing gradients.
- Tried setting a minimum/maximum threshold for gradients — no effect.
- Tried monitoring the weights and see when the value explodes/vanishes. This seems to happen suddenly in just a few steps.

### Status
Still debugging. Root cause is unclear — suspecting instability in the residual connection or attention layer when initialized with pretrained weights.

If you have any ideas or suggestions for fixing this, I'd really appreciate your help!


In [None]:
model_trainer.plot_results(train_results)