# Train a neural network to predict Therapeutic Dose of Warfarin that achieves a given INR
---
#### This notebook uses PyTorch to train a feed‑forward network
  
## How it works
---
* Read the CSV, treating Therapeutic Dose of Warfarin (mg/week) as the target.  

* All other columns are used as features, including the patient's INR measured on their current dose. At inference time you can set the INR column to the value you *want* the patient to reach (e.g. 2.5) and obtain a recommended weekly dose.  

* The numeric features are z‑score normalised with `sklearn.StandardScaler`.  

* A simple fully‑connected network (3 hidden layers) is trained with mean‑squared‑error (MSE) loss.  

* Validation metrics (RMSE & MAE) are printed every epoch.  

* The best model (lowest validation RMSE) is saved to `best_model.pt`.  

---
# LIBRARY
---

In [1]:
from typing import Dict, List
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler

from sklearn.metrics import root_mean_squared_error as rmse
from sklearn.metrics import mean_absolute_error as mae

from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch

---
# CLASS
---

## Data wrapper around ndarray

In [2]:
class WarfarinDataset(Dataset):
    """Tensor-ready wrapper for (X, y) numpy arrays."""

    def __init__(self, X:np.ndarray, y:np.ndarray):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## Feed Forward Neural Network

In [3]:
class FeedForwardNN(nn.Module):
    # 512‑256‑128‑64 MLP with BatchNorm & LeakyReLU.
    # Layer‑by‑layer breakdown
    # -----------------------
    # * Linear(in_dim → 512) – wide first layer captures pairwise feature interactions.
    # * BatchNorm1d(512) – normalises activations, allowing higher LR andfaster convergence.
    # * LeakyReLU() – avoids dead neurons with sparse inputs.
    # * Dropout(0.3) – 30 % dropout for regularisation.
    # * Linear(512 → 256) – compress learned representation.
    # * BatchNorm1d(256) – keep scale/shift healthy after compression.
    # * LeakyReLU() – another non‑linear mix.
    # * Linear(256 → 128) – further compression.
    # * BatchNorm1d(128) – stabilise.
    # * LeakyReLU() – idem.
    # * Linear(128 → 64) – penultimate dense layer.
    # * LeakyReLU() – final activation.
    # * Linear(64 → 1) – output single value log(dose).
    #
    # All `nn.Linear` layers are initialised with Kaiming normal weights and zero biases (see `_init_weights`) to match the LeakyReLU activations.

    def __init__(self, in_dim:int):
        super().__init__()
        self.NN = nn.Sequential(
            # First hidden layer – wide so the model can learn many feature interactions right away
            nn.Linear(in_dim, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(),
            nn.Dropout(0.3), # light regularization

            # Second hidden layer – compress representation
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),

            # Third hidden layer – compress representation
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(),

            # Fourth hidden layer – final non‑linear mixing
            nn.Linear(128, 64),
            nn.LeakyReLU(),

            # Output layer – single continuous value (mg/week)
            nn.Linear(64, 1), # output layer
        )

        # Apply good weight initialisation across all sub‑modules
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        """Kaiming-uniform initialisation suited for LeakyReLU."""
        if isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
            nn.init.zeros_(m.bias)
    
    def forward(self, X:torch.Tensor) -> torch.Tensor:
        return self.NN(X)

## Trainer

In [4]:
"""------------------ CONFIGURE ------------------"""
MODEL_WEIGHTS: str = 'best_nn.pt'  # where to save checkpointed weights
SCALER_FILE: str = 'scaler.pkl'    # where to save fitted StandardScaler
RANDOM_STATE: int = 42
TEST_SIZE: float = 0.2

