In [1]:
# Add age and sex to fundus data

import pandas as pd

demofile = '/run/media/anton/Elements/UKB/AgeSexDemographics_participant.csv'

df = pd.read_csv(demofile)
df

Unnamed: 0,eid,p21022,p31,p22189,p50_i0,p21002_i0
0,5869068,62,Female,-3.11,157.0,70.8
1,1285065,57,Male,5.33,191.0,107.1
2,3875604,56,Male,-4.27,188.0,80.1
3,2324036,49,Female,0.74,161.0,88.4
4,1063851,59,Female,-4.75,155.0,59.7
...,...,...,...,...,...,...
502138,4305653,62,Female,-4.26,158.0,71.3
502139,2536038,41,Female,5.27,163.0,80.1
502140,3768226,48,Female,0.41,158.0,49.7
502141,2018360,70,Male,-1.48,175.0,89.8


In [2]:
# Get codes for eye fundus subjects

import json

work_loc = '/run/media/anton/Elements/UKB/21015'

# Get number of codes

with open(f'{work_loc}/codes.json', 'r') as f:
    codes = json.load(f)
    D = len(codes)
    print(D)

with open(f'{work_loc}/eidCodesMap.json', 'r') as f:
    eidCodesMap = json.load(f)

print(len(eidCodesMap))

for k in eidCodesMap:
    print(eidCodesMap[k])
    break

subs = sorted(list(eidCodesMap.keys()))
print(subs[:5])

1470
10566
[[803], [803, 813, 818], [818, 1457], [635, 818, 1400, 1457], [635], [1012, 1012, 1433], [635, 817, 818]]
['1000080', '1000140', '1000309', '1000457', '1000522']


In [3]:
# Get eye fundus data

import pickle

fundusData = dict()

for batch in range(10,19):
    data = pickle.load(open(f'{work_loc}/{batch}.pkl', 'rb'))
    fundusData = fundusData | data
    print(f'Done {batch}')

print(type(fundusData))
print(len(fundusData))

Done 10
Done 11
Done 12
Done 13
Done 14
Done 15
Done 16
Done 17
Done 18
<class 'dict'>
12001


In [24]:
# Create training data with clustered codes

import numpy as np

x = []
y = []
ages = []
sexes = []
clusteredCodes = '''B97
F05
G31
G81
I63
I67
I69
J12
J69
R26
R29
R41
R45
R47
R53
R54
U07
'''.split()
clusteredCodesInt = [None]*len(clusteredCodes)

for i,code in enumerate(codes):
    if code not in clusteredCodes:
        continue
    idx = clusteredCodes.index(code)
    clusteredCodesInt[idx] = i

print(clusteredCodesInt)

for i,sub in enumerate(subs):
    img = fundusData[sub]
    diag = 0
    for visit in eidCodesMap[sub]:
        for code in visit:
            if code in clusteredCodesInt:
                diag = 1
                break
        if diag == 1:
            break
    x.append(img)
    y.append(diag)
    sub = int(sub)
    row = df[df['eid'] == sub]
    age = row['p21022'].iloc[0]
    sex = int(row['p31'].iloc[0] == 'Male')
    ages.append(age)
    sexes.append(sex)
    if i % 1000 == 0:
        print(f'Done {i}')

x = np.stack(x)
y = np.array(y)
ages = np.array(ages)/50
sexes = np.array(sexes)
agesex = np.stack([ages, sexes], axis=1)

print(y[:50])   
print(x.shape)
print(agesex[:5])
print(agesex.shape)

