<a href="https://colab.research.google.com/github/NoureldinAyman/Drug-Recommendation/blob/nour/V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Personalized Drug Recommendation

## Dataset
The dataset is MIMIC-IV


It is separated into 4 modules:
- [hosp](https://mimic.mit.edu/docs/iv/modules/hosp) - hospital level data for patients: labs, micro, and electronic medication administration
- [icu](https://mimic.mit.edu/docs/iv/modules/icu) - ICU level data. These are the event tables, and are identical in structure to MIMIC-III (chartevents, etc)
- [ed](https://mimic.mit.edu/docs/iv/modules/ed) - data from the emergency department
- [note](https://mimic.mit.edu/docs/iv/modules/note) - deidentified free-text clinical notes

## Data preprocessing


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Setting the seed for reproducibility

In [None]:
np.random.seed(42)
sample_frac = 0.10

Tables used:
1. omr table
  - The Online Medical Record (OMR) table contains miscellaneous information from the EHR.
2. admissions table
	- Detailed information about hospital stays.
3. d_labitems
	- Dimension table for labevents provides a description of all lab items.
4. diagnoses_icd
	- Billed ICD-9/ICD-10 diagnoses for hospitalizations.
5. labevents
	- Laboratory measurements sourced from patient derived specimens.
6. microbiologyevents
	- Microbiology cultures.
7. patients table
	- Patients' gender, age, and date of death if information exists.
8. prescriptions
	- Prescribed medications.

### Loading datasets

In [None]:
# Load datasets with relevant columns
admissions = pd.read_csv("/content/drive/MyDrive/Data/hosp/admissions.csv.gz",
                         compression="gzip",
                         usecols=["subject_id", "hadm_id", "admittime", "admission_type",
                                  "admission_location", "insurance", "language", "marital_status", "race"],
                         nrows=1000)

patients = pd.read_csv("/content/drive/MyDrive/Data/hosp/patients.csv.gz",
                       compression="gzip",
                       usecols=["subject_id", "gender", "anchor_age"],
                       nrows=1000)

omr = pd.read_csv("/content/drive/MyDrive/Data/hosp/omr.csv.gz",
                  compression="gzip",
                  usecols=["subject_id", "chartdate", "seq_num", "result_name", "result_value"],
                  nrows=1000)

labevents = pd.read_csv("/content/drive/MyDrive/Data/hosp/labevents.csv.gz",
                        compression="gzip",
                        usecols=["subject_id", "hadm_id", "itemid", "charttime", "valuenum", "valueuom", "flag"],
                        nrows=1000)

# d_labitems = pd.read_csv("/content/drive/MyDrive/Data/hosp/d_labitems.csv.gz",
#                          compression="gzip",
#                          usecols=["itemid", "label"])

microbiologyevents = pd.read_csv("/content/drive/MyDrive/Data/hosp/microbiologyevents.csv.gz",
                                 compression="gzip",
                                 usecols=["subject_id", "hadm_id", "charttime", "spec_type_desc", "org_name"],
                                 nrows=1000)

diagnoses_icd = pd.read_csv("/content/drive/MyDrive/Data/hosp/diagnoses_icd.csv.gz",
                            compression="gzip",
                            usecols=["subject_id", "hadm_id", "icd_code", "icd_version"],
                            nrows=1000)

prescriptions = pd.read_csv("/content/drive/MyDrive/Data/hosp/prescriptions.csv.gz",
                            compression="gzip",
                            usecols=["subject_id", "hadm_id", "drug", "dose_val_rx", "dose_unit_rx",
                                     "starttime", "stoptime"],
                            nrows=1000)

In [None]:
print(f"admissions shape: {admissions.shape}")
print(f"patients shape: {patients.shape}")
print(f"omr shape: {omr.shape}")
print(f"labevents shape: {labevents.shape}")
# print(f"d_labitems shape: {d_labitems.shape}")
print(f"microbiologyevents shape: {microbiologyevents.shape}")
print(f"diagnoses_icd shape: {diagnoses_icd.shape}")
print(f"prescriptions shape: {prescriptions.shape}")

admissions shape: (1000, 9)
patients shape: (1000, 3)
omr shape: (1000, 5)
labevents shape: (1000, 7)
microbiologyevents shape: (1000, 5)
diagnoses_icd shape: (1000, 4)
prescriptions shape: (1000, 7)


### Merging the datasets

Starting from the base admissions table, I'm going to:
- Create a copy of the base admissions.
- Merge admissions copy with patients on `subject_id` with left join.
- Merge with diagnoses on `subject_id`.
- Merge with prescriptions on `subject_id` and `hadm_id`.
- Merge with labevents on `subject_id` and `hadm_id`.
- Merge with microbiologyevents on `subject_id` and `hadm_id`.

All merges are done with left join to preserve the records in the admissions table.

In [None]:
# Start with admissions as the base table
dataset = admissions.copy()

# Merge with patients using subject_id
dataset = dataset.merge(patients, on='subject_id', how='left')

# Merge with diagnoses_icd using subject_id and hadm_id
dataset = dataset.merge(diagnoses_icd, on=['subject_id', 'hadm_id'], how='left')

# Merge with prescriptions using subject_id and hadm_id
dataset = dataset.merge(prescriptions, on=['subject_id', 'hadm_id'], how='left')

# Merge with labevents using subject_id and hadm_id
dataset = dataset.merge(labevents, on=['subject_id', 'hadm_id'], how='left')

# Merge with microbiologyevents using subject_id and hadm_id
dataset = dataset.merge(microbiologyevents, on=['subject_id', 'hadm_id'], how='left')

# Merge with omr using subject_id
dataset = dataset.merge(omr, on='subject_id', how='left')

Validate the result


In [None]:
print("Shape of merged data:", dataset.shape)
print("Missing values in key columns:")
print(dataset[['subject_id', 'hadm_id']].isnull().sum())

Shape of merged data: (7052413, 30)
Missing values in key columns:
subject_id    0
hadm_id       0
dtype: int64


In [None]:
dataset.columns

Index(['subject_id', 'hadm_id', 'admittime', 'admission_type',
       'admission_location', 'insurance', 'language', 'marital_status', 'race',
       'gender', 'anchor_age', 'icd_code', 'icd_version', 'starttime',
       'stoptime', 'drug', 'dose_val_rx', 'dose_unit_rx', 'itemid',
       'charttime_x', 'valuenum', 'valueuom', 'flag', 'charttime_y',
       'spec_type_desc', 'org_name', 'chartdate', 'seq_num', 'result_name',
       'result_value'],
      dtype='object')

In [None]:
dataset.shape

(7052413, 30)

### Encoding/Scaling

In [None]:
numerical_cols = dataset.select_dtypes(include=["number"]).columns
print("Numerical columns:")
print(numerical_cols)

categorical_cols = dataset.select_dtypes(include=["object", "category"]).columns
print("Categorical columns:")
print(categorical_cols)

Numerical columns:
Index(['subject_id', 'hadm_id', 'anchor_age', 'icd_version', 'itemid',
       'valuenum', 'seq_num'],
      dtype='object')
Categorical columns:
Index(['admittime', 'admission_type', 'admission_location', 'insurance',
       'language', 'marital_status', 'race', 'gender', 'icd_code', 'starttime',
       'stoptime', 'drug', 'dose_val_rx', 'dose_unit_rx', 'charttime_x',
       'valueuom', 'flag', 'charttime_y', 'spec_type_desc', 'org_name',
       'chartdate', 'result_name', 'result_value'],
      dtype='object')


In [None]:
for cat_col in categorical_cols:
    uniques = dataset[cat_col].unique()
    print(f"{cat_col} has {len(uniques)} unique values")
    print(f"Examples: {uniques[:15]}")
    print("-------------------------------------------")

admittime has 1000 unique values
Examples: ['2180-05-06 22:23:00' '2180-06-26 18:27:00' '2180-08-05 23:44:00'
 '2180-07-23 12:35:00' '2160-03-03 23:16:00' '2160-11-21 01:56:00'
 '2160-12-28 05:11:00' '2163-09-27 23:17:00' '2181-11-15 02:05:00'
 '2183-09-18 18:10:00' '2163-08-20 01:42:00' '2192-11-30 01:25:00'
 '2151-03-18 03:28:00' '2189-10-15 10:30:00' '2143-12-23 14:55:00']
-------------------------------------------
admission_type has 9 unique values
Examples: ['URGENT' 'EW EMER.' 'EU OBSERVATION' 'OBSERVATION ADMIT'
 'SURGICAL SAME DAY ADMISSION' 'AMBULATORY OBSERVATION' 'DIRECT EMER.'
 'DIRECT OBSERVATION' 'ELECTIVE']
-------------------------------------------
admission_location has 11 unique values
Examples: ['TRANSFER FROM HOSPITAL' 'EMERGENCY ROOM' 'WALK-IN/SELF REFERRAL'
 'PHYSICIAN REFERRAL' 'PROCEDURE SITE' 'CLINIC REFERRAL'
 'TRANSFER FROM SKILLED NURSING FACILITY' 'PACU'
 'INTERNAL TRANSFER TO OR FROM PSYCH' 'INFORMATION NOT AVAILABLE'
 'AMBULATORY SURGERY TRANSFER']
----

Chart times are represented in year-month-day using strings. I will convert them them to YYYYMM in integers. Where Y is year and M is month.

In [None]:
# Identify time-related columns
time_cols = [col for col in dataset.columns if 'time' in col.lower() or 'date' in col.lower()]
print("Time columns:", time_cols)

# Check original data
print("Original sample of time columns:")
for col in time_cols:
    print(f"{col}: {dataset[col].head().tolist()}")

# Convert to datetime and then to YYYYMM
for col in time_cols:
    # Convert to string to avoid type issues
    dataset[col] = dataset[col].astype(str)
    # Convert to datetime, coerce invalid values to NaT
    dataset[col] = pd.to_datetime(dataset[col], errors="coerce")
    # Convert to YYYYMM, fill NaT with 0, cast to int
    dataset[col] = (dataset[col].dt.year * 100 + dataset[col].dt.month).fillna(0).astype(int)

# Check results
print("Sample of converted columns:")
print(dataset[time_cols].head())

Time columns: ['admittime', 'starttime', 'stoptime', 'charttime_x', 'charttime_y', 'chartdate']
Original sample of time columns:
admittime: ['2180-05-06 22:23:00', '2180-05-06 22:23:00', '2180-05-06 22:23:00', '2180-05-06 22:23:00', '2180-05-06 22:23:00']
starttime: ['2180-05-08 08:00:00', '2180-05-08 08:00:00', '2180-05-08 08:00:00', '2180-05-08 08:00:00', '2180-05-08 08:00:00']
stoptime: ['2180-05-07 22:00:00', '2180-05-07 22:00:00', '2180-05-07 22:00:00', '2180-05-07 22:00:00', '2180-05-07 22:00:00']
charttime_x: ['2180-05-07 00:10:00', '2180-05-07 00:10:00', '2180-05-07 00:10:00', '2180-05-07 00:10:00', '2180-05-07 00:10:00']
charttime_y: ['2180-05-07 00:10:00', '2180-05-07 00:10:00', '2180-05-07 00:10:00', '2180-05-07 00:10:00', '2180-05-07 00:10:00']
chartdate: ['2180-04-27', '2180-04-27', '2180-05-07', '2180-05-07', '2180-05-07']


One hot encoding categorical columns except for admission type as it is ordinal.

"admission_type is useful for classifying the urgency of the admission. There are 9 possibilities: ‘AMBULATORY OBSERVATION’, ‘DIRECT EMER.’, ‘DIRECT OBSERVATION’, ‘ELECTIVE’, ‘EU OBSERVATION’, ‘EW EMER.’, ‘OBSERVATION ADMIT’, ‘SURGICAL SAME DAY ADMISSION’, ‘URGENT’."

In [None]:
filtered_rows = dataset["admission_type"]

filtered_rows.head()

In [None]:
from sklearn.preprocessing import OrdinalEncoder

# One hot encode all categorical columns except for admission type as it is ordinal.
one_hot_cols = [col for col in categorical_cols if col != 'admission_type']

admission_order = ['AMBULATORY OBSERVATION', 'DIRECT EMER.', 'DIRECT OBSERVATION','ELECTIVE',
                   'EU OBSERVATION', 'EW EMER.', 'OBSERVATION ADMIT', 'SURGICAL SAME DAY ADMISSION', 'URGENT']
ordinal_encoder = OrdinalEncoder(categories=[admission_order])
dataset['admission_type_encoded'] = ordinal_encoder.fit_transform(dataset[['admission_type']]).ravel()

dataset = dataset.drop('admission_type', axis=1)

dataset_encoded = pd.get_dummies(dataset, columns=one_hot_cols)

In [None]:
filtered_rows = dataset["admission_type_encoded"]

filtered_rows.head()

Unnamed: 0,admission_type_encoded
0,8.0
1,8.0
2,8.0
3,8.0
4,8.0


See the max and min of numerical features

In [None]:
# Print min and max for each numerical column
for col in numerical_cols.columns:
    print(f"{col}: min = {numerical_cols[col].min()}, max = {numerical_cols[col].max()}")


In [None]:
dataset_encoded.drop(["subject_id", "hadm_id"], axis=1, inplace=True)

In [None]:
dataset_encoded.shape

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder, StandardScaler
import pandas as pd
import numpy as np

# 1) Prepare your dataframe (dataset_encoded from before)
df = dataset_encoded.copy()

# 2) Split features / targets
TARGET_CLF = ['icd_code','drug','dose_unit_rx']
TARGET_REG = ['dose_val_rx','starttime','stoptime']
X = df.drop(columns=TARGET_CLF + TARGET_REG)
y_clf = df[TARGET_CLF].astype(str)   # ensure strings
y_reg = df[TARGET_REG].astype(float) # floats

# 3) Encode categorical targets
le_icd   = LabelEncoder().fit(y_clf['icd_code'])
le_drug  = LabelEncoder().fit(y_clf['drug'])
le_unit  = LabelEncoder().fit(y_clf['dose_unit_rx'])

y_clf_enc = pd.DataFrame({
    'icd_code':    le_icd.transform(y_clf['icd_code']),
    'drug':        le_drug.transform(y_clf['drug']),
    'dose_unit_rx':le_unit.transform(y_clf['dose_unit_rx']),
})

# 4) Scale numeric features & regression targets
num_feats = X.select_dtypes(include=['int64','float64']).columns
scaler_X = StandardScaler().fit(X[num_feats])
X[num_feats] = scaler_X.transform(X[num_feats])

scaler_y = StandardScaler().fit(y_reg)
y_reg_scaled = pd.DataFrame(scaler_y.transform(y_reg), columns=TARGET_REG)

# 5) Build a PyTorch Dataset
class HospDataset(Dataset):
    def __init__(self, X, y_clf, y_reg):
        self.X = torch.from_numpy(X.values).float()
        self.y_clf = torch.from_numpy(y_clf.values).long()
        self.y_reg = torch.from_numpy(y_reg.values).float()
    def __len__(self):
        return len(self.X)
    def __getitem__(self, i):
        return self.X[i], self.y_clf[i], self.y_reg[i]

# train/test split
from sklearn.model_selection import train_test_split
X_tr, X_te, ycl_tr, ycl_te, yr_tr, yr_te = train_test_split(
    X, y_clf_enc, y_reg_scaled, test_size=0.2, random_state=42
)

train_ds = HospDataset(X_tr, ycl_tr, yr_tr)
test_ds  = HospDataset(X_te, ycl_te, yr_te)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=64)

# 6) Define the multi-task model
class MultiTaskNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, out_icd, out_drug, out_unit):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
        )
        # classification heads
        self.head_icd  = nn.Linear(hidden_dim//2, out_icd)
        self.head_drug = nn.Linear(hidden_dim//2, out_drug)
        self.head_unit = nn.Linear(hidden_dim//2, out_unit)
        # regression head
        self.head_reg  = nn.Linear(hidden_dim//2, 3)  # dose_val, start, stop

    def forward(self, x):
        h = self.shared(x)
        icd_logits  = self.head_icd(h)
        drug_logits = self.head_drug(h)
        unit_logits = self.head_unit(h)
        reg_out     = self.head_reg(h)
        return icd_logits, drug_logits, unit_logits, reg_out

# instantiate
model = MultiTaskNet(
    input_dim   = X.shape[1],
    hidden_dim  = 256,
    out_icd     = len(le_icd.classes_),
    out_drug    = len(le_drug.classes_),
    out_unit    = len(le_unit.classes_)
)

# 7) Losses & optimizer
criterion_clf = nn.CrossEntropyLoss()
criterion_reg = nn.MSELoss()
optimizer     = torch.optim.Adam(model.parameters(), lr=1e-3)

# 8) Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

EPOCHS = 10
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for Xb, ycl_b, yr_b in train_loader:
        Xb, ycl_b, yr_b = Xb.to(device), ycl_b.to(device), yr_b.to(device)
        optimizer.zero_grad()
        icd_logits, drug_logits, unit_logits, reg_out = model(Xb)

        # compute losses
        loss_icd  = criterion_clf(icd_logits,  ycl_b[:,0])
        loss_drug = criterion_clf(drug_logits, ycl_b[:,1])
        loss_unit = criterion_clf(unit_logits, ycl_b[:,2])
        loss_reg  = criterion_reg(reg_out, yr_b)

        loss = loss_icd + loss_drug + loss_unit + loss_reg
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS} — loss: {avg_loss:.4f}")

# 9) Quick evaluation on test set
model.eval()
with torch.no_grad():
    correct_icd = correct_drug = correct_unit = total = 0
    mse_reg = 0
    for Xb, ycl_b, yr_b in test_loader:
        Xb, ycl_b, yr_b = Xb.to(device), ycl_b.to(device), yr_b.to(device)
        icd_logits, drug_logits, unit_logits, reg_out = model(Xb)

        # classification accuracy
        pred_icd  = icd_logits.argmax(dim=1)
        pred_drug = drug_logits.argmax(dim=1)
        pred_unit = unit_logits.argmax(dim=1)

        correct_icd  += (pred_icd  == ycl_b[:,0]).sum().item()
        correct_drug += (pred_drug == ycl_b[:,1]).sum().item()
        correct_unit += (pred_unit == ycl_b[:,2]).sum().item()
        total       += Xb.size(0)

        # regression MSE
        mse_reg += criterion_reg(reg_out, yr_b).item() * Xb.size(0)

    print("Test ICD acc:",  correct_icd/total)
    print("Test Drug acc:", correct_drug/total)
    print("Test Unit acc:", correct_unit/total)
    print("Test Reg MSE:", mse_reg/total)
