In [1]:
import pandas as pds
import numpy as np
import time
import importlib
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
from torchdiffeq import odeint_adjoint as odeint

In [3]:
from torchdiffeq import odeint as dto

In [4]:
import ode_models
importlib.reload(ode_models)

<module 'ode_models' from '/alt/applic/user-maint/zq224/WS/torchdiffeq/examples/ode_models.py'>

In [5]:
device = torch.device('cuda:' + str(1) if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=1)

In [6]:
D_TYPE = torch.float32

## Get Data

In [7]:
dat = pds.read_csv('data/cprd_full_4_markers.csv.gz', compression='gzip')
bio_markers = dat.columns[2:]

for b in bio_markers:
    dat[b] = (dat[b] - dat[b].mean()) / dat[b].std()

bio_markers

Index(['creatinine', 'dbp', 'sbp', 'tchol'], dtype='object')

In [8]:
dat.head()

Unnamed: 0,patid,ts,creatinine,dbp,sbp,tchol
0,1025,0.0,-1.196299,2.501302,1.808223,1.934243
1,1025,2.871233,-1.450525,2.410151,1.584469,1.67512
2,1025,3.19726,-1.450525,0.495972,-0.597133,-0.311484
3,1025,6.841096,-1.23867,1.04288,0.521637,-0.22511
4,1025,8.484932,-1.23867,0.131367,-0.149625,-0.743355


In [9]:
dat_dict = dict()

dat_grouped = dat.groupby('patid')

for name,group in dat_grouped:
    t = group['ts'].values
    x = group[bio_markers].values
    len_t = len(t)
    dim = x.shape[1]
    x_reshaped = x.reshape(len_t, 1, 1, dim)
    dat_dict[str(name)] = dict(t=torch.tensor(t, dtype=D_TYPE), x=torch.tensor(x_reshaped, dtype=D_TYPE))
    

In [45]:
def get_fold(dat_dict, fold=5, seed=666):
    random.seed(seed)
    eids = list(dat_dict.keys())
    eid_set = set(eids)
    random.shuffle(eids)

    fold_list = list()
    for i in range(fold):
        eid_test = eids[i::fold]
        dat_test_fold = { k: dat_dict[k] for k in eid_test }

        eid_remain = list(eid_set - set(eid_test))
        eid_val = eid_remain[::10]
        dat_val_fold = { k: dat_dict[k] for k in eid_val }

        eid_train = list(set(eid_remain) - set(eid_val))

        dat_train_fold = { k: dat_dict[k] for k in eid_train }
        fold_dict = {'train': dat_train_fold, 'val': dat_val_fold, 'test': dat_test_fold}
        fold_list.append(fold_dict)

    return fold_list

# testit
# t = set()
# for i in range(5):
#     print(len(dat_folds[i]['train'].keys()), len(dat_folds[i]['val'].keys()), len(dat_folds[i]['test'].keys()))
#     print(len(set(dat_folds[i]['train'].keys()).union(set(dat_folds[i]['val'].keys()),set(dat_folds[i]['test'].keys()))))
#     t=t.union(set(dat_folds[i]['test'].keys()))
# len(t)

In [46]:
dat_folds = get_fold(dat_dict, fold=5, seed=666)

In [49]:
def get_batch(dat_dict, batch_size, seed=42):
    random.seed(seed)
    eids = random.sample(list(dat_dict.keys()), batch_size)
    t_list = [dat_dict[e]['t'] for e in eids]
    t_max = max([len(x) for x in t_list])
    t_padded = [F.pad(x, (0, t_max - len(x)), "constant", -1.).reshape((-1, 1)) for x in t_list]
    t_tensor = torch.cat(t_padded, dim=1)
    t_mask = t_tensor >= 0
    
    x_list = [dat_dict[e]['x'] for e in eids]
    x_padded = [F.pad(i, (0,0,0,0,0,0,0,t_max - i.shape[0]), "constant", -1.) for i in x_list]
    x_tensor = torch.cat(x_padded, dim=1)
    x_mask = x_tensor >= 0
    x0_tensor = x_tensor[0, ...]
    
    return t_tensor, x_tensor, x0_tensor, t_mask, x_mask, eids

