In [None]:
"""
Sample from a trained model
"""
import re, sys
from tqdm import tqdm
import os
import pickle
#from contextlib import nullcontext
import torch
#import tiktoken
from model_ukb import GPTConfig, GPT
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import Predictive
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

import probcox as pcox

def predictor(data):
    theta =  pyro.sample("theta", dist.StudentT(1, loc=0, scale=0.001).expand([data[1].shape[1], 1])).type(torch.float32)
    pred = torch.mm(data[1], theta)
    return(pred)

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-ukb' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 10 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype ='float64'#'bfloat16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
t_min = 100.0
#exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
#ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
load_meta=True
meta_path='data/ukb/meta.pkl'

In [None]:
n_subsample = 1024
target = 'E11'
iter_ = 1000

In [49]:
labels_short[215][0]

'E11'

In [62]:
np.where(labels_short=='M10')[0][0]

783

In [None]:
labels = pd.read_csv("data/ukb/labels.csv", header=None, sep="\t")
labels_long = np.asarray(labels)
labels_short = np.asarray(labels).astype('S3').astype(str)
labels

Unnamed: 0,0
0,Padding
1,Healthy
2,Female
3,Male
4,BMI_low
...,...
1265,D46 Myelodysplastic syndromes
1266,D47 Other neoplasms of uncertain or unknown be...
1267,D48 Neoplasm of uncertain or unknown behaviour...
1268,O01 Hydatidiform mole38


In [None]:
pre = np.fromfile('/nfs/research/sds/sds-ukb-cancer/projects/gpt/data/pre.bin', dtype=np.uint32).reshape(-1,3)

In [None]:
def get_p2i(data):
    px = data[:,0].astype('int')
    pix = sorted(list(set(px)))
    #p2i = np.array([(p, (px==p).argmax(), (px==p).sum()) for i,p in enumerate(pix)])
    p2i = []
    j = 0
    q = px[0]
    for i,p in enumerate(px):
        if p != q:
            p2i.append([j,i-j])
            q = p
            j = i
    return np.array(p2i)

In [None]:
pre_p2i = get_p2i(pre)


In [None]:
pre[:, -1] += 1

In [None]:
def longformat(pre, pre_p2i, sindx, labels, target):
    
    time = []
    rindx = []
    cindx = []
    values = []
    rmax = 0
    for ii in sindx:
        x = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2].astype(int)
        a = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1].astype(np.float32)
        a[-1] = a[-1] + 1
        
        eindx = np.where(target == labels)[0]
        event = np.any(x == eindx)
        
        if event: 
            event_age = a[np.where(x == eindx)[0][0]]
            indx = a < event_age
            x = x[indx]
            a = a[indx]
            a = np.concatenate((a, np.asarray([event_age])))

        au = np.unique(a)
        time_ = np.concatenate((au[:-1, None], au[1:, None], np.zeros((au.shape[0]-1, 1))), axis=1)
        
        if event: 
            cindx_ = x.tolist()
            time_[-1, -1] = 1
        else:
            cindx_ = (x[:-1]).tolist()
            
        time_ = time_.tolist()
        rindx_ = np.asarray([np.where(aa==au)[0][0] for aa in a[:-1]])
        rindx_max = np.max(rindx_)
        
        # expand cumulative 
        cindx_cum = []
        rindx_cum = []
        for kk in range(len(rindx_)):
            rrep = np.arange(rindx_[kk], rindx_max+1)
            rindx_cum.extend(rrep.tolist())
            cindx_cum.extend(np.repeat(cindx_[kk], rrep.shape[0]).tolist())
        rindx_cum = np.asarray(rindx_cum)

        rindx_cum = rindx_cum + rmax
        rmax = np.max(rindx_cum)+1
        rindx_cum = rindx_cum.tolist()
        values_ = np.ones((len(rindx_cum), )).tolist()
        time.extend(time_)
        rindx.extend(rindx_cum)
        cindx.extend(cindx_cum)
        values.extend(values_)

    time = torch.tensor(time).type(torch.float32)
    Xs = torch.sparse_coo_tensor(torch.cat((torch.tensor(rindx)[None, :], torch.tensor(cindx)[None, :])), torch.tensor(values), (rmax, labels.shape[0]))

    return(time, Xs)

    

In [None]:
# Inference
#=======================================================================================================================
pyro.clear_param_store()
sampling_props = [pre_p2i.shape[0], np.where(target == labels_short)[0][0], n_subsample, None]
m = pcox.PCox(sampling_proportion=sampling_props, predictor=predictor)
m.initialize(eta=0.01, num_particles=1, rank=10) 

loss=[0]
for _ in tqdm(range(iter_)):
    sindx = np.random.choice(range(pre_p2i.shape[0]), n_subsample, replace=False)
    data = longformat(pre=pre, pre_p2i=pre_p2i, sindx=sindx, labels=labels_short, target=target)
    loss.append(m.infer(data=data))
            
