In [25]:
import numpy as np
import warnings  
warnings.filterwarnings('ignore', category=FutureWarning)  

from vade_new import VaDE
from utility import create_project_folders, set_random_seed,set_device
from config import config
import torch
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
import torch.nn.functional as F

set_random_seed(123)

# Load Data

In [4]:
train_data = np.flip(np.load(r"/mnt/sda/gene/zhangym/VADER/Data/NC_9/X_reference_9.npy"), axis=1)
train_label = np.load(r"/mnt/sda/gene/zhangym/VADER/Data/NC_9/y_reference_9.npy").astype(int)
S = np.flip(np.load(r"/mnt/sda/gene/zhangym/VADER/Data/NC_9/MCR_NC9_S_20.npy"),axis=1)
Wavenumber = np.flip(np.load(r'/mnt/sda/gene/zhangym/VADER/Data/NC_9/wavenumbers.npy'),axis=0)
device = "cuda:1"
project_tag = '/mnt/sda/gene/zhangym/VADER/VADER/Test_MCREC/0919_test_gene'
Pretrain_epochs = 100
epochs = 300
batch_size = 128
memo = "NC_9"
n_gene = 2000

In [5]:
model_params = config.get_model_params()
device = set_device(device)
tensor_data = torch.tensor(train_data.copy(), dtype=torch.float32)
tensor_gpu_data = tensor_data.to(device)
input_dim = tensor_data.shape[1]
project_dir = create_project_folders(project_tag)
weight_scheduler_config = config.get_weight_scheduler_config()
n_component = S.shape[0]
paths = config.get_project_paths(
    project_dir,
    n_component,
    lamb1=weight_scheduler_config['init_weights']['lamb1'],
    lamb2=weight_scheduler_config['init_weights']['lamb2'],
    lamb3=weight_scheduler_config['init_weights']['lamb3'],
    lamb4=weight_scheduler_config['init_weights']['lamb4'],
    lamb5=weight_scheduler_config['init_weights']['lamb5'],
    lamb6=weight_scheduler_config['init_weights']['lamb6'],
    memo=memo,
)
l_c_dim = config.encoder_type(model_params['encoder_type'], paths['train_path'])
model = VaDE(
    input_dim=input_dim,
    intermediate_dim=model_params['intermediate_dim'],
    latent_dim=n_component,
    tensor_gpu_data=tensor_gpu_data,
    n_components=n_component,
    S=torch.tensor(S.copy()).float().to(device),
    wavenumber = Wavenumber,
    # prior_y=train_label,
    lamb1=weight_scheduler_config['init_weights']['lamb1'],
    lamb2=weight_scheduler_config['init_weights']['lamb2'],
    lamb3=weight_scheduler_config['init_weights']['lamb3'],
    lamb4=weight_scheduler_config['init_weights']['lamb4'],
    lamb5=weight_scheduler_config['init_weights']['lamb5'],
    lamb6=weight_scheduler_config['init_weights']['lamb6'],
    device=device,
    l_c_dim=l_c_dim,
    batch_size=batch_size,
    encoder_type=model_params['encoder_type'],
    pretrain_epochs=Pretrain_epochs,
    num_classes=9,
    clustering_method=model_params['clustering_method'],
    resolution_1=model_params['resolution_1'],
    resolution_2=model_params['resolution_2']
).to(device)

配置文件路径: /mnt/sda/gene/zhangym/VADER/VADER/model_config.yaml
成功复制配置文件到: /mnt/sda/gene/zhangym/VADER/VADER/Test_MCREC/0919_test_gene/NC_9


In [6]:
model.load_state_dict(torch.load('/mnt/sda/gene/zhangym/VADER/VADER/Test_MCREC/0915_Save_Model/NC_9/txt/Epoch_101_Acc=1.00_model.pth'))

<All keys matched successfully>

In [7]:
recon_x, mean, gaussian_means, log_var, z, gamma, pi, S = model(tensor_gpu_data,  labels_batch = None)

gmm_probs = gamma.detach().cpu().numpy()
gmm_labels = np.argmax(gmm_probs, axis=1)

print(f'NMI of clustering: {normalized_mutual_info_score(gmm_labels, train_label):.2f}') 
print(f'ARI of clustering: {adjusted_rand_score(gmm_labels, train_label):.2f}') 

NMI of clustering: 0.93
ARI of clustering: 0.88


In [16]:
threshold = 0.95
gaussian_means = model.gaussian.means.cpu().detach().numpy()
gaussian_vars = model.gaussian.log_variances.cpu().detach().numpy()
print(gaussian_means.shape)
z_sample=np.random.multivariate_normal(gaussian_means[1,:],np.diag(np.exp(gaussian_vars[1,:])),(1,))

(9, 20)


In [None]:
z_prob = F.softmax(model.gaussian.gaussian_log_prob(torch.tensor(z_sample).to(device)), dim=1)

if z_prob.max() > threshold:
    new_X = np.matmul(z_sample, S.cpu().detach().numpy())

In [41]:
new_X

array([[0.41208685, 0.35056194, 0.41051253, 0.36756137, 0.37760283,
        0.3342634 , 0.38360236, 0.41818665, 0.45116785, 0.48587412,
        0.43604001, 0.47136588, 0.47246208, 0.49353732, 0.55020844,
        0.60021903, 0.57161405, 0.60716922, 0.52193601, 0.57898132,
        0.63158334, 0.66658111, 0.63408655, 0.65278609, 0.60279547,
        0.63074008, 0.61113263, 0.65270173, 0.67996595, 0.68120216,
        0.71167336, 0.67247236, 0.66986944, 0.65548125, 0.65923238,
        0.61727108, 0.56473146, 0.58166573, 0.5589365 , 0.57643353,
        0.56819423, 0.52937491, 0.50234396, 0.50851333, 0.58661465,
        0.59423393, 0.58378417, 0.59422266, 0.60192349, 0.588655  ,
        0.58516514, 0.58056725, 0.59438665, 0.5920856 , 0.61476425,
        0.59195444, 0.57951411, 0.58453893, 0.59335993, 0.56464607,
        0.56163252, 0.54012442, 0.5552243 , 0.56043106, 0.56830123,
        0.53387605, 0.54477511, 0.52177664, 0.55319792, 0.52494041,
        0.54892706, 0.54613563, 0.52701768, 0.52