In [None]:
##### paper: "Comparing high-dimensional neural recordings by aligning their low-dimensional latent representations" ######
"""
In this paper, the authors propose a method for comparing neural recordings across different animals, 
brain regions, and tasks. The method involves mapping the high-dimensional neural recordings onto a 
low-dimensional latent space using a dimensionality reduction technique called non-negative matrix 
factorization (NMF). The authors then align the latent representations of the neural recordings using 
a technique called Procrustes analysis, which allows for comparisons between neural data from different sources.
"""

In [1]:
####### dropout resnet18 vs without dropout
#### 
import torch
import sys
sys.path.append("..")
import numpy as np

CLEAN_PATH = "/home/yifan/dataset/resnet18_with_dropout/pairflip/cifar10/0/"
REF_PATH = "/home/yifan/dataset/clean/pairflip/cifar10/0"


ENCODER_DIMS=[512,256,256,256,256,2]
DECODER_DIMS= [2,256,256,256,256,512]
VIS_MODEL_NAME = 'vis'

DEVICE='cuda:1'
########## initulize reference data and target data
from alignment.data_preprocess import DataInit
REF_EPOCH = 200
TAR_EPOCH = 200
ref_datainit = DataInit(REF_PATH,REF_PATH,REF_EPOCH)
tar_datainit = DataInit(CLEAN_PATH,CLEAN_PATH,TAR_EPOCH)
ref_model, ref_provider, ref_train_data, ref_prediction, ref_prediction_res, ref_scores = ref_datainit.getData()
tar_model, tar_provider, tar_train_data, tar_prediction, tar_prediction_res, tar_scores = tar_datainit.getData()


from alignment.ReferenceGenerator import ReferenceGenerator
gen = ReferenceGenerator(ref_provider=ref_provider, tar_provider=tar_provider,REF_EPOCH=REF_EPOCH,TAR_EPOCH=TAR_EPOCH,ref_model=ref_model,tar_model=tar_model,DEVICE=DEVICE)

absolute_alignment_indicates,predict_label_diff_indicates,predict_confidence_Diff_indicates,high_distance_indicates = gen.subsetClassify(18,0.8,0.3,0.05)


from representationTrans.trans_visualizer_border import visualizer
from singleVis.SingleVisualizationModel import VisModel
from singleVis.projector import TimeVisProjector
model = VisModel(ENCODER_DIMS, DECODER_DIMS)

I = np.eye(512)
projector = TimeVisProjector(vis_model=model, content_path=REF_PATH, vis_model_name=VIS_MODEL_NAME, device="cpu")
vis = visualizer(ref_provider, I,I, np.dot(ref_provider.train_representation(TAR_EPOCH),I), projector, 200,[0,1],'tab10')


  from .autonotebook import tqdm as notebook_tqdm


NET resnet18
Finish initialization...


100%|██████████| 250/250 [00:00<00:00, 462.56it/s]


NET resnet18_with_dropout
Finish initialization...


100%|██████████| 250/250 [00:00<00:00, 6411.23it/s]
100%|██████████| 250/250 [00:00<00:00, 5697.85it/s]
100%|██████████| 250/250 [00:00<00:00, 5087.26it/s]


absolute alignment indicates number: 106 label diff indicates number: 12 confidence diff indicates number: 16 high distance number: 97


In [26]:
import torch
import numpy as np
from scipy.linalg import orthogonal_procrustes
def nmf(X, k,device,max_iter=100):
    """
    Non-negative matrix factorization using PyTorch.
    :param X: input data matrix of shape (n_samples, n_features)
    :param k: number of latent factors
    :param max_iter: maximum number of iterations
    :return: factors W and H
    """
    n_samples, n_features = X.shape
    X_tensor = torch.tensor(X, dtype=torch.float32, device=device)
    W = torch.randn(n_samples, k, dtype=torch.float32, device=device)
    H = torch.randn(k, n_features, dtype=torch.float32, device=device)
    for i in range(max_iter):
        H *= torch.mm(W.transpose(0, 1), X_tensor) / torch.mm(torch.mm(W.transpose(0, 1), W), H)
        W *= torch.mm(X_tensor, H.transpose(0, 1)) / torch.mm(W, torch.mm(H, H.transpose(0, 1)))
    return W.cpu().numpy(), H.cpu().numpy()