g = m.return_guide()
theta_est = g.quantiles([0.025, 0.5, 0.975])

  'with `validate_args=False` to turn off validation.')
  self.sampling_proportion[0] = torch.tensor([self.sampling_proportion[0]])
  self.sampling_proportion[0] = torch.tensor([self.sampling_proportion[0]])
  self.sampling_proportion[1] = torch.tensor([self.sampling_proportion[1]])
  self.sampling_proportion[1] = torch.tensor([self.sampling_proportion[1]])
  self.sampling_proportion[2] = torch.tensor([self.sampling_proportion[2]])
  self.sampling_proportion[2] = torch.tensor([self.sampling_proportion[2]])
  censor_ratio = torch.tensor([self.sampling_proportion[0]/self.sampling_proportion[1]]).type(self.dtype)
  censor_ratio = torch.tensor([self.sampling_proportion[0]/self.sampling_proportion[1]]).type(self.dtype)
  uncensored_ratio = torch.tensor([self.sampling_proportion[2]/self.sampling_proportion[3]]).type(self.dtype)
  uncensored_ratio = torch.tensor([self.sampling_proportion[2]/self.sampling_proportion[3]]).type(self.dtype)
100%|██████████| 1000/1000 [1:29:47<00:00,  5.39s/it]


In [42]:
out = np.concatenate((theta_est['theta'][0].detach().numpy(), theta_est['theta'][1].detach().numpy(), theta_est['theta'][2].detach().numpy()), axis=1).astype(np.float32)
out.tofile(f'./out_cox/param/theta_{target}')

In [None]:
time = []
pred = []

for sindx in tqdm(np.array_split(np.arange(pre_p2i.shape[0]), 1000)):
    data = longformat(pre=pre, pre_p2i=pre_p2i, sindx=sindx, labels=labels_short, target=target)
    with torch.no_grad():
        pp = torch.sparse.mm(data[1], theta_est['theta'][1])
    time.extend(data[0].numpy().tolist())
    pred.extend(pp.numpy().tolist())
time = np.asarray(time)
pred = np.asarray(pred)

100%|██████████| 1000/1000 [19:00<00:00,  1.14s/it]


In [29]:
def Breslow(times, pred):
    times[times[:, -1]==1, 1] = times[times[:, -1]==1, 1] - 0.0000001
    event_times = times[times[:, -1] ==1, 1]
    event_times = event_times[np.argsort(event_times)]
    a0 = [0]
    for ii in tqdm(range(event_times.shape[0])):
        risk_set = (times[:, 0] < event_times[ii]) * (event_times[ii] <= times[:, 1])
        a0.append(1/np.sum(np.exp(pred[risk_set])))
    return(event_times, np.asarray(a0[1:]))

class A0_fun():
    def __init__(self, tt, basehaz):
        self.tt = tt
        self.basehaz = basehaz
        
    def __call__(self, ii):
        
        if np.sum(ii > self.tt) == 0:
            return(0)
        elif np.sum(ii > self.tt) <= len(self.tt):
            return(self.basehaz[np.sum(ii > self.tt)-1][0])
        else:
            return(self.basehaz[-1][0])

def absolute_risk(tt0, tt_range, A0, pred):
    A0net = np.sum(A0[tt0:tt0+tt_range]*np.exp(pred))
    S0net = 1-np.exp(-np.cumsum(A0net))
    return([S0net])

In [None]:
tt, basehaz = Breslow(times=time, pred=pred)    
delta_time =[]
for jj in np.arange(0, tt.shape[0]-1):
    delta_time.append(tt[jj+1] - tt[jj])
delta_time.append(0.1)
delta_time = np.asarray(delta_time)[:, None]
delta_time = np.asarray([np.sum(delta_time[jj==tt], axis=0) for jj in np.unique(tt)])
basehaz = np.asarray([np.sum(basehaz[jj==tt], axis=0) for jj in np.unique(tt)])[:, None]
tt = np.unique(tt)
basehaz = basehaz/delta_time
A0 = A0_fun(tt=tt[:-2], basehaz=basehaz[:-2])
A0_eval = np.asarray([A0(ii) for ii in range(36500)])

  7%|▋         | 2629/35403 [05:42<1:08:47,  7.94it/s]

In [22]:
theta = theta_est['theta'][1].detach().numpy()

In [23]:
age = []
sex = []
for ii in tqdm(range(pre_p2i.shape[0])):
    x = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2].astype(int)
    a = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1].astype(np.float32)
    assert np.logical_or(np.any(x==3), np.any(x==2))
    sex.extend([np.any(x==3).astype(int)])   
    age.extend([a[-1]])
sex = np.asarray(sex).astype(int)
age = np.asarray(age).astype(int)

100%|██████████| 471057/471057 [00:19<00:00, 23641.13it/s]


In [24]:
# hazard
N = pre_p2i.shape[0]
X = np.zeros((N, labels_short.shape[0]))
aa = []
for ii in tqdm(range(N)):
    with torch.no_grad():
        a = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1]][:,1][None, :].astype('float32')
        x = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1]][:,2][None, :].astype('float32')
        for jj in x[0, :].astype(int):
            X[ii, jj] = 1
        aa.extend([np.max(a).tolist()])
coxpred = np.matmul(X, theta)
aa = np.asarray(aa).astype(int)

100%|██████████| 471057/471057 [00:21<00:00, 21761.29it/s]


In [30]:
# absolute risk
risk = []
for ii in tqdm(range(N)):
    risk.extend(absolute_risk(tt0=np.minimum(80*365, np.asarray([aa[ii]]))[0], tt_range=900, A0=A0_eval, pred=coxpred[ii]))
risk = np.asarray(risk)

100%|██████████| 471057/471057 [00:18<00:00, 25492.50it/s]


In [44]:
risk.astype(np.float32)
risk.tofile(f'./out_cox/pred/{target}')

In [16]:
exit()