In [4]:
"""
Sample from a trained model
"""
import re, sys
from tqdm import tqdm
import os
import pickle
#from contextlib import nullcontext
import torch
#import tiktoken
from model_ukb import GPTConfig, GPT
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F
# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-ukb' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 10 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype ='float64'#'bfloat16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
t_min = 100.0
#exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
#ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
load_meta=True
meta_path='data/ukb/meta.pkl'

In [5]:
sys.path.append('/nfs/research/sds/sds-ukb-cancer/projects/gpt/CancerRisk-main/model')
from m1 import predictor

In [6]:
labels = pd.read_csv("data/ukb/labels.csv", header=None, sep="\t")
labels

Unnamed: 0,0
0,Padding
1,Healthy
2,Female
3,Male
4,BMI_low
...,...
1265,D46 Myelodysplastic syndromes
1266,D47 Other neoplasms of uncertain or unknown be...
1267,D48 Neoplasm of uncertain or unknown behaviour...
1268,O01 Hydatidiform mole38


In [107]:
pre = np.fromfile('/nfs/research/sds/sds-ukb-cancer/projects/gpt/data/pre.bin', dtype=np.uint32).reshape(-1,3)

In [108]:
def get_p2i(data):
    px = data[:,0].astype('int')
    pix = sorted(list(set(px)))
    #p2i = np.array([(p, (px==p).argmax(), (px==p).sum()) for i,p in enumerate(pix)])
    p2i = []
    j = 0
    q = px[0]
    for i,p in enumerate(px):
        if p != q:
            p2i.append([j,i-j])
            q = p
            j = i
    return np.array(p2i)

In [109]:
pre_p2i = get_p2i(pre)


In [110]:
pre[:, -1] += 1

In [69]:
age = []
sex = []
for ii in tqdm(range(pre_p2i.shape[0])):
    x = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2].astype(int)
    a = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1].astype(np.float32)
    assert np.logical_or(np.any(x==3), np.any(x==2))
    sex.extend([np.any(x==3).astype(int)])   
    age.extend([a[-1]])
sex = np.asarray(sex).astype(int)
age = np.asarray(age).astype(int)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 471057/471057 [00:20<00:00, 22916.48it/s]


In [70]:

cr_dir = '/nfs/research/sds/sds-ukb-cancer/projects/gpt/CancerRisk-main/'

events = ['oesophagus', 'stomach', 'colorectal', 'liver', 'pancreas', 'lung', 'melanoma', 'breast',
          'cervix_uteri', 'corpus_uteri', 'ovary', 'prostate', 'testis', 'kidney', 'bladder', 'brain',
            'thyroid', 'non_hodgkin_lymphoma', 'multiple_myeloma', 'AML', 'other', 'death']

coxCcodes = ['C15','C16','C18','C22','C25','C34','C43','C50','C53','C54','C56','C61','C62','C64','C65','C70','C73','C82','C90','C91']

for cc in range(22):
    print(cc, events[cc])

disease_codes = np.load(cr_dir + 'model/disease_codes.npy', allow_pickle=True)

gene_names = np.asarray([jj+ii for jj in events[:20] for ii in [' First Degree', ' All', ' Multiple', ' Early']])

bth_names = np.asarray(['Alcoholic', 'Smoker', 'High Blood Pressure', 'Low Blood Pressure', 'Height', 'Weight', 'Age at first Birth'])

dtype = torch.FloatTensor

column_headers = np.concatenate((disease_codes[:, 1], gene_names, bth_names)) # covaraite ordering


0 oesophagus
1 stomach
2 colorectal
3 liver
4 pancreas
5 lung
6 melanoma
7 breast
8 cervix_uteri
9 corpus_uteri
10 ovary
11 prostate
12 testis
13 kidney
14 bladder
15 brain
16 thyroid
17 non_hodgkin_lymphoma
18 multiple_myeloma
19 AML
20 other
21 death


In [71]:
class CIF():
    def __init__(self, cc, tt0, tt_range, A0, pred, sex, full=False, **kwds):
        self.cc = cc
        self.tt0 = tt0
        self.tt_range = tt_range
        self.A0 = A0
        self.pred = pred
        self.sex = sex
        self.full = full # including all otherwise only the 20 major cancer types are considered
        self.sexnames = ['female', 'male']

    def __call__(self, ii):
        A0net = np.sum((self.A0[:, self.sex[ii], self.tt0[ii]:self.tt0[ii]+self.tt_range]*np.exp(self.pred[ii, :, None])), axis=0)
        S0net = np.exp(-np.cumsum(A0net))
        if self.full:
            pest = np.cumsum(self.A0[:20, self.sex[ii], self.tt0[ii]:self.tt0[ii]+self.tt_range]*np.exp(self.pred[ii, :20, None]) * S0net[None, :], axis=1)[:, -1].sum()
        else:
            pest = np.cumsum(self.A0[self.cc, self.sex[ii], self.tt0[ii]:self.tt0[ii]+self.tt_range] * np.exp(self.pred[ii, self.cc, None]) * S0net)[-1]
        return([pest])

In [72]:
A0 = np.load(cr_dir + 'model/breslow.npy')
theta = pd.read_csv(cr_dir + 'model/theta.csv')
theta = np.asarray(theta.iloc[:, 2:68:3])

In [111]:
labels_short = np.asarray(labels).astype('S3').astype(str)

In [74]:
# hazard
N = pre_p2i.shape[0]
coxmap = [np.where(disease_codes[:, 0] == kk)[0].tolist() for kk in np.asarray(labels_short)[:, 0]]
X = np.zeros((N, 1392))
aa = []

for ii in tqdm(range(N)):
    with torch.no_grad():
        a = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1]][:,1][None, :].astype('float32')
        x = pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1]][:,2][None, :].astype('float32')
        for jj in x[0, :].astype(int):
            if coxmap[jj] != []:
                X[ii, coxmap[jj][0]] = 1
        X[ii, -6] = np.any(x[0, :]==9).astype(int)
        X[ii, -7] = np.any(x[0, :]==12).astype(int)
        aa.extend([np.max(a).tolist()])

coxpred = np.matmul(X, theta)
aa = np.asarray(aa).astype(int)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 471057/471057 [00:37<00:00, 12607.88it/s]


In [None]:
# absolute risk
#=======================================================================================================================
risk = []
for ii in tqdm(range(N)):
    rr = []
    for cc in range(20): # loop over each cancer
        cif_ = CIF(cc=cc, tt0=np.minimum(82*365, np.asarray([aa[ii]])), tt_range=int(365*2.2), A0=A0, pred=coxpred[ii, :][None, :], sex=np.asarray([sex[ii]]), full=False)
        rr.extend([cif_(0)])
    risk.extend(np.asarray(rr).T.tolist())
risk = np.asarray(risk)

In [15]:
risk = np.asarray(risk).astype(np.float32)
risk.tofile(f'/nfs/research/sds/sds-ukb-cancer/projects/gpt/out/cox.bin')

In [16]:
exit()