In [8]:
import pandas as pd
from pymethylprocess.MethylationDataTypes import MethylationArray
from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score
import warnings
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.offline as py
warnings.filterwarnings("ignore")

In [9]:
from pybedtools import BedTool
import numpy as np

In [10]:
from functools import reduce

In [11]:
BedTool('hg19.genome').makewindows(g='hg19.genome',w=1000000).saveas('hg19.1m.bed')#.to_dataframe().shape

<BedTool(hg19.1m.bed)>

In [12]:
ma=MethylationArray.from_pickle('train_val_test_sets/train_methyl_array.pkl')
ma_v=MethylationArray.from_pickle('train_val_test_sets/val_methyl_array.pkl')

In [13]:
include_last=False
def get_final_modules(ma=ma,a='450kannotations.bed',b='lola_vignette_data/activeDHS_universe.bed', include_last=False):
    allcpgs=ma.beta.columns.values
    df=BedTool(a).to_dataframe()
    df.iloc[:,0]=df.iloc[:,0].astype(str).map(lambda x: 'chr'+x.split('.')[0])
    df=df.set_index('name').loc[list(ma.beta)].reset_index().iloc[:,[1,2,3,0]]
    df_bed=pd.read_table(b,header=None)
    df_bed['features']=np.arange(df_bed.shape[0])
    df_bed=df_bed.iloc[:,[0,1,2,-1]]
    b=BedTool.from_dataframe(df)
    a=BedTool.from_dataframe(df_bed)#('lola_vignette_data/activeDHS_universe.bed')
    c=a.intersect(b,wa=True,wb=True).sort()
    d=c.groupby(g=[1,2,3,4],c=(8,8),o=('count','distinct'))
    df2=d.to_dataframe()
    df3=df2.loc[df2.iloc[:,-2]>25]
    modules = [cpgs.split(',') for cpgs in df3.iloc[:,-1].values]
    modulecpgs=np.array(list(set(list(reduce(lambda x,y:x+y,modules)))))
    missing_cpgs=np.setdiff1d(allcpgs,modulecpgs).tolist()
    final_modules = modules+([missing_cpgs] if include_last else [])
    module_names=(df3.iloc[:,0]+'_'+df3.iloc[:,1].astype(str)+'_'+df3.iloc[:,2].astype(str)).tolist()
    return final_modules,modulecpgs,module_names

final_modules,modulecpgs,module_names=get_final_modules(b='hg19.1m.bed',include_last=include_last)
ma.beta=ma.beta.loc[:,modulecpgs]
ma_v.beta=ma_v.beta.loc[:,modulecpgs]


In [14]:
pd.DataFrame(list(map(len,final_modules))).sum()
#len(final_modules)

0    7747
dtype: int64

In [15]:
len(final_modules)

172

In [16]:
import torch
import torch.nn.functional as F

def softmax(input_tensor, dim=1):
    # transpose input
    transposed_input = input_tensor.transpose(dim, len(input_tensor.size()) - 1)
    # calculate softmax
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
    # un-transpose result
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input_tensor.size()) - 1)

In [17]:
from sklearn.preprocessing import LabelBinarizer
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
# https://github.com/higgsfield/Capsule-Network-Tutorial/blob/master/Capsule%20Network.ipynb

class MLP(nn.Module): # add latent space extraction, and spits out csv line of SQL as text for UMAP
    def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=False, softmax=False):
        super(MLP,self).__init__()
        self.hidden_topology=hidden_topology
        self.topology = [n_input]+hidden_topology+[n_outputs]
        layers = [nn.Linear(self.topology[i],self.topology[i+1]) for i in range(len(self.topology)-2)]
        for layer in layers:
            torch.nn.init.xavier_uniform_(layer.weight)
        self.layers = [nn.Sequential(layer,nn.ReLU(),nn.Dropout(p=dropout_p)) for layer in layers]
        self.output_layer = nn.Linear(self.topology[-2],self.topology[-1])
        torch.nn.init.xavier_uniform_(self.output_layer.weight)
        if binary:
            output_transform = nn.Sigmoid()
        elif softmax:
            output_transform = nn.Softmax()
        else:
            output_transform = nn.Dropout(p=0.)
        self.layers.append(nn.Sequential(self.output_layer,output_transform))
        self.mlp = nn.Sequential(*self.layers)
        
    def forward(self, x):
        #print(x.shape)
        return self.mlp(x)
        
