In [1]:
"""
Sample from a trained model
"""
import re
from tqdm import tqdm
import os
import pickle
from contextlib import nullcontext
import torch
#import tiktoken
from model_ukb import GPTConfig, GPT
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F

os.chdir('/nfs/research/sds/sds-ukb-cancer/projects/gpt/nanoGPT-healthGPT')
# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-ukb' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 10 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype ='float64'#'bfloat16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
t_min = 100.0
#exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cpu' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

load_meta=True
meta_path='data/ukb/meta.pkl'

In [2]:
#labels = pd.read_csv("data/ukb/fields.txt", header=None).merge(pd.read_csv("data/ukb/icd10_codes_mod.tsv", sep='\t',header=None, index_col=0), left_on=0, right_index=True)
#labels[1] = labels[1].str.replace("Source of report of ","")
#train.max(0)
labels = pd.read_csv("data/ukb/labels.csv", header=None, sep="\t")
labels_short = np.asarray(labels).astype('S3').astype(str)
labels

Unnamed: 0,0
0,Padding
1,Healthy
2,Female
3,Male
4,BMI_low
...,...
1265,D46 Myelodysplastic syndromes
1266,D47 Other neoplasms of uncertain or unknown be...
1267,D48 Neoplasm of uncertain or unknown behaviour...
1268,O01 Hydatidiform mole38


In [3]:
# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    #checkpoint['model_args']['token_dropout'] = 1.0
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

# look for the meta pickle in case it is available in the dataset folder
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    # TODO want to make this more general to arbitrary encoder/decoder schemes
    stoi, itos = meta['stoi'], meta['itos']
    #print(stoi, itos)
    print(checkpoint['model_args']['vocab_size'])
    token_length = int(np.log(checkpoint['model_args']['vocab_size'])/np.log(meta['vocab_size']))
    print(token_length)
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[int(j)] for i in l for j in np.base_repr(i,meta['vocab_size'], token_length)[-token_length:][::-1]])
else:
    # ok let's assume gpt-2 encodings by default
    print("No meta.pkl found, assuming GPT-2 encodings...")
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)
checkpoint['model_args']

number of parameters: 2.24M
Loading meta from data/ukb/meta.pkl...
1270
1


{'n_layer': 12,
 'n_head': 12,
 'n_embd': 120,
 'block_size': 48,
 'bias': False,
 'vocab_size': 1270,
 'dropout': 0.0,
 'token_dropout': 0.0,
 't_min': 0.1,
 'mask_ties': True,
 'ignore_tokens': [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]}

In [4]:
def get_p2i(data):
    px = data[:,0].astype('int')
    pix = sorted(list(set(px)))
    #p2i = np.array([(p, (px==p).argmax(), (px==p).sum()) for i,p in enumerate(pix)])
    p2i = []
    j = 0
    q = px[0]
    for i,p in enumerate(px):
        if p != q:
            p2i.append([j,i-j])
            q = p
            j = i
    return np.array(p2i)

In [5]:
pre = np.fromfile('/nfs/research/sds/sds-ukb-cancer/projects/gpt/data/pre.bin', dtype=np.uint32, ).reshape(-1,3)
pre_p2i = get_p2i(pre)
pre[:, -1] += 1

In [8]:
nextlogits = []
padding='random'
with torch.no_grad():
    for ii in tqdm(range(pre_p2i.shape[0])):
        x = torch.from_numpy((pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2][None, :]).astype(int)).to(device)
        a = torch.from_numpy((pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1][None, :]).astype(int)).to(device)
        if padding == 'regular':
            pad = torch.arange(3652.5/2, 36525, 3652.5/2) * torch.ones(1,1)
        else:
            pad = torch.randint(36525, (1, 20)) + 1

        pad = pad.to(device)
        m = a.max()
        x = torch.hstack([x, torch.ones(1, pad.shape[1], dtype=torch.int).to(device)])
        a = torch.hstack([a, pad])
        
        s = torch.argsort(a, 1)
        x = torch.gather(x,1,s)
        a = torch.gather(a,1,s)
        a = a.type(torch.int32)
        idx = a <= m 
        a = a[idx][None, :]
        x = x[idx][None, :]
        nextlogits.extend(model(x, a)[0][:, -1, :].cpu().numpy().tolist())
nextlogits = np.asarray(nextlogits).astype(np.float32)
nextlogits.tofile(f'/nfs/research/sds/sds-ukb-cancer/projects/gpt/out/nextlogits.bin')
  