def get_all(dat_dict, seed=42):
    eids = list(dat_dict.keys())
    t_list = [dat_dict[e]['t'] for e in eids]
    t_max = max([len(x) for x in t_list])
    t_padded = [F.pad(x, (0, t_max - len(x)), "constant", -1.).reshape((-1, 1)) for x in t_list]
    t_tensor = torch.cat(t_padded, dim=1)
    t_mask = t_tensor >= 0
    
    x_list = [dat_dict[e]['x'] for e in eids]
    x_padded = [F.pad(i, (0,0,0,0,0,0,0,t_max - i.shape[0]), "constant", -1.) for i in x_list]
    x_tensor = torch.cat(x_padded, dim=1)
    x_mask = x_tensor >= 0
    x0_tensor = x_tensor[0, ...]
    
    return t_tensor, x_tensor, x0_tensor, t_mask, x_mask, eids


In [47]:
t, y, y0, t_mask, y_mask, eids = get_batch(dat_folds[0]['train'], batch_size=70)

# Training

## Baseline - Average

In [48]:
t, y, y0, t_mask, y_mask, eids = get_batch(dat_dict, batch_size=7000)
torch.mean(torch.abs(y))

tensor(0.9626)

## Baseline - LSTM

### LSTM without time info

In [73]:
import baseline_models as baseline
import training_utils

importlib.reload(training_utils)


<module 'training_utils' from '/alt/applic/user-maint/zq224/WS/torchdiffeq/examples/training_utils.py'>

In [74]:
niters = 1000
test_freq = 100
batch_size = 500
input_dim = output_dim = 4
n_hidden = 50

base_lstm = baseline.BaselineLSTM(input_dim, n_hidden, output_dim)
optimizer = optim.Adam(base_lstm.parameters(), lr=1e-3)


In [75]:
def base_lstm_loss_func(t, y, y0, t_mask, y_mask, eids):
    y_pred = base_lstm(y0, t)
    loss = torch.mean(torch.abs(y_pred[y_mask.squeeze()] - y[y_mask]))
    return loss

def base_lstm_save_func():
    model_path = 'models/cprd_lstm_no_time.pth'
    torch.save(base_lstm.state_dict(), model_path)

In [76]:
training_utils.training_loop(niters, 
                    dat_folds[0], 
                    batch_size, 
                    optimizer, 
                    test_freq, 
                    base_lstm_loss_func, 
                    base_lstm_save_func)

Iter 0100 | Total Loss 0.527168
Iter 0200 | Total Loss 0.452818
Iter 0300 | Total Loss 0.420922
Iter 0400 | Total Loss 0.401545
Iter 0500 | Total Loss 0.389461
Iter 0600 | Total Loss 0.382782
Iter 0700 | Total Loss 0.378658
Iter 0800 | Total Loss 0.376341
Iter 0900 | Total Loss 0.373735
Iter 1000 | Total Loss 0.371715


(tensor(0.3717), 62.42432188987732)

### LSTM with time info

In [97]:
niters = 5000
test_freq = 100
batch_size = 500
input_dim = output_dim = 4
n_hidden = 50
model_path = 'models/cprd_lstm.pth'

base_time_lstm = baseline.BaselineTimeLSTM(input_dim, n_hidden, output_dim)

optimizer = optim.Adam(base_time_lstm.parameters(), lr=1e-3)


In [98]:
def base_time_lstm_loss_func(t, y, y0, t_mask, y_mask, eids):
    y_pred = base_time_lstm(y0, t)
    loss = torch.mean(torch.abs(y_pred[y_mask.squeeze()] - y[y_mask]))
    return loss

def base_time_lstm_save_func():
    model_path = 'models/cprd_lstm.pth'
    torch.save(base_time_lstm.state_dict(), model_path)

In [None]:
training_utils.training_loop(niters, 
                    dat_folds[0], 
                    batch_size, 
                    optimizer, 
                    test_freq, 
                    base_time_lstm_loss_func, 
                    base_time_lstm_save_func)

## Vanilla Neural ODE

In [80]:
niters = 100
batch_size = 500
step_size = 1./12
test_freq = 10

# vanila Neural ODE
func0 = ode_models.ODEFunc0(dim_y=4)
optimizer = optim.Adam(func0.parameters(), lr=1e-3)


In [81]:
def Vanilla_ode_loss_func(t, y, y0, t_mask, y_mask, eids):
    pred_y = dto(func0, y0, t, method='euler_par', options={'step_size': step_size})
    loss = torch.mean(torch.abs(pred_y[y_mask] - y[y_mask]))
    return loss

def Vanilla_ode_save_func():
    model_path = 'models/vanilla_ode.pth'
    torch.save(func0.state_dict(), model_path)

In [82]:
training_utils.training_loop(niters, 
                    dat_folds[0], 
                    batch_size, 
                    optimizer, 
                    test_freq, 
                    Vanilla_ode_loss_func, 
                    Vanilla_ode_save_func)