class MethylationDataset(Dataset):
    def __init__(self, methyl_arr, outcome_col,binarizer=None, modules=[]):
        if binarizer==None:
            binarizer=LabelBinarizer()
            binarizer.fit(methyl_arr.pheno[outcome_col].astype(str).values)
        self.y=binarizer.transform(methyl_arr.pheno[outcome_col].astype(str).values)
        self.y_unique=np.unique(np.argmax(self.y,1))
        self.binarizer=binarizer
        if not modules:
            modules=[list(methyl_arr.beta)]
        self.modules=modules
        self.X=methyl_arr.beta
        self.length=methyl_arr.beta.shape[0]
        
    def __len__(self):
        return self.length

    def __getitem__(self,i):
        return tuple([torch.FloatTensor(self.X.iloc[i].values)]+[torch.FloatTensor(self.X.iloc[i].loc[module].values) for module in self.modules]+[torch.FloatTensor(self.y[i])])
    
class PrimaryCaps(nn.Module):
    def __init__(self,modules,hidden_topology,n_output):
        super(PrimaryCaps, self).__init__()
        self.capsules=nn.ModuleList([MLP(len(module),hidden_topology,0.,n_outputs=n_output) for module in modules])
        
    def forward(self, x):
        #print(self.capsules)
        u = [self.capsules[i](x[i]) for i in range(len(self.capsules))]
        u = torch.stack(u, dim=1)
        #print(u.size())
        return self.squash(u)
    
    def squash(self, x):
        squared_norm = (x ** 2).sum(-1, keepdim=True)
        #print('prim_norm',squared_norm.size())
        output_tensor = squared_norm *  x / ((1. + squared_norm) * torch.sqrt(squared_norm))
        #print('z_init',output_tensor.size())
        return output_tensor
    
    def get_weights(self):
        return list(self.capsules[0].parameters())[0].data#self.state_dict()#[self.capsules[i].state_dict() for i in range(len(self.capsules))]
        
class CapsLayer(nn.Module):
    def __init__(self, n_capsules, n_routes, n_input, n_output, routing_iterations=3):
        super(CapsLayer, self).__init__()
        self.n_capsules=n_capsules
        self.num_routes = n_routes
        self.W=nn.Parameter(torch.randn(1, n_routes, n_capsules, n_output, n_input))
        self.routing_iterations=routing_iterations
        
    def forward(self,x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.n_capsules, dim=2).unsqueeze(4)
        
        W = torch.cat([self.W] * batch_size, dim=0)
        #print('affine',W.size(),x.size())
        u_hat = torch.matmul(W, x)
        #print('affine_trans',u_hat.size())

        b_ij = Variable(torch.zeros(1, self.num_routes, self.n_capsules, 1))
        
        if torch.cuda.is_available():
            b_ij=b_ij.cuda()
            

        for iteration in range(self.routing_iterations):
            c_ij = softmax(b_ij)
            #print(c_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
            #print('coeff',c_ij.size())#[0,:,0,:])#.size())

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            #print('z',v_j.size())
            
            if iteration < self.routing_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
        
        
    def squash(self, x):
        #print(x.size())
        squared_norm = (x ** 2).sum(-1, keepdim=True)
        #print('norm',squared_norm.size())
        output_tensor = squared_norm *  x / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
        
class Decoder(nn.Module):
    def __init__(self, n_input, n_output, hidden_topology):
        super(Decoder, self).__init__()
        self.decoder=MLP(n_input,hidden_topology, 0., n_outputs=n_output, binary=True)
        
    def forward(self, x):
        return self.decoder(x)
    
