In [1]:
nevecs = 60
ncomps = 1500
nbasecomps = 20
crop_size=180
ncauses = 3
dr_method = 'KPCA'

In [2]:
import os, sys
from pathlib import Path

macaw_path = Path(os.getcwd()).parent.parent.parent
sys.path.append(str(macaw_path) +'/')

train_path = macaw_path/'data'/'ukbb'/'axial'/f'train_hc_data_{dr_method}_{ncomps}.pkl'
model_base_path = macaw_path/'models'/f'{dr_method}_{ncomps}'/f'{nevecs}'

if not os.path.exists(model_base_path):
    os.makedirs(model_base_path)

In [3]:
import pickle

with open(train_path, 'rb') as f:
    train = pickle.load(f)
    
sex = train['sex']
age = train['age']
bmi = train['bmi']
# imgs = train['imgs']
min_age = train['min_age']

encoded_data = train['encoded_data']
kpca = train['kpca']

In [4]:
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
import numpy as np

def recons(age,bmi,latents, latent_offset=0):
    age+=min_age
    
    latent_enc = np.zeros((latents.shape[0],ncomps))
    latent_enc[:,latent_offset:latent_offset+latents.shape[1]] = latents    
    imgs = kpca.inverse_transform(latent_enc)
    return age,bmi,imgs

In [5]:
import seaborn as sns

sb = 4
fig, axs = plt.subplots(1,sb, figsize=(sb*5,5))

sns.histplot(sex,ax=axs[0],fill=True)
axs[0].set(title = "Sex")
sns.histplot(age,ax=axs[1], fill=True)
axs[1].set(title="Age")
sns.histplot(bmi,ax=axs[2], fill=True)
axs[2].set(title="BMI")
sns.histplot(encoded_data[0],ax=axs[3], fill=True)
axs[3].set(title="PCA0")

In [6]:
import utils.visualize as vis

plt.rcParams['figure.figsize'] = (20,4)
nsamples = 5

idx = np.random.randint(0,age.shape[0],nsamples)
re = recons(age[idx],bmi[idx],encoded_data[idx])

sex_t = ['Male' if round(s) else 'Female' for s in sex[idx]]
titles_sam = [f'Sex:{s}, Age:{a}, BMI:{np.round(b)}' for s,a,b in zip(sex_t, re[0],re[1])]
fig  = vis.img_grid([i.reshape(crop_size,crop_size).T for i in re[2]],cols=nsamples,titles=titles_sam)

## Causal Graph

In [7]:
sex_to_latents = [(0,i) for i in range(ncauses,nevecs+ncauses)]
sex_to_bmi = [(0,2)]

age_to_latents = [(1,i) for i in range(ncauses,nevecs+ncauses)]
age_to_bmi = [(1,2)]

bmi_to_latents = [(2,i) for i in range(ncauses,nevecs+ncauses)]
autoregressive_latents = [(i,j) for i in range(ncauses,nevecs+ncauses) for j in range(i+1,nevecs+ncauses)]
# autoregressive_latents = [(i,j) for i in range(ncauses,2*ncauses) for j in range(i+1,nevecs+ncauses)]
# autoregressive_latents = []
edges = sex_to_latents + sex_to_bmi+ age_to_latents + age_to_bmi+bmi_to_latents + autoregressive_latents

## Priors

In [8]:
P_sex = np.sum(sex)/len(sex)
print(P_sex)

unique_values, counts = np.unique(age, return_counts=True)
P_age = counts/np.sum(counts)
P_age

In [9]:
plt.bar(np.arange(36),P_age)

In [10]:
import torch
import yaml
from utils.helpers import dict2namespace

with open(macaw_path/'config'/'ukbb.yaml', 'r') as f:
    config_raw = yaml.load(f, Loader=yaml.FullLoader)
    
config = dict2namespace(config_raw)
config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [11]:
import torch.distributions as td
import torch

priors = [(slice(0,1),td.Bernoulli(torch.tensor([P_sex]).to(config.device))), # sex
          (slice(1,2),td.Categorical(torch.tensor([P_age]).to(config.device))), # age
          (slice(2,3),td.Normal(torch.zeros(1).to(config.device), torch.ones(1).to(config.device))), #BMI
          (slice(3,nbasecomps+3),td.Normal(torch.zeros(nbasecomps).to(config.device), torch.ones(nbasecomps).to(config.device))), # base_comps
          (slice(nbasecomps+3,nevecs+3),td.Normal(torch.zeros(nevecs-nbasecomps).to(config.device), torch.ones(nevecs-nbasecomps).to(config.device))), # new_comps
         ]

In [12]:
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

# Writer will output to ./runs/ directory by default
writer = SummaryWriter()

In [13]:
from macaw import MACAW

loss_vals = []
for e in range(0,ncomps-nbasecomps,nevecs-nbasecomps):
    save_path = model_base_path/f'{e}.pt'    
    ed =  encoded_data[:,e:e+nevecs]
    print(e,e+nevecs)
    
    if not os.path.exists(save_path): 
       
        X = np.hstack([sex[:,np.newaxis], age[:,np.newaxis], bmi[:,np.newaxis], ed])    

        macaw = MACAW.MACAW(config)
        loss_vals.append(macaw.fit_with_priors(X,edges, priors))

        torch.save(macaw,save_path)
    else:
        print("Skipping")

In [14]:
plt.plot(np.array(loss_vals[0][0]))
plt.plot(np.array(loss_vals[0][1]))

In [15]:
model_path = model_base_path/'hyperparameters.pkl'

with open(model_path, 'wb') as f:
    pickle.dump({'ncomps':ncomps, 'nevecs':nevecs,'nbasecomps':nbasecomps, 'ncauses':ncauses, 'crop_size':crop_size, 'age_bins':len(P_age)}, f)

## Linear regression

In [16]:
ed =  encoded_data[:,:nevecs]
X = np.hstack([sex[:,np.newaxis], bmi[:,np.newaxis], ed])    

In [17]:
from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X, age)

In [18]:
reg.score(X, age)

In [19]:
lr_path = model_base_path/'lr.pkl'

with open(lr_path, 'wb') as f:
    pickle.dump({'reg':reg}, f)

In [20]:
lr_path

In [21]:
reg.coef_

In [22]:
np.mean(np.abs(reg.predict(X) - age))

In [23]:
idx = 0
macaw = torch.load(model_base_path/f'{idx}.pt')
X_test = np.hstack([sex[:,np.newaxis], age[:,np.newaxis], bmi[:,np.newaxis], encoded_data[:,idx:nevecs+idx]])

In [24]:
encoded_data.shape

In [25]:
probs=[]
for i in range(36):
    X_test[:,1] = i
    probs.append(macaw.log_likelihood(X_test))

In [26]:
pexp = np.exp(np.array(probs))
pexp_sum = np.sum(pexp,axis=0)
p = pexp/pexp_sum

In [27]:
pred_labels = np.argmax(probs, axis=0)

In [28]:
age_bins=36

In [29]:
sns.histplot(age, bins=age_bins)
sns.histplot(pred_labels,bins=age_bins)

In [30]:
l = np.array(probs)[:,80]
plt.bar(np.arange(36),l-np.min(l))

In [31]:
sns.histplot(age-pred_labels, bins=age_bins)

In [32]:
np.mean(np.abs(pred_labels - age))