Iter 0010 | Total Loss 0.527961
Iter 0020 | Total Loss 0.456916
Iter 0030 | Total Loss 0.422552
Iter 0040 | Total Loss 0.404244
Iter 0050 | Total Loss 0.394447
Iter 0060 | Total Loss 0.388240
Iter 0070 | Total Loss 0.384390
Iter 0080 | Total Loss 0.381673
Iter 0090 | Total Loss 0.379925
Iter 0100 | Total Loss 0.378580


(tensor(0.3786), 239.4606318473816)

## Augmented ODE

In [87]:
niters = 100
batch_size = 500
step_size = 1./12
test_freq = 10

# augmented Neural ODE
func_aug = ode_models.ODEFuncAug(dim_y=4, dim_aug=4)
optimizer = optim.Adam(func_aug.parameters(), lr=1e-3)


In [88]:
def augmented_ode_loss_func(t, y, y0, t_mask, y_mask, eids):
    y0_aug = F.pad(y0, (0, func_aug.dim_aug, 0, 0, 0, 0), "constant", 0.)
    pred_y = dto(func_aug, y0_aug, t, method='euler_par', options={'step_size': step_size})
    pred_y = pred_y[..., :func_aug.dim_y]
    loss = torch.mean(torch.abs(pred_y[y_mask] - y[y_mask]))
    return loss

def augmented_ode_save_func():
    model_path = 'models/augmented_ode.pth'
    torch.save(func_aug.state_dict(), model_path)

In [89]:
training_utils.training_loop(niters, 
                    dat_folds[0], 
                    batch_size, 
                    optimizer, 
                    test_freq, 
                    augmented_ode_loss_func, 
                    augmented_ode_save_func)

Iter 0010 | Total Loss 0.480909
Iter 0020 | Total Loss 0.426202
Iter 0030 | Total Loss 0.406207
Iter 0040 | Total Loss 0.395764
Iter 0050 | Total Loss 0.390499
Iter 0060 | Total Loss 0.386622
Iter 0070 | Total Loss 0.384260
Iter 0080 | Total Loss 0.382188
Iter 0090 | Total Loss 0.380931
Iter 0100 | Total Loss 0.379610


(tensor(0.3796), 264.3959217071533)

## Higher-Order ODE

In [90]:
niters = 100
batch_size = 500
step_size = 1./12
test_freq = 10

In [91]:
func = ode_models.HigherOrderOde(dat_dict, batch_size=batch_size, dim=4, order=2, hidden_size=50)
func.init_cond_mat.requires_grad = False
optimizer = optim.Adam(func.parameters(), lr=1e-3)


In [92]:
def higher_ode_loss_func(t, y, y0, t_mask, y_mask, eids):
    func.set_init_cond(eids)
    init_zeros = torch.zeros_like(func.init_cond)
    
    pred_y = dto(func, init_zeros, t, method='euler_par', options={'step_size': step_size})
    pred_y_final = (pred_y  + func.init_cond)[..., :func.dim]
    
    loss = torch.mean(torch.abs(pred_y_final[y_mask] - y[y_mask]))
    return loss

def higher_ode_save_func():
    model_path = 'models/higher_ode.pth'
    torch.save(func.state_dict(), model_path)

In [93]:
training_utils.training_loop(niters, 
                    dat_folds[0], 
                    batch_size, 
                    optimizer, 
                    test_freq, 
                    higher_ode_loss_func, 
                    higher_ode_save_func)

Iter 0010 | Total Loss 2.721154
Iter 0020 | Total Loss 2.034932
Iter 0030 | Total Loss 1.702108
Iter 0040 | Total Loss 1.508700
Iter 0050 | Total Loss 1.387567
Iter 0060 | Total Loss 1.299232
Iter 0070 | Total Loss 1.229783
Iter 0080 | Total Loss 1.170778
Iter 0090 | Total Loss 1.124129
Iter 0100 | Total Loss 1.084185


(tensor(1.0842), 254.95494604110718)

In [95]:
torch.mean(torch.abs(func.init_cond_mat[:, :, 4:]))

tensor(0.7702)

In [96]:
torch.mean(torch.abs(func.init_cond_mat[:, :, :4]))

tensor(0.7975)

In [198]:
ii = 0

start = time.time()
for itr in range(1, niters + 1):
    
    t, y, y0, t_mask, y_mask, eids = get_batch(dat_dict, batch_size, itr+500)

    optimizer.zero_grad()
    
    func.set_init_cond(eids)
    init_zeros = torch.zeros_like(func.init_cond)
    
    pred_y = dto(func, init_zeros, t, method='euler_par', options={'step_size': step_size})
    pred_y_final = (pred_y  + func.init_cond)[..., :func.dim]
    
    loss = torch.mean(torch.abs(pred_y_final[y_mask] - y[y_mask]))
    loss.backward()
    optimizer.step()

    if itr % test_freq == 0:
        with torch.no_grad():
            print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
            ii += 1