class CapsNet(nn.Module):
    def __init__(self, primary_caps, caps_hidden_layers, caps_output_layer, decoder, lr_balance=0.5, gamma=0.005):
        super(CapsNet, self).__init__()
        self.primary_caps=primary_caps
        self.caps_hidden_layers=caps_hidden_layers
        self.caps_output_layer=caps_output_layer
        self.decoder=decoder
        self.recon_loss_fn = nn.BCELoss()
        self.lr_balance=lr_balance
        self.gamma=gamma
        
    def forward(self, x_orig, modules_input):
        x=self.primary_caps(modules_input)
        primary_caps_out=x#.view(x.size(0),x.size(1)*x.size(2))
        #print(x.size())
        for layer in self.caps_hidden_layers:
            x=layer(x)
        
        y_pred=self.caps_output_layer(x)#.squeeze(-1)
        #print(y_pred.shape)
        
        classes = torch.sqrt((y_pred ** 2).sum(2))
        classes = F.softmax(classes)
        
        max_length_indices = classes.argmax(dim=1)
        masked = torch.sparse.torch.eye(8)
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        
        embedding = (y_pred * masked[:, :, None, None]).view(y_pred.size(0), -1)
        
        #print(y_pred.size())
        x_hat=self.decoder(embedding)#.reshape(y_pred.size(0),-1))
        return x_orig, x_hat, y_pred, embedding, primary_caps_out
    
    def recon_loss(self, x_orig, x_hat):
        return self.recon_loss_fn(x_hat, x_orig)
    
    def margin_loss(self,x, labels):
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))
        
        #print(v_c)

        left = (F.relu(0.9 - v_c)**2).view(batch_size, -1)
        right = (F.relu(v_c - 0.1)**2).view(batch_size, -1)
        #print(left)
        #print(right)
        #print(labels)

        loss = labels * left + self.lr_balance * (1.0 - labels) * right
        #print(loss.shape)
        loss = loss.sum(dim=1).mean()
        return loss
    
    def calculate_loss(self, x_orig, x_hat, y_pred, y_true):
        margin_loss = self.margin_loss(y_pred, y_true)
        recon_loss = self.gamma*self.recon_loss(x_orig,x_hat)
        loss = margin_loss + recon_loss
        return loss, margin_loss, recon_loss
        

In [18]:
ma.pheno['Age_binned'],bins=pd.cut(ma.pheno['Age'],bins=8,retbins=True)
ma_v.pheno['Age_binned'],bins=pd.cut(ma_v.pheno['Age'],bins=bins,retbins=True,)

dataset=MethylationDataset(ma,'Age_binned',modules=final_modules)
dataset_v=MethylationDataset(ma_v,'Age_binned',modules=final_modules)

In [19]:
dataloader=DataLoader(dataset,batch_size=16,shuffle=True,num_workers=6, drop_last=True)
dataloader_v=DataLoader(dataset_v,batch_size=16,shuffle=False,num_workers=6, drop_last=False)

In [20]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import Adam


In [21]:
len(final_modules)

172

In [None]:
primary_caps_out_len=40
caps_out_len=20
n_inputs=list(map(len,final_modules))
n_epochs=500
n_primary=len(final_modules)
hidden_topology=[30,100,100]
gamma=1e-2
decoder_top=[100,300]
lr=1e-3
routing_iterations=3

primary_caps = PrimaryCaps(modules=final_modules,hidden_topology=hidden_topology,n_output=primary_caps_out_len)
hidden_caps = []
n_out_caps=len(dataset.y_unique)
output_caps = CapsLayer(n_out_caps,n_primary,primary_caps_out_len,caps_out_len,routing_iterations=routing_iterations)
decoder = Decoder(n_out_caps*caps_out_len,len(list(ma.beta)),decoder_top)
capsnet = CapsNet(primary_caps, hidden_caps, output_caps, decoder, gamma=gamma)

for d in ['figures/embeddings'+x for x in ['','1','2']]:
    os.makedirs(d,exist_ok=True)

