In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.autograd import grad

import matplotlib.pyplot as plt

        
# Download dataset from: https://biostat.app.vumc.org/wiki/pub/Main/DataSets/titanic3.csv
dataset_path = "../../../../data/titanic3.csv"
titanic_data = pd.read_csv(dataset_path)

titanic_data = pd.concat([titanic_data,
                          pd.get_dummies(titanic_data['sex']),
                          pd.get_dummies(titanic_data['embarked'],prefix="embark"),
                          pd.get_dummies(titanic_data['pclass'],prefix="class")], axis=1)
titanic_data["age"] = titanic_data["age"].fillna(titanic_data["age"].mean())
titanic_data["fare"] = titanic_data["fare"].fillna(titanic_data["fare"].mean())
titanic_data = titanic_data.drop(['name','ticket','cabin','boat','body','home.dest','sex','embarked','pclass'], axis=1)

# Set random seed for reproducibility.
np.random.seed(131254)

# Convert features and labels to numpy arrays.
labels = titanic_data["survived"].to_numpy()
titanic_data = titanic_data.drop(['survived'], axis=1)
feature_names = list(titanic_data.columns)
data = titanic_data.to_numpy()

# Separate training and test sets using 
train_indices = np.random.choice(len(labels), int(0.7*len(labels)), replace=False)
test_indices = list(set(range(len(labels))) - set(train_indices))
train_features = data[train_indices]
train_labels = labels[train_indices]
test_features = data[test_indices]
test_labels = labels[test_indices]


torch.manual_seed(1)  # Set seed for reproducibility.
class TitanicSimpleNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(12, 12)
        self.sigmoid1 = nn.Sigmoid()
        self.linear2 = nn.Linear(12, 8)
        self.sigmoid2 = nn.Sigmoid()
        self.linear3 = nn.Linear(8, 2)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        lin1_out = self.linear1(x)
        sigmoid_out1 = self.sigmoid1(lin1_out)
        sigmoid_out2 = self.sigmoid2(self.linear2(sigmoid_out1))
        return self.softmax(self.linear3(sigmoid_out2))
        
        
net = TitanicSimpleNNModel()

criterion = nn.CrossEntropyLoss()
num_epochs = 200

optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
input_tensor = torch.from_numpy(train_features).type(torch.FloatTensor)
label_tensor = torch.from_numpy(train_labels)
for epoch in range(num_epochs):    
    output = net(input_tensor)
    loss = criterion(output, label_tensor)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print ('Epoch {}/{} => Loss: {:.2f}'.format(epoch+1, num_epochs, loss.item()))
        
input_tensor = torch.from_numpy(train_features).type(torch.FloatTensor)
test_input_tensor = torch.from_numpy(test_features).type(torch.FloatTensor)

Epoch 1/200 => Loss: 0.70
Epoch 21/200 => Loss: 0.55
Epoch 41/200 => Loss: 0.50
Epoch 61/200 => Loss: 0.49
Epoch 81/200 => Loss: 0.48
Epoch 101/200 => Loss: 0.49
Epoch 121/200 => Loss: 0.48
Epoch 141/200 => Loss: 0.48
Epoch 161/200 => Loss: 0.48
Epoch 181/200 => Loss: 0.48


In [16]:
train_features.shape

(916, 12)

In [13]:
n_steps = 50
idx = 1

def preds_and_grads(inputs, model, baselines=None, n_steps=50, target=1, numpy=False):
    """
    function to get predictions and gradients of the output wrt features values
    
    Args:
        inputs : Tensor of inputs
        model : pytorch model
        baselines : Tensor or None of baselines 
        n_steps : int number of steps to approximate integrated gradients
        target : int target class
        numpy : bool if true return numpy, else Tensor
    """
    if inputs.dim() == 2:
        inputs = inputs.unsqueeze(0)
    
    if baselines == None:
        baselines = torch.zeros_like(inputs)
        
    print(f"Inputs shape: {inputs.shape}")
    print(f"baselines shape: {baselines.shape}")
    
    # k/m in the formula
    alphas = torch.linspace(0, 1, n_steps).tolist()
    
    print(f"alphas: {alphas}")
    
    # direct path from baseline to input. shape : ([n_steps, n_features], )
    scaled_features = tuple(
            torch.cat(
                [baseline + alpha * (input - baseline) for alpha in alphas], dim=0
            ).requires_grad_()
            for input, baseline in zip(inputs, baselines)
        )
    print(f"scaled features len shape :{scaled_features} {len(scaled_features)} and {scaled_features[0].shape}")
    # predictions at every step. shape : [n_steps, 1]
    preds = model(scaled_features[0])[:, target]
    # gradients of predictions wrt input features. shape : [n_steps, n_features]
    grads = grad(outputs=torch.unbind(preds), inputs=scaled_features)
    print(f"grads shape is: {grads[0].shape}")
    if numpy:
        return preds.detach().numpy(), grads[0].detach().numpy()
    return preds, grads

In [14]:
preds, grads = preds_and_grads(test_input_tensor[idx:idx+1], net, n_steps=n_steps, target=1, numpy=True)

Inputs shape: torch.Size([1, 1, 12])
baselines shape: torch.Size([1, 1, 12])
alphas: [0.0, 0.020408162847161293, 0.040816325694322586, 0.06122449040412903, 0.08163265138864517, 0.10204081237316132, 0.12244898080825806, 0.1428571343421936, 0.16326530277729034, 0.18367347121238708, 0.20408162474632263, 0.22448979318141937, 0.2448979616165161, 0.26530611515045166, 0.2857142686843872, 0.30612242221832275, 0.3265306055545807, 0.34693875908851624, 0.36734694242477417, 0.3877550959587097, 0.40816324949264526, 0.4285714030265808, 0.44897958636283875, 0.4693877398967743, 0.4897959232330322, 0.5102040767669678, 0.5306122303009033, 0.5510203838348389, 0.5714285969734192, 0.5918367505073547, 0.6122449040412903, 0.6326530575752258, 0.6530612707138062, 0.6734694242477417, 0.6938775777816772, 0.7142857313156128, 0.7346939444541931, 0.7551020979881287, 0.7755102515220642, 0.7959184050559998, 0.8163264989852905, 0.8367346525192261, 0.8571428060531616, 0.8775509595870972, 0.8979591727256775, 0.918367326

In [17]:
test_input_tensor[idx:idx+1]

tensor([[ 47.0000,   1.0000,   0.0000, 227.5250,   0.0000,   1.0000,   1.0000,
           0.0000,   0.0000,   1.0000,   0.0000,   0.0000]])