end = time.time()


Iter 0050 | Total Loss 0.426231
Iter 0100 | Total Loss 0.597094
Iter 0150 | Total Loss 0.583271
Iter 0200 | Total Loss 1.400030
Iter 0250 | Total Loss 0.738257
Iter 0300 | Total Loss 0.420469
Iter 0350 | Total Loss 0.611337
Iter 0400 | Total Loss 0.443519
Iter 0450 | Total Loss 0.491005
Iter 0500 | Total Loss 0.892621
Iter 0550 | Total Loss 0.768308
Iter 0600 | Total Loss 0.525871
Iter 0650 | Total Loss 0.453953
Iter 0700 | Total Loss 1.970526
Iter 0750 | Total Loss 0.988136
Iter 0800 | Total Loss 2.543571
Iter 0850 | Total Loss 1.738551
Iter 0900 | Total Loss 0.450546
Iter 0950 | Total Loss 0.405119
Iter 1000 | Total Loss 0.566877
Iter 1050 | Total Loss 0.449112
Iter 1100 | Total Loss 0.478447
Iter 1150 | Total Loss 0.420794
Iter 1200 | Total Loss 0.375969
Iter 1250 | Total Loss 0.641680
Iter 1300 | Total Loss 0.411319
Iter 1350 | Total Loss 0.385653
Iter 1400 | Total Loss 0.409560
Iter 1450 | Total Loss 0.422992
Iter 1500 | Total Loss 0.390408


In [201]:
model_path = 'models/cprd_ho2.pth'
torch.save(func.state_dict(), model_path)


In [202]:
# func1 = ode_models.HigherOrderOde(dat_dict,batch_size=batch_size, dim=4, order=2, hidden_size=50)
# func1.load_state_dict(torch.load(model_path))
# func1.eval()
# func1.init_cond_mat.requires_grad = False


In [199]:
(end-start)/60

41.17963133653005

In [182]:
itr=498

In [183]:
t, y, y0, t_mask, y_mask, eids = get_batch(dat_dict, batch_size, itr)

In [184]:
eids[0]

'8423206'

In [158]:
dat_dict['14260073']

{'t': tensor([0.0000, 3.1205, 4.7178]),
 'x': tensor([[[[0.4524, 2.7864, 2.9481, 1.3142]]],
 
 
         [[[0.1569, 1.9626, 1.4286, 0.2870]]],
 
 
         [[[0.7479, 1.6880, 1.6537, 0.8006]]]])}

In [156]:
y0[0, ...]

tensor([[0.4524, 2.7864, 2.9481, 1.3142]])

In [168]:
idx = np.array([func.eid_to_id[x] for x in eids])
id_torch = torch.from_numpy(idx)
#  self.init_cond = self.init_cond_mat[id_torch, ...]
id_torch[0]

tensor(5238)

In [169]:
func.init_cond_mat[5238]

tensor([[ 0.4512,  2.7882,  2.9492,  1.3124, -0.0830, -0.2435, -0.4933, -0.3075]],
       grad_fn=<SelectBackward>)

In [161]:
func.init_cond[0]

tensor([[-0.1238, -0.6072, -1.0318,  0.8855, -0.1671,  0.3526,  0.2405, -0.1994]],
       grad_fn=<SelectBackward>)

In [162]:
func.init_cond_mat

torch.Size([7010, 1, 8])

In [192]:
func_untrain = ode_models.HigherOrderOde(dat_dict,batch_size=batch_size, dim=4, order=2, hidden_size=50)


In [193]:
torch.mean(torch.abs(func.init_cond_mat[..., :4] - func_untrain.init_cond_mat[..., :4]))

tensor(0., grad_fn=<MeanBackward0>)

In [166]:
func.init_cond_mat[:5, 0, :4]

tensor([[-0.1154,  0.6854,  1.9798,  2.4262],
        [-1.1180,  0.0436, -0.2041,  0.8874],
        [ 0.4524, -0.5087, -0.8788, -0.2265],
        [ 2.0107, -0.1345,  0.5280,  0.5414],
        [ 0.1557, -0.3421,  0.3031, -0.1577]], grad_fn=<SliceBackward>)

In [167]:
func_untrain.init_cond_mat[:5, 0, :4]