def procrustes(X, Y):
    """
    Procrustes analysis for aligning two datasets.
    :param X: input data matrix of shape (n_samples, n_features)
    :param Y: input data matrix of shape (n_samples, n_features)
    :return: aligned data matrix Y
    """
    # Center the data matrices
    X_centered = X - np.mean(X, axis=0)
    Y_centered = Y - np.mean(Y, axis=0)
    # Perform Procrustes analysis
    mtx1, mtx2, disparity = orthogonal_procrustes(X_centered, Y_centered)
    Y_aligned = np.dot(Y_centered, mtx2)
    return Y_aligned



In [33]:
import torch
import torch.nn as nn
import torch.optim as optim

class NMF(nn.Module):
    def __init__(self, n, m, r):
        super(NMF, self).__init__()
        self.W = nn.Parameter(torch.rand(n, r, requires_grad=True))
        self.H = nn.Parameter(torch.rand(r, m, requires_grad=True))

    def forward(self):
        return torch.matmul(self.W, self.H)


In [43]:
def perform_nmf(V, r, num_epochs, learning_rate):
    n, m = V.shape
    nmf_model = NMF(n, m, r)
    optimizer = optim.Adam(nmf_model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        V_hat = nmf_model()
        loss = torch.sum((V - V_hat)**2)
        loss.backward()
        optimizer.step()
        print("epoch:",epoch,"loss",loss)
    
    return nmf_model.W, nmf_model.H
def procrustes_analysis(X, Y):
    X_mean = torch.mean(X, dim=0)
    Y_mean = torch.mean(Y, dim=0)
    
    X_centered = X - X_mean
    Y_centered = Y - Y_mean
    
    U, S, V = torch.svd(torch.matmul(X_centered.T, Y_centered))
    R = torch.matmul(U, V.T)
    s = torch.sum(S) / torch.sum(X_centered**2)

    return R, s, X_mean, Y_mean

def align_procrustes(X, Y, R, s, X_mean, Y_mean):
    return s * torch.matmul(X - X_mean, R) + Y_mean



In [70]:
X = torch.Tensor(tar_train_data)
Y = torch.Tensor(ref_train_data)
# 计算X和Y的均值和标准差
# Perform NMF on X and Y
W_X, H_X = perform_nmf(X, 10, 200, 0.1)
W_Y, H_Y = perform_nmf(Y, 10, 200, 0.1)

epoch: 0 loss tensor(1.0957e+08, grad_fn=<SumBackward0>)
epoch: 1 loss tensor(42266376., grad_fn=<SumBackward0>)
epoch: 2 loss tensor(19414678., grad_fn=<SumBackward0>)
epoch: 3 loss tensor(15568738., grad_fn=<SumBackward0>)
epoch: 4 loss tensor(17366386., grad_fn=<SumBackward0>)
epoch: 5 loss tensor(19889836., grad_fn=<SumBackward0>)
epoch: 6 loss tensor(21846224., grad_fn=<SumBackward0>)
epoch: 7 loss tensor(23113804., grad_fn=<SumBackward0>)
epoch: 8 loss tensor(23850860., grad_fn=<SumBackward0>)
epoch: 9 loss tensor(24231314., grad_fn=<SumBackward0>)
epoch: 10 loss tensor(24385930., grad_fn=<SumBackward0>)
epoch: 11 loss tensor(24401176., grad_fn=<SumBackward0>)
epoch: 12 loss tensor(24329976., grad_fn=<SumBackward0>)
epoch: 13 loss tensor(24201868., grad_fn=<SumBackward0>)
epoch: 14 loss tensor(24030270., grad_fn=<SumBackward0>)
epoch: 15 loss tensor(23817208., grad_fn=<SumBackward0>)
epoch: 16 loss tensor(23556222., grad_fn=<SumBackward0>)
epoch: 17 loss tensor(23234172., grad_fn

In [135]:
R, s, W_X_mean, W_Y_mean = procrustes_analysis(W_X, W_Y)
W_X_aligned = align_procrustes(W_X, W_Y, R, s, W_X_mean, W_Y_mean)

In [139]:
# distances = pairwise_distances(W_X_aligned, W_Y)
W_X_aligned - W_Y
# degrees = np.sqrt(np.sum((W_X_aligned - W_Y)**2, axis=1))

tensor([[-2.0192e-01,  2.7818e-01,  8.3335e-02,  ..., -4.1830e-01,
         -3.1814e-01,  2.7906e-01],
        [-6.1115e-02, -1.3147e-01,  2.1422e-02,  ...,  5.1416e-02,
         -2.2498e-01,  3.7401e-02],
        [-2.5321e-01,  1.5755e-01,  1.0832e-01,  ...,  2.4374e-01,
          2.4707e-01,  3.2663e-01],
        ...,
        [ 1.4403e-01, -3.5164e-01,  3.0339e-05,  ...,  2.9838e-01,
          3.6142e-01,  2.8500e-02],
        [ 7.9456e-02, -1.6207e-02, -1.8104e-01,  ..., -3.4400e-01,
          2.2521e-01, -4.1308e-02],
        [ 1.5862e-01, -2.3774e-01, -1.6660e-01,  ..., -1.6032e-01,
          3.1075e-01, -2.8092e-01]], grad_fn=<SubBackward0>)

tensor([[0.8112, 2.1760, 1.6504,  ..., 1.8227, 2.1220, 1.7615],
        [2.4204, 0.5849, 0.6242,  ..., 0.7195, 2.4144, 1.4973],
        [2.2690, 1.0002, 0.6661,  ..., 0.8080, 2.5014, 1.6200],
        ...,
        [1.9942, 0.9942, 0.5425,  ..., 0.6734, 2.0870, 1.2926],
        [2.3369, 2.3433, 2.3271,  ..., 2.3177, 0.5497, 1.3969],
        [2.0177, 1.7465, 1.4984,  ..., 1.6875, 1.4423, 0.6278]],
       grad_fn=<SqrtBackward0>)

In [134]:
len(most_similar_indices)

50000

[(0, 0),
 (1, 1),
 (2, 2),
 (3, 3),
 (4, 4),
 (5, 5),
 (6, 6),
 (7, 7),
 (8, 8),
 (9, 9)]

In [147]:
from scipy.spatial import procrustes
X = W_X_aligned.detach().numpy()
Y = W_Y.detach().numpy()
percentile=90
# Perform Procrustes analysis to transform X to Y
mtx1, mtx2, disparity = procrustes(X, Y)
# Compute the transformation degree for each row of the matrix
degree = np.sqrt(np.sum((mtx1 - X)**2, axis=1))

# Find the subset of samples with the highest transformation degree
threshold = np.percentile(degree, percentile)
high_discrepancy_representation_set = np.where(degree > threshold)[0]

threshold_2 = np.percentile(degree, 20)
low_discrepancy_representation_set = np.where(degree < threshold_2)[0]

In [144]:
X = tar_train_data
Y = ref_train_data
percentile=90
# Perform Procrustes analysis to transform X to Y
mtx1, mtx2, disparity = procrustes(X, Y)
# Compute the transformation degree for each row of the matrix
degree = np.sqrt(np.sum((mtx1 - X)**2, axis=1))

# Find the subset of samples with the highest transformation degree
threshold = np.percentile(degree, percentile)
threshold_2 = np.percentile(degree, 20)
high_discrepancy_representation_set = np.where(degree > threshold)[0]
low_discrepancy_representation_set = np.where(degree < threshold_2)[0]

In [104]:
degree

array([16.48912018, 26.55234335, 25.3769415 , ..., 18.35612092,
       31.3011138 , 20.43379256])

In [148]:

print(f"{len(high_distance_indicates)} samples in high prediction MAE set")
in_both_arrays = np.in1d(high_discrepancy_representation_set, high_distance_indicates)

print(f"{np.count_nonzero(in_both_arrays)}/{len(high_distance_indicates)} samples of high prediction MAE are present in high_discrepancy_representation_set.")


print(f"{len(absolute_alignment_indicates)} samples in absolute alignmeny set")
in_both_arrays = np.in1d(high_discrepancy_representation_set, absolute_alignment_indicates)
print(f"{np.count_nonzero(in_both_arrays)}/{len(absolute_alignment_indicates)} samples of absolute_alignment_indicates are present in high_discrepancy_representation_set.")


print(f"{len(absolute_alignment_indicates)} samples in absolute alignmeny set")
in_both_arrays = np.in1d(low_discrepancy_representation_set, absolute_alignment_indicates)
print(f"{np.count_nonzero(in_both_arrays)}/{len(absolute_alignment_indicates)} samples of absolute_alignment_indicates are present in low_discrepancy_representation_set.")


97 samples in high prediction MAE set
32/97 samples of high prediction MAE are present in high_discrepancy_representation_set.
106 samples in absolute alignmeny set
30/106 samples of absolute_alignment_indicates are present in high_discrepancy_representation_set.
106 samples in absolute alignmeny set
11/106 samples of absolute_alignment_indicates are present in low_discrepancy_representation_set.