In [5]:
def train(df:pd.DataFrame, target_col:str, epochs:int=100, batch_size:int=64, lr:float=1e-3):
    """------------------ 1) Load data ------------------"""
    print('[*] Loading data...')
    # Split dataset into features and target
    X = df.drop(columns=[target_col]).values.astype(np.float32)
    y = df[target_col].values.astype(np.float32)

    # Train:Test -> 8:2
    X_train, X_test, y_train, y_test = train_test_split(
        X, y,
        test_size=TEST_SIZE,
        random_state=RANDOM_STATE,
    )

    """"------------------ 2) Fit scaler ------------------"""
    print('[*] Scaling features...')
    scaler = StandardScaler()
    # scaler = MinMaxScaler()

    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.fit_transform(X_test)

    # Persist scaler to re‑use at inference time
    with open(SCALER_FILE, 'wb') as filo:
        pickle.dump(scaler, filo)
    

    """------------------ 3) Build datasets/loaders ------------------"""
    print('[*] Wrapping train/test datasets...')
    train_ds = WarfarinDataset(X_train_scaled, y_train)
    test_ds = WarfarinDataset(X_test_scaled, y_test)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    """------------------ 4) Model/optimiser ------------------"""
    print('[*] Initializing neural network...')
    device = torch.device('cuda' if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    print(f'\tUsing {device}.')

    model = FeedForwardNN(X_train.shape[1]).to(device)
    print(f'\tModel:\n{model}\n')

    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=10, factor=0.5)
    print(
        f'\tCirterion:\n{criterion}\n\n'
        f'\tOptimizer:\n{optimizer}\n\n'
        f'\tScheduler:\n{scheduler}\n\n'
    )

    """------------------ 5) Training loop ------------------"""
    print(f'[*] Start training({epochs} epochs):')
    best_test_rmse = float('inf')
    for epoch in range(1, epochs + 1):
        # ---- train ----
        model.train()
        train_losses: List[float] = []
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            
            optimizer.zero_grad()
            y_preds = model(x_batch)
            loss = criterion(y_preds, y_batch)
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
        
        # ---- test ----
        model.eval()
        test_losses: List[float] = []
        test_preds, test_targets = [], []
        with torch.no_grad():
            for x_batch, y_batch in test_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                
                y_preds = model(x_batch)
                loss = criterion(y_preds, y_batch)
                
                test_losses.append(loss.item())
                test_preds.append(y_preds.cpu().numpy())
                test_targets.append(y_batch.cpu().numpy())
        
        test_preds_np = np.concatenate(test_preds).squeeze()
        test_targets_np = np.concatenate(test_targets).squeeze()
        test_rmse = rmse(test_targets_np, test_preds_np)
        test_mae = mae(test_targets_np, test_preds_np)
        
        print(f'\r\tEpoch {epoch:03d}: Train MSE = {np.mean(train_losses):.4f} | Test RMSE = {test_rmse:.4f} | Test MAE = {test_mae:.4f}', end='')

        # Plateau scheduler – auto LR decay if progress stalls
        scheduler.step(test_rmse)

        # Checkpoint if this epoch is best so far
        if test_rmse < best_test_rmse:
            best_test_rmse = test_rmse
            torch.save(model.state_dict(), MODEL_WEIGHTS)
        
    """------------------ 6) Done ------------------"""
    print(
        f'\n[*] Training complete, best validation RMSE: {best_test_rmse:.4f}.\n'
        f'[*] Model saved to {MODEL_WEIGHTS}\n'
        f'[*] Scaler saved to {SCALER_FILE}'
    )

---
# TRAIN NEURAL NETWORK
---

In [6]:
DATA_CSV: str = '../datasets/NN_Training_Data.csv'
TARGET_COLUMN: str = 'Therapeutic Dose of Warfarin'

In [7]:
df = pd.read_csv(DATA_CSV)
df

Unnamed: 0,INR on Reported Therapeutic Dose of Warfarin,Therapeutic Dose of Warfarin,Weight (kg),Height (cm),Valve Replacement,Gender_female,Gender_male,Age <50,Age 50-69,Age >=70,Fluvastatin (Lescol),Aspirin,Acetaminophen or Paracetamol (Tylenol),Simvastatin (Zocor),Cardiomyopathy/LV Dilation,Rifampin or Rifampicin,Diabetes,Anti-fungal Azoles
0,2.60,49.00,115.70,193.04,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2.15,42.00,144.20,176.53,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,1.90,53.00,77.10,162.56,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,2.40,28.00,90.70,182.24,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1.90,42.00,72.60,167.64,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3984,3.00,35.00,86.36,165.10,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
3985,2.80,27.51,55.91,160.02,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
3986,2.00,57.47,97.73,187.96,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3987,2.00,70.00,87.27,177.80,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [8]:
train(
    df=df,
    target_col=TARGET_COLUMN,
    # epochs=500,
)

[*] Loading data...
[*] Scaling features...
[*] Wrapping train/test datasets...
[*] Initializing neural network...
	Using mps.
	Model:
FeedForwardNN(
  (NN): Sequential(
    (0): Linear(in_features=17, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01)
    (7): Linear(in_features=256, out_features=128, bias=True)
    (8): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): LeakyReLU(negative_slope=0.01)
    (10): Linear(in_features=128, out_features=64, bias=True)
    (11): LeakyReLU(negative_slope=0.01)
    (12): Linear(in_features=64, out_features=1, bias=True)
  )
)

	Cirterion:
MSELoss()

	Optimizer:
AdamW (
Par