[71, 283, 354, 385, 503, 507, 509, 542, 576, 1002, 1004, 1014, 1018, 1020, 1026, 1027, 1207]
Done 0
Done 1000
Done 2000
Done 3000
Done 4000
Done 5000
Done 6000
Done 7000
Done 8000
Done 9000
Done 10000
[0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 1 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
(10566, 256, 342, 3)
[[0.92 1.  ]
 [1.32 0.  ]
 [1.12 0.  ]
 [1.2  0.  ]
 [1.   1.  ]]
(10566, 2)


In [25]:
# Swizzle channels (color) to be first (second) dimension

x = np.transpose(x, (0, 3, 1, 2))
print(x.shape)

(10566, 3, 256, 342)


In [26]:
from sklearn.model_selection import train_test_split
import math

xz = x[np.where(y == 0)]
xo = x[np.where(y == 1)]
asz = agesex[np.where(y == 0)]
aso = agesex[np.where(y == 1)]

n = min(len(xz), len(xo))

xzi = np.arange(len(xz))
np.random.shuffle(xzi)

xoi = np.arange(len(xo))
np.random.shuffle(xoi)

xx = np.concatenate([xz[xzi[:n]], xo[xoi[:n]]])
aa = np.concatenate([asz[xzi[:n]], aso[xoi[:n]]])
yy = np.concatenate([np.zeros(n), np.ones(n)])

xtr, xt, atr, at, ytr, yt = train_test_split(xx, aa, yy, stratify=yy, train_size=math.floor(2*n/3))

print(xtr.shape)
print(atr.shape)
print(ytr.shape)

(987, 3, 256, 342)
(987, 2)
(987,)


In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

class FundusCNN(nn.Module):
    def __init__(self):
        super(FundusCNN, self).__init__()
        self.conv1 = nn.Conv2d(3,10,(7,7),stride=3).float().cuda()
        self.conv2 = nn.Conv2d(10,10,(7,7),stride=3).float().cuda()
        self.conv3 = nn.Conv2d(10,10,(7,7),stride=3).float().cuda()
        self.conv4 = nn.Conv2d(10,10,(7,7),stride=3).float().cuda()
        self.mp = nn.MaxPool2d((1,2)).float().cuda()
        self.fc = nn.Linear(10+2,2).float().cuda()

    def forward(self, x, a):
        N = x.shape[0]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.conv4(x)
        x = self.mp(x)
        x = x.reshape(N,-1)
        x = torch.cat([x, a], dim=1)
        x = self.fc(x)
        return x

cnn = FundusCNN()
optim = torch.optim.Adam(cnn.parameters(), lr=1e-4, weight_decay=1e-4)

nepochs = 3000
bsize = 1000
pperiod = 50
limit = (math.floor(len(xtr)/bsize)+1)*bsize

for e in range(nepochs):
    optim.zero_grad()
    starti = (e*bsize)%limit
    endi = starti+bsize
    xx = xtr[starti:endi]
    aa = atr[starti:endi]
    yy = ytr[starti:endi]
    xx = torch.from_numpy(xx).float().cuda()
    aa = torch.from_numpy(aa).float().cuda()
    yy = torch.from_numpy(yy).long().cuda()
    logits = cnn(xx, aa)
    loss = F.cross_entropy(logits, yy)
    loss.backward()
    optim.step()
    if e % pperiod == 0 or e == nepochs-1:
        print(f'{e} {float(loss)}')
        
        # Try it on the test set
        with torch.no_grad():
            xx = torch.from_numpy(xt).float().cuda()
            aa = torch.from_numpy(at).float().cuda()
            logits = cnn(xx, aa)
            yhat = torch.argmax(logits, dim=1)
            yhat = yhat.cpu().numpy()
            p = np.mean(yt)
            null = p
            if null < 0.5:
                null = 1-null
            acc = np.mean((yt == yhat)+0)
            tn, fp, fn, tp = confusion_matrix(yt, yhat).ravel()
            sp = tn / (tn+fp)
            sn = tp / (tp+fn)
            fs = 2*tp / (2*tp + fp + fn)
            print(null, acc, sp, sn, fs)
            # Binary AU ROC
            auroc = roc_auc_score(yt, logits[:,1].cpu().numpy())
            print(auroc)
            
            # Try random guesses with correct proportion of ones
            yhat = np.random.binomial(1, p, size=len(yt))
            tn, fp, fn, tp = confusion_matrix(yt, yhat).ravel()
            sp = tn / (tn+fp)
            sn = tp / (tp+fn)
            fs = 2*tp / (2*tp + fp + fn)
            acc = np.mean((yt == yhat)+0)
            print(null, acc, sp, sn, fs)
            print('---')
print('Done')
    

0 0.7422732710838318
0.500253164556962 0.500253164556962 1.0 0.0 0.0
0.4821310641579398
0.500253164556962 0.5007594936708861 0.5030364372469636 0.49848024316109424 0.49949238578680205
---
50 0.7090728878974915
0.500253164556962 0.47037974683544304 0.6376518218623481 0.3029381965552178 0.3637469586374696
0.5254092678504774
0.500253164556962 0.5174683544303798 0.5040485829959515 0.5309017223910841 0.5237381309345327
---
100 0.6920337080955505
0.500253164556962 0.5159493670886076 0.597165991902834 0.43465045592705165 0.4729878721058434
0.5159748799166493
0.500253164556962 0.499746835443038 0.5111336032388664 0.48834853090172237 0.49385245901639346
---
150 0.6849734783172607
0.500253164556962 0.530126582278481 0.5748987854251012 0.48530901722391084 0.5079533404029692
0.5351728338850399
0.500253164556962 0.5032911392405063 0.5101214574898786 0.49645390070921985 0.49974502804691484
---
200 0.6779310703277588
0.500253164556962 0.5432911392405063 0.5850202429149798 0.5015197568389058 0.5232558