In [6]:
import numpy as np
import scipy.spatial.distance as sd
from neighborhood import neighbor_graph, laplacian
from correspondence import Correspondence
from stiefel import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from datareader import *
import pandas as pd 
import os.path
import pdb
cuda = torch.device('cuda') 
import scipy as sp
from collections import Counter
import seaborn as sns
from random import sample
import random
from sklearn import preprocessing
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
"""Defines the neural network"""

class Net(nn.Module):
    def __init__(self, D_in, H1, H2, D_out):
        super(Net, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H1)
        self.linear2 = torch.nn.Linear(H1, H2)
        self.linear3 = torch.nn.Linear(H2, D_out)

    def forward(self, x):
        h1_sigmoid = self.linear1(x).sigmoid()
        h2_sigmoid = self.linear2(h1_sigmoid).sigmoid()
        y_pred = self.linear3(h2_sigmoid)
        return y_pred

In [22]:
def train_and_project(x1_np, x2_np):
    
    torch.manual_seed(0)

    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H1, H2, D_out = x1_np.shape[0], x1_np.shape[1], 512, 64, 3

    model = Net(D_in, H1, H2, D_out)

    x1 = torch.from_numpy(x1_np.astype(np.float32))
    x2 = torch.from_numpy(x2_np.astype(np.float32))
    print(x1.dtype)
    
    adj1 = neighbor_graph(x1_np, k=5)
    adj2 = neighbor_graph(x2_np, k=5)

    #corr = Correspondence(matrix=np.eye(N))

    w1 = np.corrcoef(x1, x2)[0:x1.shape[0],x1.shape[0]:(x1.shape[0]+x2.shape[0])]
    w1[abs(w1) > 0.5] = 1
    w1[w1 != 1] = 0
    w = np.block([[w1,adj1],
                  [adj2,w1.T]])

    L_np = laplacian(w, normed=False)
    L = torch.from_numpy(L_np.astype(np.float32))
    
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
    
    for t in range(500):
        # Forward pass: Compute predicted y by passing x to the model
        y1_pred = model(x1)
        y2_pred = model(x2)

        outputs = torch.cat((y1_pred, y2_pred), 0)
        
        # Project the output onto Stiefel Manifold
        u, s, v = torch.svd(outputs, some=True)
        proj_outputs = u@v.t()

        # Compute and print loss
        print(L.dtype)
        loss = torch.trace(proj_outputs.t()@L@proj_outputs)
        print(t, loss.item())

        # Zero gradients, perform a backward pass, and update the weights.
        proj_outputs.retain_grad()

        optimizer.zero_grad()
        loss.backward(retain_graph=True)

        # Project the (Euclidean) gradient onto the tangent space of Stiefel Manifold (to get Rimannian gradient)
        rgrad = proj_stiefel(proj_outputs, proj_outputs.grad) 

        optimizer.zero_grad()
        # Backpropogate the Rimannian gradient w.r.t proj_outputs
        proj_outputs.backward(rgrad)

        optimizer.step()
        
    proj_outputs_np = proj_outputs.detach().numpy()
    return proj_outputs_np

In [34]:
Efeature = pd.read_csv('../data/efeature_filtered.csv',index_col=0)
geneExp = pd.read_csv('../data/expMat_filtered.csv',index_col=0)
label = pd.read_csv('../data/label_visual.csv')
print('Shape of geneExp: ', geneExp.shape)
print('Shape of Efeature: ', Efeature.shape)

#x1_np = preprocessing.scale(np.log(geneExp+1).to_numpy())
#x2_np = preprocessing.scale(Efeature.T.to_numpy())

x1_np = np.log(geneExp+1).to_numpy()
x2_np = preprocessing.scale(Efeature.T.to_numpy())

print(x1_np.shape)
print(x2_np.shape)

Shape of geneExp:  (1000, 3654)
Shape of Efeature:  (3654, 41)
(1000, 3654)
(41, 3654)


In [35]:
projections = train_and_project(x1_np, x2_np)

torch.float32
torch.float32
0 16.509029388427734
torch.float32
1 15.672364234924316
torch.float32
2 15.045405387878418
torch.float32
3 14.495298385620117
torch.float32
4 14.020167350769043
torch.float32
5 13.609445571899414
torch.float32
6 13.233115196228027
torch.float32
7 12.873891830444336
torch.float32
8 12.531214714050293
torch.float32
9 12.20956039428711
torch.float32
10 11.910453796386719
torch.float32
11 11.631485939025879
torch.float32
12 11.369108200073242
torch.float32
13 11.120898246765137
torch.float32
14 10.885858535766602
torch.float32
15 10.663627624511719
torch.float32
16 10.453712463378906
torch.float32
17 10.255171775817871
torch.float32
18 10.066837310791016
torch.float32
19 9.887600898742676
torch.float32
20 9.716368675231934
torch.float32
21 9.551756858825684
torch.float32
22 9.392093658447266
torch.float32
23 9.236048698425293
torch.float32
24 9.083226203918457
torch.float32
25 8.934130668640137
torch.float32
26 8.789505958557129
torch.float32
27 8.64972877502441

torch.float32
225 -0.4464569687843323
torch.float32
226 -0.4443609118461609
torch.float32
227 -0.4477754235267639
torch.float32
228 -0.4624914824962616
torch.float32
229 -0.4734916090965271
torch.float32
230 -0.4742472469806671
torch.float32
231 -0.4774527847766876
torch.float32
232 -0.48938411474227905
torch.float32
233 -0.4968068301677704
torch.float32
234 -0.4977031946182251
torch.float32
235 -0.5003952980041504
torch.float32
236 -0.5090774297714233
torch.float32
237 -0.5144522190093994
torch.float32
238 -0.516864538192749
torch.float32
239 -0.5205973982810974
torch.float32
240 -0.529538631439209
torch.float32
241 -0.5342814922332764
torch.float32
242 -0.5388941168785095
torch.float32
243 -0.5467015504837036
torch.float32
244 -0.5560463666915894
torch.float32
245 -0.5613645911216736
torch.float32
246 -0.5651077032089233
torch.float32
247 -0.5706303715705872
torch.float32
248 -0.5766957998275757
torch.float32
249 -0.580296516418457
torch.float32
250 -0.5828105807304382
torch.float32


In [36]:
projections.shape

(1041, 3)

In [37]:
projections = pd.DataFrame(projections)
features = geneExp.index.tolist()+Efeature.columns.tolist()
projections.index = features
projections

Unnamed: 0,0,1,2
Adarb2,-0.017885,0.012411,-0.000088
Sst,-0.011594,0.012967,0.000829
Vip,-0.019374,0.015498,-0.000188
Npy,-0.009411,0.011734,-0.009520
Synpr,-0.022196,0.018100,-0.003972
...,...,...,...
fast_trough_v_short_square,0.003909,0.022146,-0.000598
fast_trough_t_short_square,0.000611,0.017694,-0.007263
threshold_v_short_square,0.003939,0.019455,-0.004537
threshold_i_short_square,-0.017429,0.013782,-0.009230


In [38]:
projections.to_csv("../data/deepmanreg_latent.csv")