optimizer = Adam(capsnet.parameters(),lr)
scheduler=CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1)
for epoch in range(n_epochs):
    print(epoch)
    capsnet.train(True)
    running_loss=0.
    Y={'true':[],'pred':[]}
    for i,batch in enumerate(dataloader):
        x_orig=batch[0]
        #print(x_orig)
        y_true=batch[-1]
        module_x = batch[1:-1]
        x_orig, x_hat, y_pred, embedding, primary_caps_out=capsnet(x_orig,module_x)
        loss,margin_loss,recon_loss=capsnet.calculate_loss(x_orig, x_hat, y_pred, y_true)
        Y['true'].extend(y_true.argmax(1).detach().cpu().numpy().tolist())
        Y['pred'].extend(F.softmax(torch.sqrt((y_pred**2).sum(2))).argmax(1).detach().cpu().numpy().tolist())
        train_loss=margin_loss.item()#print(loss)
        running_loss+=train_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #print(capsnet.primary_caps.get_weights())
    running_loss/=(i+1)
    print('Epoch {}: Train Loss {}, Train R2: {}, Train MAE: {}'.format(epoch,running_loss,r2_score(Y['true'],Y['pred']), mean_absolute_error(Y['true'],Y['pred'])))
    scheduler.step()
    capsnet.train(False)
    running_loss=np.zeros((3,))
    Y={'true':[],'pred':[],'embeddings':[],'embeddings2':[],'embeddings3':[]}
    with torch.no_grad():
        for i,batch in enumerate(dataloader_v):
            x_orig=batch[0]
            y_true=batch[-1]
            module_x = batch[1:-1]
            x_orig, x_hat, y_pred, embedding, primary_caps_out=capsnet(x_orig,module_x)
            #print(primary_caps_out.size())
            Y['embeddings3'].append(torch.cat([primary_caps_out[i] for i in range(x_orig.size(0))],dim=0).detach().cpu().numpy())
            primary_caps_out=primary_caps_out.view(primary_caps_out.size(0),primary_caps_out.size(1)*primary_caps_out.size(2))
            Y['embeddings'].append(embedding.detach().cpu().numpy())
            Y['embeddings2'].append(primary_caps_out.detach().cpu().numpy())
            loss,margin_loss,recon_loss=capsnet.calculate_loss(x_orig, x_hat, y_pred, y_true)
            val_loss=margin_loss.item()#print(loss)
            running_loss+=np.array([loss,margin_loss,recon_loss])
            Y['true'].extend(y_true.argmax(1).detach().cpu().numpy().tolist())
            Y['pred'].extend((y_pred**2).sum(2).argmax(1).detach().cpu().numpy().tolist())
        running_loss/=(i+1)
    Y['embeddings']=pd.DataFrame(PCA(n_components=2).fit_transform(np.vstack(Y['embeddings'])),columns=['x','y'])
    Y['embeddings2']=pd.DataFrame(PCA(n_components=2).fit_transform(np.vstack(Y['embeddings2'])),columns=['x','y'])
    #print(list(map(lambda x: x.shape,Y['embeddings3'])))
    Y['embeddings3']=pd.DataFrame(PCA(n_components=3).fit_transform(np.vstack(Y['embeddings3'])),columns=['x','y','z'])
    Y['embeddings']['color']=Y['true']
    Y['embeddings2']['color']=Y['true']
    Y['embeddings3']['color']=module_names*ma_v.beta.shape[0]#Y['true']
    Y['embeddings3']['name']=list(reduce(lambda x,y:x+y,[[i]*n_primary for i in Y['true']]))
    fig = px.scatter_3d(Y['embeddings3'], x="x", y="y", z="z", color="color", symbol='name', text='name')
    py.plot(fig, filename='figures/embeddings3/embeddings3.{}.pos.html'.format(epoch),auto_open=False)
    #Y['embeddings3']['color']=list(reduce(lambda x,y:x+y,[[i]*n_primary for i in Y['true']]))
    fig = px.scatter_3d(Y['embeddings3'], x="x", y="y", z='z', color="name", text='color')
    py.plot(fig, filename='figures/embeddings3/embeddings3.{}.true.html'.format(epoch),auto_open=False)
    fig = px.scatter(Y['embeddings'], x="x", y="y", color="color")
    py.plot(fig, filename='figures/embeddings/embeddings.{}.true.html'.format(epoch),auto_open=False)
    fig = px.scatter(Y['embeddings2'], x="x", y="y", color="color")
    py.plot(fig, filename='figures/embeddings2/embeddings2.{}.true.html'.format(epoch),auto_open=False)
    Y['embeddings'].loc[:,'color']=Y['pred']
    Y['embeddings2'].loc[:,'color']=Y['pred']
    fig = px.scatter(Y['embeddings'], x="x", y="y", color="color")
    py.plot(fig, filename='figures/embeddings/embeddings.{}.pred.html'.format(epoch),auto_open=False)
    fig = px.scatter(Y['embeddings2'], x="x", y="y", color="color")
    py.plot(fig, filename='figures/embeddings2/embeddings2.{}.pred.html'.format(epoch),auto_open=False)
    print('Epoch {}: Val Loss {}, Margin Loss {}, Recon Loss {}, Val R2: {}, Val MAE: {}'.format(epoch,running_loss[0],running_loss[1],running_loss[2],r2_score(Y['true'],Y['pred']), mean_absolute_error(Y['true'],Y['pred'])))
    #Y=pd.DataFrame([])
            
            
            

