In [1]:
"""
Sample from a trained model
"""
import re
from tqdm import tqdm
import os
import pickle
from contextlib import nullcontext
import torch
#import tiktoken
from model_ukb import GPTConfig, GPT
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
from torchmetrics import CalibrationError
# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-ukb' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 10 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype ='float64'#'bfloat16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
t_min = 100.0
#exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cpu' #'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

load_meta=True
meta_path='data/ukb/meta.pkl'

%matplotlib inline
%config InlineBackend.figure_format='retina'
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams.update({'axes.grid': True,
                     'grid.linestyle': ':',
                     'axes.spines.bottom': False,
                     'font.size': 16,
          'axes.spines.left': False,
          'axes.spines.right': False,
          'axes.spines.top': False})
plt.rcParams['figure.dpi']=72

run=False

In [2]:
#labels = pd.read_csv("data/ukb/fields.txt", header=None).merge(pd.read_csv("data/ukb/icd10_codes_mod.tsv", sep='\t',header=None, index_col=0), left_on=0, right_index=True)
#labels[1] = labels[1].str.replace("Source of report of ","")
#train.max(0)
labels = pd.read_csv("data/ukb/labels.csv", header=None, sep="\t")
labels_long = np.asarray(labels)[:, 0]
labels_short = np.asarray(labels).astype('S3').astype(str)
labels

Unnamed: 0,0
0,Padding
1,Healthy
2,Female
3,Male
4,BMI_low
...,...
1265,D46 Myelodysplastic syndromes
1266,D47 Other neoplasms of uncertain or unknown be...
1267,D48 Neoplasm of uncertain or unknown behaviour...
1268,O01 Hydatidiform mole38


In [3]:
def get_p2i(data):
    px = data[:,0].astype('int')
    pix = sorted(list(set(px)))
    #p2i = np.array([(p, (px==p).argmax(), (px==p).sum()) for i,p in enumerate(pix)])
    p2i = []
    j = 0
    q = px[0]
    for i,p in enumerate(px):
        if p != q:
            p2i.append([j,i-j])
            q = p
            j = i
    return np.array(p2i)

In [4]:
train = np.fromfile('/nfs/research/sds/sds-ukb-cancer/projects/gpt/data/pre_wd.bin', dtype=np.uint32).reshape(-1,3)
train = np.concatenate((train, np.asarray([[9999999, 0, 0]])))
train_p2i = get_p2i(train)
train[:, -1] += 1
train = train[:-1]

In [5]:
age = []
sex = []
sexext = []
for ii in tqdm(range(train_p2i.shape[0])):
    x = train[train_p2i[ii, 0]:train_p2i[ii, 0]+train_p2i[ii, 1], 2].astype(int)
    a = train[train_p2i[ii, 0]:train_p2i[ii, 0]+train_p2i[ii, 1], 1].astype(np.float32)
    age.extend([a.max()])
    sex.extend([np.any(x==3).astype(int)])   
    sexext.extend([np.repeat(np.any(x==3).astype(int), x.shape[0]).tolist()])   
sexext = np.concatenate(sexext)
sex = np.asarray(sex).astype(int)
age = np.asarray(age)

100%|████████████████████████████████| 502309/502309 [00:19<00:00, 26321.16it/s]


In [6]:
class A0_fun():
    def __init__(self, tt, basehaz):
        self.tt = tt
        self.basehaz = basehaz
        
    def __call__(self, ii):
        if np.sum(ii > self.tt) == 0:
            return(0)
        elif np.sum(ii > self.tt) <= len(self.tt):
            return(self.basehaz[np.sum(ii > self.tt)-1][0])
        else:
            return(self.basehaz[-1][0])

def agesex(jj, X, amax):
    idx = X[:, -1]==jj
    if idx.sum() > 100:
        tt = []
        basehaz = []
        for aa in np.sort(X[idx, 1]):
            tt.extend([aa])
            basehaz.extend([1/(aa<=amax).sum()])
        tt = np.asarray(tt)
        basehaz = np.asarray(basehaz)
        delta_time =[]
        for ii in np.arange(0, tt.shape[0]-1):
            delta_time.append(tt[ii+1] - tt[ii])
        delta_time.append(0.1)
        delta_time = np.asarray(delta_time)[:, None]
        delta_time = np.asarray([np.sum(delta_time[ii==tt], axis=0) for ii in np.unique(tt)])
        basehaz = np.asarray([np.sum(basehaz[ii==tt], axis=0) for ii in np.unique(tt)])[:, None]
        tt = np.unique(tt)
        basehaz = basehaz/delta_time
        A0 = A0_fun(tt=tt[:-2], basehaz=basehaz[:-2])
        A0_eval = np.asarray([A0(ii) for ii in range(36500)])
    else:
        A0_eval = np.zeros((36500,))
    return(A0_eval)

In [7]:
A = []
for jj in tqdm(np.arange(13,1270)):
    a = agesex(jj=jj, X=train[sexext==0], amax=age[sex==0])
    b = agesex(jj=jj, X=train[sexext==1], amax=age[sex==1])
    A.extend([[a, b]]) 
A = np.stack((A)).astype(np.float32)

  7%|██▋                                      | 83/1257 [02:05<29:29,  1.51s/it]


KeyboardInterrupt: 

In [None]:
pre = np.fromfile('/nfs/research/sds/sds-ukb-cancer/projects/gpt/data/pre.bin', dtype=np.uint32).reshape(-1,3)
pre_p2i = get_p2i(pre)
pre[:, -1] += 1

In [None]:
# get age and sex
age = []
sex = []
for ii in tqdm(range(pre_p2i.shape[0])):
    x = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2].astype(int)
    a = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1].astype(np.float32)
    assert np.logical_or(np.any(x==2), np.any(x==3))
    sex.extend([np.any(x==3).astype(int)])   
    age.extend([a[-1]])
sex = np.asarray(sex).astype(int)
age = np.asarray(age).astype(int)

In [None]:
baserate = []
for ii in tqdm(range(pre_p2i.shape[0])):
    br = A[:, sex[ii], age[ii]:age[ii]+900].sum(axis=1)
    br = 1-np.exp(-br)
    baserate.extend([br.tolist()])


baserate = np.asarray(baserate).astype(np.float32)


In [None]:
baserate.tofile(f'/nfs/research/sds/sds-ukb-cancer/projects/gpt/out/baserate.bin')