In [1]:
# Get codes for liver subjects

import json

work_loc = '/run/media/anton/Elements/UKB/liverCodes'

# Get number of codes

with open(f'{work_loc}/codes.json', 'r') as f:
    codes = json.load(f)
    D = len(codes)
    print(D)

with open(f'{work_loc}/eidCodesMap.json', 'r') as f:
    eidCodesMap = json.load(f)

print(len(eidCodesMap))

for k in eidCodesMap:
    print(eidCodesMap[k])
    break

subs = sorted(list(eidCodesMap.keys()))
print(subs[:5])

1291
8564
[[888, 1278]]
['1000177', '1000217', '1000411', '1000463', '1000499']


In [2]:
# Get liver data

import pickle

liverData = pickle.load(open(f'{work_loc}/../liver10-15.pkl', 'rb'))

print(type(liverData))

<class 'dict'>


In [3]:
for sub,img in liverData.items():
    print(type(sub), sub, img.shape)
    break

<class 'str'> 1099545 (288, 384)


In [48]:
# 1: Create vectorized data sets, with codes associated stroke and sequelae of stroke (plus COVID)?
# 2: Instead of stroke etc., test with female-related illness
# 3: Try esophageal and gastric diseases (K) among others
# 4: Pregnancy-related codes (better than chance)
# 5: T (Burn codes) [not good, few have burn codes, even random have sensitivity of zero]
# 6: 1174-1207 well balanced, maybe pneumonia among them, not sure what overall theme is (marginally better than chance or the same)
# 7: 1496 1519 kidney stones and urinary disease (better sensitivity, marginally better sp and acc)
# 8: 1519 1538 self poisoning and overdose? (small absolute numbers) [impossible no sensitivity]
# 9: 1538 1558 hypertension and heart disease (fifty percent ones) (early stop around 600 epochs is best, then goes almost to random)

import numpy as np

x = []
y = []
# strokeCodes = ['B97','F05','G31','G81','I63','I67','I69','J12','J69','R26','R29','R41',
#                'R45','R47','R53','R54','U07']
# strokeCodes = ['C50','C54','D25','D27','N70','N72','N73','N80','N81','N83','N84','N85',
#                'N88','N89','N90','N92','N93','N94','N95','Z40','Z42']
strokeCodes = '''E11
E78
I08
I10
I20
I21
I24
I25
I35
I44
I48
I50
I51
N18
R00
R07
Z82
Z86
Z92
Z95'''.split()
strokeCodesInt = [None]*len(strokeCodes)

for i,code in enumerate(codes):
    if code not in strokeCodes:
        continue
    idx = strokeCodes.index(code)
    if idx != -1:
        strokeCodesInt[idx] = i

print(strokeCodesInt)

for i,sub in enumerate(subs):
    img = liverData[sub]
    diag = 0
    for visit in eidCodesMap[sub]:
        for code in visit:
            if code in strokeCodesInt:
                diag = 1
                break
        if diag == 1:
            break
    x.append(img)
    y.append(diag)
    if i % 1000 == 0:
        print(f'Done {i}')

x = np.stack(x)
y = np.array(y)

print(y[:50])   

[200, 224, 387, 389, 394, 395, 398, 399, 407, 414, 418, 420, 421, 705, 870, 877, 1273, 1277, 1283, 1286]
Done 0
Done 1000
Done 2000
Done 3000
Done 4000
Done 5000
Done 6000
Done 7000
Done 8000
[0 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 0 1 0 0 1 1 1 0 1 0 1 0 1 0 1 1 0 1 1 1
 1 0 0 1 0 0 0 0 1 0 0 0 0]


In [49]:
from sklearn.model_selection import train_test_split

xtr, xt, ytr, yt = train_test_split(x, y, stratify=y, train_size=5000)

print(xtr.shape)
print(ytr.shape)

(5000, 288, 384)
(5000,)


In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import confusion_matrix

class LiverCNN(nn.Module):
    def __init__(self):
        super(LiverCNN, self).__init__()
        self.conv1 = nn.Conv2d(1,10,(7,7),stride=3).float().cuda()
        self.conv2 = nn.Conv2d(10,10,(7,7),stride=3).float().cuda()
        self.conv3 = nn.Conv2d(10,10,(7,7),stride=3).float().cuda()
        self.conv4 = nn.Conv2d(10,10,(7,7),stride=3).float().cuda()
        self.mp = nn.MaxPool2d((1,2)).float().cuda()
        self.fc = nn.Linear(10,2).float().cuda()

    def forward(self, x):
        N = x.shape[0]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.conv4(x)
        x = self.mp(x)
        x = x.reshape(N,-1)
        x = self.fc(x)
        return x

cnn = LiverCNN()
optim = torch.optim.Adam(cnn.parameters(), lr=1e-4, weight_decay=1e-4)

nepochs = 3000
bsize = 1000
pperiod = 50

for e in range(nepochs):
    optim.zero_grad()
    starti = (e*bsize) % len(xtr)
    endi = ((e+1)*bsize) % len(xtr)
    if endi < starti:
        endi = len(xtr)
    xx = xtr[starti:endi]
    yy = ytr[starti:endi]
    xx = torch.from_numpy(xx).float().cuda().unsqueeze(1)
    yy = torch.from_numpy(yy).long().cuda()
    logits = cnn(xx)
    loss = F.cross_entropy(logits, yy)
    loss.backward()
    optim.step()
    if e % pperiod == 0 or e == nepochs-1:
        print(f'{e} {float(loss)}')
        
        # Try it on the test set
        with torch.no_grad():
            xx = torch.from_numpy(xt).float().cuda().unsqueeze(1)
            logits = cnn(xx)
            yhat = torch.argmax(logits, dim=1)
            yhat = yhat.cpu().numpy()
            p = np.mean(yt)
            null = p
            if null < 0.5:
                null = 1-null
            acc = np.mean((yt == yhat)+0)
            tn, fp, fn, tp = confusion_matrix(yt, yhat).ravel()
            sp = tn / (tn+fp)
            sn = tp / (tp+fn)
            print(null, acc, sp, sn)
            # Try random guesses with correct proportion of ones
            yhat = np.random.binomial(1, p, size=len(yt))
            tn, fp, fn, tp = confusion_matrix(yt, yhat).ravel()
            sp = tn / (tn+fp)
            sn = tp / (tp+fn)
            acc = np.mean((yt == yhat)+0)
            print(null, acc, sp, sn)
print('Done')
    

0 0.769627034664154
0.5028058361391694 0.5333894500561167 0.8560267857142857 0.20711060948081264
0.5028058361391694 0.49915824915824913 0.5100446428571429 0.4881489841986456
50 0.6771009564399719
0.5028058361391694 0.5507856341189674 0.6138392857142857 0.48702031602708806
0.5028058361391694 0.5115039281705949 0.5184151785714286 0.5045146726862303
100 0.6635559797286987
0.5028058361391694 0.5561167227833894 0.62109375 0.4904063205417607
0.5028058361391694 0.4983164983164983 0.48995535714285715 0.5067720090293454
150 0.6470837593078613
0.5028058361391694 0.5597643097643098 0.5993303571428571 0.5197516930022573
0.5028058361391694 0.494949494949495 0.4921875 0.49774266365688485
200 0.6229677796363831
0.5028058361391694 0.5516273849607183 0.5758928571428571 0.5270880361173815
0.5028058361391694 0.4994388327721661 0.5005580357142857 0.4983069977426637
250 0.5897917151451111
0.5028058361391694 0.5552749719416387 0.5686383928571429 0.5417607223476298
0.5028058361391694 0.5014029180695847 0.513