0
Epoch 0: Train Loss 0.5725887553258375, Train R2: -0.7333907141798888, Train MAE: 2.125
Epoch 0: Val Loss 0.5783395767211914, Margin Loss 0.5714452713727951, Recon Loss 0.006894311518408358, Val R2: -2.8299663299663296, Val MAE: 3.32
1
Epoch 1: Train Loss 0.5145805315537886, Train R2: 0.300931968421721, Train MAE: 1.1647727272727273
Epoch 1: Val Loss 0.5092523768544197, Margin Loss 0.5024370849132538, Recon Loss 0.006815279251895845, Val R2: 0.058291245791245894, Val MAE: 1.38
2
Epoch 2: Train Loss 0.47692854837937787, Train R2: 0.5827942120598548, Train MAE: 0.8068181818181818
Epoch 2: Val Loss 0.4920669347047806, Margin Loss 0.4853948801755905, Recon Loss 0.006672052899375558, Val R2: -0.020622895622895543, Val MAE: 1.64
3
Epoch 3: Train Loss 0.45093053579330444, Train R2: 0.625011464410941, Train MAE: 0.7159090909090909
Epoch 3: Val Loss 0.4829041138291359, Margin Loss 0.47646401077508926, Recon Loss 0.006440094206482172, Val R2: -0.38362794612794593, Val MAE: 1.7
4
Epoch 4: Train

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import Adam
mlp=MLP(len(list(ma.beta)), [500,100], dropout_p=0., n_outputs=n_out_caps, binary=False, softmax=False)
optimizer = Adam(mlp.parameters(),lr=1e-4,weight_decay=1e-3)
scheduler=CosineAnnealingLR(optimizer, T_max=10, eta_min=0, last_epoch=-1)
loss_fn=nn.CrossEntropyLoss()
mlp.train(True)
for epoch in range(n_epochs):
    running_loss=0.
    y_pred_all,y_true_all=[],[]
    for i,batch in enumerate(dataloader):
        x_orig=batch[0]
        #print(x_orig.size())
        #print(batch[-1])
        #print(batch[-1].size())
        y_true=torch.argmax(batch[-1],dim=1).long()
        #print(y_true)
        module_x = batch[1:-1]
        y_pred=mlp(x_orig)
        #print(y_pred.shape)
        y_true_all.append(y_true.flatten().detach().numpy())
        y_pred_all.append(torch.argmax(y_pred,dim=1).flatten().detach().numpy())
        loss=loss_fn(y_pred,y_true)
        train_loss=loss.item()#print(loss)
        running_loss+=train_loss
        #print(train_loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    y_true,y_pred=np.hstack(y_true_all),np.hstack(y_pred_all)
    print(accuracy_score(y_true,y_pred))
    running_loss/=(i+1)
    print('train',running_loss)
    scheduler.step()
    
            
        

In [None]:
y_pred.size()

In [None]:
a=1