tensor([[-0.0964,  0.6812,  1.9914,  2.4268],
        [-1.1096,  0.0405, -0.2035,  0.8862],
        [ 0.4524, -0.5087, -0.8788, -0.2265],
        [ 2.0144, -0.1426,  0.5282,  0.5438],
        [ 0.1569, -0.3256,  0.3031, -0.1409]], grad_fn=<SliceBackward>)

In [185]:
func.init_cond_mat[:5, 0, -4:]

tensor([[-0.6105,  0.4214,  2.3506, -0.3816],
        [-0.0939,  0.3392,  0.6464, -0.7820],
        [-1.3570,  0.3796, -0.4085,  0.0888],
        [-0.0635, -0.4829, -0.3766, -0.1169],
        [-0.1704,  0.3768, -0.1794,  0.4243]], grad_fn=<SliceBackward>)

In [186]:
func_untrain.init_cond_mat[:5, 0, -4:]

tensor([[-0.5927,  0.4283,  2.3702, -0.4005],
        [-0.0888,  0.3464,  0.6390, -0.7918],
        [-1.3570,  0.3796, -0.4085,  0.0888],
        [-0.0617, -0.4903, -0.3837, -0.1250],
        [-0.1813,  0.3930, -0.1933,  0.4411]], grad_fn=<SliceBackward>)

# old stuff

In [171]:
n_processed = 0
step_size = 1./12

s = time.time()
for eid, v in dat_dict.items():
    t = v['t']
    x = v['x']

    optimizer.zero_grad()
    
    
    func.set_init_cond([eid])
    init_zeros = torch.zeros_like(func.init_cond)
    
#     pred_y = odeint(func, init_zeros, t)
    pred_y = dto(func, init_zeros, t, method='euler', options={'step_size': step_size})
    
    pred_y_final = pred_y  + func.init_cond
    loss = torch.mean(torch.abs(pred_y_final[..., :func.dim] - x))
    loss.backward()
    optimizer.step()
    
    n_processed += 1
    
    if n_processed % 500 == 0:
        print('Processed:', n_processed)
#     if n_processed > 100:
#         break
e = time.time()

Processed: 500
Processed: 1000
Processed: 1500
Processed: 2000
Processed: 2500
Processed: 3000
Processed: 3500
Processed: 4000
Processed: 4500
Processed: 5000
Processed: 5500
Processed: 6000
Processed: 6500
Processed: 7000


In [172]:
e - s

175.36765384674072

In [133]:
init_zeros.requires_grad

True

In [122]:
n_processed

2

In [69]:
func.init_cond

Parameter containing:
tensor([[[100.0000,  91.0000, 128.0000,   6.4000, -11.1027,  -1.3878,   5.5513,
           -2.6369]]], requires_grad=True)

In [62]:
func.init_cond.shape

torch.Size([1, 1, 8])

In [88]:
5 % 3

2

In [102]:
n_processed

31

In [91]:
pred_y = odeint(func, init_zeros, t)

In [135]:
v = dat_dict['1092']

t = v['t']
x = v['x']

optimizer.zero_grad()
func.set_init_cond(eid, t, x)
init_zeros = torch.zeros_like(func.init_cond)




AttributeError: 'HigherOrderOde' object has no attribute 'set_init_cond'

In [99]:
pred_y = odeint(func, init_zeros, t)
pred_y_final = pred_y + func.init_cond


In [100]:
loss = torch.mean(torch.abs(pred_y_final[..., :func.dim] - x))
loss.backward()


In [101]:
optimizer.step()

In [139]:
x[0, ...].shape

torch.Size([1, 1, 4])

In [144]:
eid_list = list(dat_dict.keys())
eid_to_id = dict(zip(eid_list, range(len(eid_list))))


In [174]:
dat_dict['1092']

{'t': tensor([0.0000, 0.8548, 4.8959]),
 'x': tensor([[[[ 86.0000,  86.0000, 174.0000,   7.6000]]],
 
 
         [[[ 74.0000,  90.0000, 210.0000,   7.2000]]],
 
 
         [[[ 88.0000,  82.0000, 140.0000,   5.6000]]]])}

In [175]:
dat_dict['3259']

{'t': tensor([0.0000, 2.3781, 5.7644]),
 'x': tensor([[[[ 62.0000,  79.0000, 135.0000,   5.8000]]],
 
 
         [[[ 57.0000,  88.0000, 162.0000,   3.6000]]],
 
 
         [[[ 74.0000,  78.0000, 146.0000,   4.8000]]]])}

In [165]:
x.shape

torch.Size([3, 1, 1, 4])