In [1]:
import numpy as np
import scipy.spatial.distance as sd
from neighborhood import neighbor_graph, laplacian
from correspondence import Correspondence
from stiefel import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from datareader import *
import pandas as pd 
import os.path
import pdb
cuda = torch.device('cuda') 
import scipy as sp
from collections import Counter
import seaborn as sns
from random import sample
import random
from sklearn import preprocessing
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
"""Defines the neural network"""

class Net(nn.Module):
    def __init__(self, D_in, H1, H2, D_out):
        super(Net, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H1)
        self.linear2 = torch.nn.Linear(H1, H2)
        self.linear3 = torch.nn.Linear(H2, D_out)

    def forward(self, x):
        h1_sigmoid = self.linear1(x).sigmoid()
        h2_sigmoid = self.linear2(h1_sigmoid).sigmoid()
        y_pred = self.linear3(h2_sigmoid)
        return y_pred

In [3]:
def train_and_project(x1_np, x2_np):
    
    torch.manual_seed(0)

    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H1, H2, D_out = x1_np.shape[0], x1_np.shape[1], 512, 64, 10

    model = Net(D_in, H1, H2, D_out)

    x1 = torch.from_numpy(x1_np.astype(np.float32))
    x2 = torch.from_numpy(x2_np.astype(np.float32))
    print(x1.dtype)
    
    adj1 = neighbor_graph(x1_np, k=5)
    adj2 = neighbor_graph(x2_np, k=5)

    #corr = Correspondence(matrix=np.eye(N))

    w1 = np.corrcoef(x1, x2)[0:x1.shape[0],x1.shape[0]:(x1.shape[0]+x2.shape[0])]
    w1[abs(w1) > 0.5] = 1
    w1[w1 != 1] = 0
    w = np.block([[w1,adj1],
                  [adj2,w1.T]])

    L_np = laplacian(w, normed=False)
    L = torch.from_numpy(L_np.astype(np.float32))
    
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
    
    for t in range(500):
        # Forward pass: Compute predicted y by passing x to the model
        y1_pred = model(x1)
        y2_pred = model(x2)

        outputs = torch.cat((y1_pred, y2_pred), 0)
        
        # Project the output onto Stiefel Manifold
        u, s, v = torch.svd(outputs, some=True)
        proj_outputs = u@v.t()

        # Compute and print loss
        print(L.dtype)
        loss = torch.trace(proj_outputs.t()@L@proj_outputs)
        print(t, loss.item())

        # Zero gradients, perform a backward pass, and update the weights.
        proj_outputs.retain_grad()

        optimizer.zero_grad()
        loss.backward(retain_graph=True)

        # Project the (Euclidean) gradient onto the tangent space of Stiefel Manifold (to get Rimannian gradient)
        rgrad = proj_stiefel(proj_outputs, proj_outputs.grad) 

        optimizer.zero_grad()
        # Backpropogate the Rimannian gradient w.r.t proj_outputs
        proj_outputs.backward(rgrad)

        optimizer.step()
        
    proj_outputs_np = proj_outputs.detach().numpy()
    return proj_outputs_np

In [4]:
Efeature = pd.read_csv('../data/efeature_filtered.csv',index_col=0)
geneExp = pd.read_csv('../data/expMat_filtered.csv',index_col=0)
label = pd.read_csv('../data/label_visual.csv')
print('Shape of geneExp: ', geneExp.shape)
print('Shape of Efeature: ', Efeature.shape)

#x1_np = preprocessing.scale(np.log(geneExp+1).to_numpy())
#x2_np = preprocessing.scale(Efeature.T.to_numpy())

x1_np = np.log(geneExp+1).to_numpy()
x2_np = preprocessing.scale(Efeature.T.to_numpy())

print(x1_np.shape)
print(x2_np.shape)

Shape of geneExp:  (1000, 3654)
Shape of Efeature:  (3654, 41)
(1000, 3654)
(41, 3654)


In [5]:
projections = train_and_project(x1_np, x2_np)