100%|██████████████████████████████████| 471057/471057 [44:33<00:00, 176.18it/s]


In [None]:
trajectory = []
padding='random'
nsamples=5
with torch.no_grad():
    for ii in tqdm(range(pre_p2i.shape[0])):
        ll = []
        for _ in range(5):
            x = torch.from_numpy((pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2][None, :]).astype(int)).to(device)
            a = torch.from_numpy((pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1][None, :]).astype(int)).to(device)
            if padding == 'regular':
                pad = torch.arange(3652.5/2, 36525, 3652.5/2) * torch.ones(1,1)
            else:
                pad = torch.randint(36525, (1, 20)) + 1
    
            pad = pad.to(device)
            m = a.max()
            x = torch.hstack([x, torch.ones(1, pad.shape[1], dtype=torch.int).to(device)])
            a = torch.hstack([a, pad])
            
            s = torch.argsort(a, 1)
            x = torch.gather(x,1,s)
            a = torch.gather(a,1,s)
            a = a.type(torch.int32)
            idx = a <= m 
            a = a[idx][None, :]
            x = x[idx][None, :]
            age = a[-1,-1]
            rr = []
            x_, y_, logits = model.generate(idx=x, age=a, max_new_tokens=10, max_age=age+900, temperature=1.0, no_repeat=True) #3.2*365.25
            logits = logits.cpu().numpy()
            logits = logits[0, x.shape[1]-1:, :]
        
            idxinf = logits[0, :] == -np.inf
            logits[logits == -np.inf] = np.nan
            logits[:, idxinf] = -np.inf
            ll.extend(np.nanmean(logits, axis=0)[None, :].tolist())
        trajectory.extend(np.mean(ll, axis=0)[None, :].tolist())
trajectory = np.asarray(trajectory).astype(np.float32)
trajectory.tofile(f'/nfs/research/sds/sds-ukb-cancer/projects/gpt/out/trajectory.bin')
        


 27%|███████▉                      | 124974/471057 [5:15:54<15:11:32,  6.33it/s]

In [None]:
%%bash 
bkill 76298994

In [112]:
%%bash 
bjobs

JOBID      USER    STAT  QUEUE      FROM_HOST   EXEC_HOST   JOB_NAME   SUBMIT_TIME
76298994   alexwju RUN   gpu-a100   codon-login codon-gpu-0 */bin/bash Jan 15 07:59


In [None]:
if run:
    next_probs = []
    with torch.no_grad():
        for ii in tqdm(range(pre_p2i.shape[0])):
            x = torch.from_numpy((pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2][None, :]).astype(int)).to(device)
            a = torch.from_numpy((pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1][None, :]).astype(int)).to(device)
            # next token probs
            nprobs = F.softmax(model(x, a)[0][:, -1, :], dim=-1)
            nprobs = nprobs.cpu().numpy()
            next_probs.extend(nprobs.tolist())     
    next_probs = np.asarray(next_probs).astype(np.float32)
    next_probs.tofile(f'/nfs/research/sds/sds-ukb-cancer/projects/gpt/out/next_probs.bin')
else:
    next_probs = np.fromfile(f'/nfs/research/sds/sds-ukb-cancer/projects/gpt/out/next_probs.bin', dtype=np.float32).reshape(-1, 1270)


traj = []
with torch.no_grad():
    for ii in tqdm(range(pre_p2i.shape[0])):
        x = (pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 2][None, :]).astype(int)
        a = (pre[pre_p2i[ii, 0]:pre_p2i[ii, 0]+pre_p2i[ii, 1], 1][None, :]).astype(int)
        age = a[-1,-1]
        rr = []
        for i in range(5):
            x_, y_, logits = model.generate(idx=torch.from_numpy(x).to(device), age=torch.from_numpy(a).to(device), max_new_tokens=10, max_age=age+900, temperature=1.0, no_repeat=True) #3.2*365.25
            logits = logits.cpu().numpy()
            r = 1 / (1 + np.exp(-logits[0,:,:])/365.25)
            r = r.max(axis=0)[None, :]
            rr.extend(r.tolist())
        rr = np.median(np.asarray(rr), axis=0)[None, :].tolist()
        traj.extend(rr)
traj = np.asarray(traj).astype(np.float32)
traj.tofile(f'/nfs/research/sds/sds-ukb-cancer/projects/gpt/out/trajectory3.bin')
  