torch.float32
torch.float32
0 76.15230560302734
torch.float32
1 74.04776000976562
torch.float32
2 72.15640258789062
torch.float32
3 70.50237274169922
torch.float32
4 69.06790924072266
torch.float32
5 67.79653930664062
torch.float32
6 66.63261413574219
torch.float32
7 65.54585266113281
torch.float32
8 64.52932739257812
torch.float32
9 63.584930419921875
torch.float32
10 62.70964431762695
torch.float32
11 61.8917236328125
torch.float32
12 61.11663818359375
torch.float32
13 60.37461853027344
torch.float32
14 59.662174224853516
torch.float32
15 58.97975158691406
torch.float32
16 58.32814025878906
torch.float32
17 57.706607818603516
torch.float32
18 57.112762451171875
torch.float32
19 56.5430908203125
torch.float32
20 55.99395751953125
torch.float32
21 55.46256637573242
torch.float32
22 54.947120666503906
torch.float32
23 54.446781158447266
torch.float32
24 53.96090316772461
torch.float32
25 53.4884147644043
torch.float32
26 53.02803039550781
torch.float32
27 52.57858657836914
torch.float32

torch.float32
229 18.07370376586914
torch.float32
230 18.024158477783203
torch.float32
231 17.975345611572266
torch.float32
232 17.927244186401367
torch.float32
233 17.87984848022461
torch.float32
234 17.833160400390625
torch.float32
235 17.787141799926758
torch.float32
236 17.741823196411133
torch.float32
237 17.697160720825195
torch.float32
238 17.653154373168945
torch.float32
239 17.609806060791016
torch.float32
240 17.567089080810547
torch.float32
241 17.525026321411133
torch.float32
242 17.483558654785156
torch.float32
243 17.442720413208008
torch.float32
244 17.402477264404297
torch.float32
245 17.36283302307129
torch.float32
246 17.323768615722656
torch.float32
247 17.285293579101562
torch.float32
248 17.24738883972168
torch.float32
249 17.210023880004883
torch.float32
250 17.17322540283203
torch.float32
251 17.1369686126709
torch.float32
252 17.101255416870117
torch.float32
253 17.066070556640625
torch.float32
254 17.031389236450195
torch.float32
255 16.99723243713379
torch.flo

torch.float32
453 14.66789436340332
torch.float32
454 14.663707733154297
torch.float32
455 14.660327911376953
torch.float32
456 14.657746315002441
torch.float32
457 14.654730796813965
torch.float32
458 14.651023864746094
torch.float32
459 14.647530555725098
torch.float32
460 14.644716262817383
torch.float32
461 14.6419095993042
torch.float32
462 14.638609886169434
torch.float32
463 14.635212898254395
torch.float32
464 14.632189750671387
torch.float32
465 14.629429817199707
torch.float32
466 14.626461029052734
torch.float32
467 14.623247146606445
torch.float32
468 14.620165824890137
torch.float32
469 14.617340087890625
torch.float32
470 14.614527702331543
torch.float32
471 14.611571311950684
torch.float32
472 14.60853099822998
torch.float32
473 14.605644226074219
torch.float32
474 14.602882385253906
torch.float32
475 14.600107192993164
torch.float32
476 14.597232818603516
torch.float32
477 14.594350814819336
torch.float32
478 14.591574668884277
torch.float32
479 14.588873863220215
torch

In [6]:
projections.shape

(1041, 10)

In [7]:
projections = pd.DataFrame(projections)
features = geneExp.index.tolist()+Efeature.columns.tolist()
projections.index = features
projections

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
Adarb2,-0.006427,0.003653,0.011276,0.008796,-0.002035,0.010519,0.011852,-0.024288,0.004697,0.020238
Sst,0.003257,0.010013,0.011801,0.022704,-0.011400,-0.004367,0.009136,-0.022721,0.004323,0.009220
Vip,-0.010553,0.004490,0.009969,0.012553,-0.004974,0.007019,0.003081,-0.026217,0.006336,0.014434
Npy,-0.003612,0.005848,-0.000513,0.026711,-0.004672,-0.005764,0.004483,-0.021566,0.004539,0.007751
Synpr,0.004285,0.026262,0.018568,0.029322,-0.009616,-0.021698,0.002789,-0.030245,0.000830,-0.002280
...,...,...,...,...,...,...,...,...,...,...
fast_trough_v_short_square,0.002473,0.003269,0.002758,0.019747,-0.005012,0.009959,-0.005279,-0.011582,0.001191,0.017237
fast_trough_t_short_square,0.003247,0.004186,0.000356,0.017510,-0.004141,0.009430,-0.003541,-0.015877,-0.001935,0.017557
threshold_v_short_square,0.001886,0.002282,-0.000514,0.018300,-0.004602,0.011816,-0.005785,-0.011797,0.001866,0.018193
threshold_i_short_square,-0.009573,0.020358,0.010290,0.022886,-0.001082,0.015415,0.000827,-0.024236,-0.000262,0.005419


In [8]:
projections.to_csv("../data/deepmanreg_latent.csv")